module Control.Distributed.Process.Internal.Closure.BuiltIn 
  ( -- * Remote table 
    remoteTable
    -- * Static dictionaries and associated operations
  , staticDecode
  , sdictUnit
  , sdictProcessId
  , sdictSendPort
    -- * Some static values
  , sndStatic
    -- * The CP type and associated combinators 
  , CP
  , idCP
  , splitCP
  , returnCP
  , bindCP 
  , seqCP 
    -- * CP versions of Cloud Haskell primitives  
  , cpLink
  , cpUnlink
  , cpSend
  , cpExpect
  , cpNewChan
  ) where

import Data.ByteString.Lazy (ByteString)  
import Data.Binary (decode, encode)
import Data.Rank1Typeable (Typeable, ANY, ANY1, ANY2, ANY3, ANY4)
import Data.Rank1Dynamic (toDynamic)
import Control.Distributed.Static 
  ( RemoteTable
  , registerStatic
  , Static
  , staticLabel
  , staticApply
  , Closure
  , closure
  , closureApplyStatic
  , closureApply
  , staticCompose
  , staticClosure
  )
import Control.Distributed.Process.Serializable 
  ( SerializableDict(..)
  , Serializable
  )
import Control.Distributed.Process.Internal.Types 
  ( Process
  , ProcessId
  , SendPort
  , ReceivePort
  )
import Control.Distributed.Process.Internal.Primitives 
  ( link
  , unlink
  , send
  , expect
  , newChan
  )

--------------------------------------------------------------------------------
-- Remote table                                                               --
--------------------------------------------------------------------------------

remoteTable :: RemoteTable -> RemoteTable
remoteTable = 
      registerStatic "$decodeDict"      (toDynamic (decodeDict       :: SerializableDict ANY -> ByteString -> ANY))
    . registerStatic "$sdictUnit"       (toDynamic (SerializableDict :: SerializableDict ()))
    . registerStatic "$sdictProcessId"  (toDynamic (SerializableDict :: SerializableDict ProcessId)) 
    . registerStatic "$sdictSendPort_"  (toDynamic (sdictSendPort_   :: SerializableDict ANY -> SerializableDict (SendPort ANY)))
    . registerStatic "$returnProcess"   (toDynamic (return           :: ANY -> Process ANY))
    . registerStatic "$seqProcess"      (toDynamic ((>>)             :: Process ANY1 -> Process ANY2 -> Process ANY2))
    . registerStatic "$bindProcess"     (toDynamic ((>>=)            :: Process ANY1 -> (ANY1 -> Process ANY2) -> Process ANY2))
    . registerStatic "$decodeProcessId" (toDynamic (decode           :: ByteString -> ProcessId))
    . registerStatic "$link"            (toDynamic link)
    . registerStatic "$unlink"          (toDynamic unlink)
    . registerStatic "$sendDict"        (toDynamic (sendDict         :: SerializableDict ANY -> ProcessId -> ANY -> Process ()))
    . registerStatic "$expectDict"      (toDynamic (expectDict       :: SerializableDict ANY -> Process ANY))
    . registerStatic "$newChanDict"     (toDynamic (newChanDict      :: SerializableDict ANY -> Process (SendPort ANY, ReceivePort ANY)))
    . registerStatic "$cpSplit"         (toDynamic (cpSplit          :: (ANY1 -> Process ANY3) -> (ANY2 -> Process ANY4) -> (ANY1, ANY2) -> Process (ANY3, ANY4)))
    . registerStatic "$snd"             (toDynamic (snd              :: (ANY1, ANY2) -> ANY2))
  where
    decodeDict :: forall a. SerializableDict a -> ByteString -> a
    decodeDict SerializableDict = decode

    sdictSendPort_ :: forall a. SerializableDict a -> SerializableDict (SendPort a)
    sdictSendPort_ SerializableDict = SerializableDict

    sendDict :: forall a. SerializableDict a -> ProcessId -> a -> Process ()
    sendDict SerializableDict = send

    expectDict :: forall a. SerializableDict a -> Process a
    expectDict SerializableDict = expect

    newChanDict :: forall a. SerializableDict a -> Process (SendPort a, ReceivePort a)
    newChanDict SerializableDict = newChan

    cpSplit :: forall a b c d. (a -> Process c) -> (b -> Process d) -> (a, b) -> Process (c, d)
    cpSplit f g (a, b) = do
      c <- f a
      d <- g b
      return (c, d)

--------------------------------------------------------------------------------
-- Static dictionaries and associated operations                              --
--------------------------------------------------------------------------------

-- | Static decoder, given a static serialization dictionary.
--
-- See module documentation of "Control.Distributed.Process.Closure" for an
-- example.
staticDecode :: Typeable a => Static (SerializableDict a) -> Static (ByteString -> a)
staticDecode dict = decodeDictStatic `staticApply` dict 
  where
    decodeDictStatic :: Typeable a => Static (SerializableDict a -> ByteString -> a)
    decodeDictStatic = staticLabel "$decodeDict"

-- | Serialization dictionary for '()' 
sdictUnit :: Static (SerializableDict ())
sdictUnit = staticLabel "$sdictUnit" 

-- | Serialization dictionary for 'ProcessId' 
sdictProcessId :: Static (SerializableDict ProcessId)
sdictProcessId = staticLabel "$sdictProcessId" 

-- | Serialization dictionary for 'SendPort'
sdictSendPort :: Typeable a 
              => Static (SerializableDict a) -> Static (SerializableDict (SendPort a))
sdictSendPort = staticApply (staticLabel "$sdictSendPort_") 

--------------------------------------------------------------------------------
-- Static values                                                              --
--------------------------------------------------------------------------------

sndStatic :: Static ((a, b) -> b)
sndStatic = staticLabel "$snd"

--------------------------------------------------------------------------------
-- The CP type and associated combinators                                     --
--------------------------------------------------------------------------------

-- | @CP a b@ is a process with input of type @a@ and output of type @b@
type CP a b = Closure (a -> Process b)

returnProcessStatic :: Typeable a => Static (a -> Process a)
returnProcessStatic = staticLabel "$returnProcess"

-- | 'CP' version of 'Control.Category.id' 
idCP :: Typeable a => CP a a
idCP = staticClosure returnProcessStatic

-- | 'CP' version of ('Control.Arrow.***')
splitCP :: (Typeable a, Typeable b, Typeable c, Typeable d) 
        => CP a c -> CP b d -> CP (a, b) (c, d)
splitCP p q = cpSplitStatic `closureApplyStatic` p `closureApply` q
  where
    cpSplitStatic :: Static ((a -> Process c) -> (b -> Process d) -> (a, b) -> Process (c, d))
    cpSplitStatic = staticLabel "$cpSplit" 

-- | 'CP' version of 'Control.Monad.return'
returnCP :: forall a. Serializable a 
         => Static (SerializableDict a) -> a -> Closure (Process a)
returnCP dict x = closure decoder (encode x)
  where
    decoder :: Static (ByteString -> Process a)
    decoder = returnProcessStatic
            `staticCompose`
              staticDecode dict

-- | 'CP' version of ('Control.Monad.>>')
seqCP :: (Typeable a, Typeable b)
      => Closure (Process a) -> Closure (Process b) -> Closure (Process b)
seqCP p q = seqProcessStatic `closureApplyStatic` p `closureApply` q 
  where
    seqProcessStatic :: (Typeable a, Typeable b)
                     => Static (Process a -> Process b -> Process b)
    seqProcessStatic = staticLabel "$seqProcess"

-- | (Not quite the) 'CP' version of ('Control.Monad.>>=') 
bindCP :: forall a b. (Typeable a, Typeable b)
       => Closure (Process a) -> CP a b -> Closure (Process b)
bindCP x f = bindProcessStatic `closureApplyStatic` x `closureApply` f 
  where
    bindProcessStatic :: (Typeable a, Typeable b) 
                      => Static (Process a -> (a -> Process b) -> Process b)
    bindProcessStatic = staticLabel "$bindProcess"

--------------------------------------------------------------------------------
-- CP versions of Cloud Haskell primitives                                    --
--------------------------------------------------------------------------------

decodeProcessIdStatic :: Static (ByteString -> ProcessId)
decodeProcessIdStatic = staticLabel "$decodeProcessId"

-- | 'CP' version of 'link'
cpLink :: ProcessId -> Closure (Process ())
cpLink = closure (linkStatic `staticCompose` decodeProcessIdStatic) . encode 
  where
    linkStatic :: Static (ProcessId -> Process ())
    linkStatic = staticLabel "$link"

-- | 'CP' version of 'unlink'
cpUnlink :: ProcessId -> Closure (Process ())
cpUnlink = closure (unlinkStatic `staticCompose` decodeProcessIdStatic) . encode
  where
    unlinkStatic :: Static (ProcessId -> Process ())
    unlinkStatic = staticLabel "$unlink"

-- | 'CP' version of 'send'
cpSend :: forall a. Typeable a 
       => Static (SerializableDict a) -> ProcessId -> CP a () 
cpSend dict pid = closure decoder (encode pid)
  where
    decoder :: Static (ByteString -> a -> Process ())
    decoder = (sendDictStatic `staticApply` dict)
            `staticCompose` 
              decodeProcessIdStatic 

    sendDictStatic :: Typeable a 
                   => Static (SerializableDict a -> ProcessId -> a -> Process ())
    sendDictStatic = staticLabel "$sendDict" 

-- | 'CP' version of 'expect'
cpExpect :: Typeable a => Static (SerializableDict a) -> Closure (Process a)
cpExpect dict = staticClosure (expectDictStatic `staticApply` dict)
  where
    expectDictStatic :: Typeable a => Static (SerializableDict a -> Process a)
    expectDictStatic = staticLabel "$expectDict"

-- | 'CP' version of 'newChan'
cpNewChan :: Typeable a 
          => Static (SerializableDict a) 
          -> Closure (Process (SendPort a, ReceivePort a))
cpNewChan dict = staticClosure (newChanDictStatic `staticApply` dict)
  where
    newChanDictStatic :: Typeable a 
                      => Static (SerializableDict a -> Process (SendPort a, ReceivePort a))
    newChanDictStatic = staticLabel "$newChanDict"