{-# LANGUAGE ExistentialQuantification  #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE PatternGuards              #-}
module Control.Distributed.Process.ManagedProcess.Server
  ( 
    condition
  , state
  , input
  , reply
  , replyWith
  , noReply
  , continue
  , timeoutAfter
  , hibernate
  , stop
  , stopWith
  , replyTo
  , replyChan
  , reject
  , rejectWith
  , become
    
  , noReply_
  , haltNoReply_
  , continue_
  , timeoutAfter_
  , hibernate_
  , stop_
    
  , handleCall
  , handleCallIf
  , handleCallFrom
  , handleCallFromIf
  , handleRpcChan
  , handleRpcChanIf
  , handleCast
  , handleCastIf
  , handleInfo
  , handleRaw
  , handleDispatch
  , handleDispatchIf
  , handleExit
  , handleExitIf
    
  , action
  , handleCall_
  , handleCallIf_
  , handleCallFrom_
  , handleCallFromIf_
  , handleRpcChan_
  , handleRpcChanIf_
  , handleCast_
  , handleCastIf_
    
  , handleControlChan
  , handleControlChan_
    
  , handleExternal
  , handleExternal_
  , handleCallExternal
  ) where
import Control.Concurrent.STM (STM, atomically)
import Control.Distributed.Process hiding (call, Message)
import qualified Control.Distributed.Process as P (Message)
import Control.Distributed.Process.Serializable
import Control.Distributed.Process.ManagedProcess.Internal.Types hiding (liftIO, lift)
import Control.Distributed.Process.Extras
  ( ExitReason(..)
  , Routable(..)
  )
import Control.Distributed.Process.Extras.Time
import Prelude hiding (init)
condition :: forall a b. (Serializable a, Serializable b)
          => (a -> b -> Bool)
          -> Condition a b
condition = Condition
state :: forall s m. (Serializable m) => (s -> Bool) -> Condition s m
state = State
input :: forall s m. (Serializable m) => (m -> Bool) -> Condition s m
input = Input
reject :: forall r s . s -> String -> Reply r s
reject st rs = continue st >>= return . ProcessReject rs
rejectWith :: forall r m s . (Show r) => s -> r -> Reply m s
rejectWith st rs = reject st (show rs)
reply :: (Serializable r) => r -> s -> Reply r s
reply r s = continue s >>= replyWith r
replyWith :: (Serializable r)
          => r
          -> ProcessAction s
          -> Reply r s
replyWith r s = return $ ProcessReply r s
noReply :: (Serializable r) => ProcessAction s -> Reply r s
noReply = return . NoReply
noReply_ :: forall s r . (Serializable r) => s -> Reply r s
noReply_ s = continue s >>= noReply
haltNoReply_ :: Serializable r => ExitReason -> Reply r s
haltNoReply_ r = stop r >>= noReply
continue :: s -> Action s
continue = return . ProcessContinue
continue_ :: (s -> Action s)
continue_ = return . ProcessContinue
timeoutAfter :: Delay -> s -> Action s
timeoutAfter d s = return $ ProcessTimeout d s
timeoutAfter_ :: StatelessHandler s Delay
timeoutAfter_ d = return . ProcessTimeout d
hibernate :: TimeInterval -> s -> Process (ProcessAction s)
hibernate d s = return $ ProcessHibernate d s
hibernate_ :: StatelessHandler s TimeInterval
hibernate_ d = return . ProcessHibernate d
become :: forall s . ProcessDefinition s -> s -> Action s
become def st = return $ ProcessBecome def st
stop :: ExitReason -> Action s
stop r = return $ ProcessStop r
stopWith :: s -> ExitReason -> Action s
stopWith s r = return $ ProcessStopping s r
stop_ :: StatelessHandler s ExitReason
stop_ r _ = stop r
replyTo :: (Serializable m) => CallRef m -> m -> Process ()
replyTo cRef@(CallRef (_, tag)) msg = sendTo cRef $ CallResponse msg tag
replyChan :: (Serializable m) => SendPort m -> m -> Process ()
replyChan = sendChan
handleCall_ :: (Serializable a, Serializable b)
           => (a -> Process b)
           -> Dispatcher s
handleCall_ = handleCallIf_ $ input (const True)
handleCallIf_ :: forall s a b . (Serializable a, Serializable b)
    => Condition s a 
    -> (a -> Process b) 
    -> Dispatcher s
handleCallIf_ cond handler
  = DispatchIf {
      dispatch   = \s (CallMessage p c) -> handler p >>= mkCallReply c s
    , dispatchIf = checkCall cond
    }
  where
        
        
        mkCallReply :: (Serializable b)
                    => CallRef b
                    -> s
                    -> b
                    -> Process (ProcessAction s)
        mkCallReply c s m =
          let (c', t) = unCaller c
          in sendTo c' (CallResponse m t) >> continue s
handleCall :: (Serializable a, Serializable b)
           => CallHandler s a b
           -> Dispatcher s
handleCall = handleCallIf $ state (const True)
handleCallIf :: forall s a b . (Serializable a, Serializable b)
    => Condition s a 
    -> CallHandler s a b
        
    -> Dispatcher s
handleCallIf cond handler
  = DispatchIf
    { dispatch   = \s (CallMessage p c) -> handler s p >>= mkReply c
    , dispatchIf = checkCall cond
    }
handleCallFrom_ :: forall s a b . (Serializable a, Serializable b)
                => StatelessCallHandler s a b
                -> Dispatcher s
handleCallFrom_ = handleCallFromIf_ $ input (const True)
handleCallFromIf_ :: forall s a b . (Serializable a, Serializable b)
                  => Condition s a
                  -> StatelessCallHandler s a b
                  -> Dispatcher s
handleCallFromIf_ cond handler =
  DispatchIf {
      dispatch   = \_ (CallMessage p c) -> handler c p >>= mkReply c
    , dispatchIf = checkCall cond
    }
handleCallFrom :: forall s a b . (Serializable a, Serializable b)
           => DeferredCallHandler s a b
           -> Dispatcher s
handleCallFrom = handleCallFromIf $ state (const True)
handleCallFromIf :: forall s a b . (Serializable a, Serializable b)
    => Condition s a 
    -> DeferredCallHandler s a b
        
    -> Dispatcher s
handleCallFromIf cond handler
  = DispatchIf {
      dispatch   = \s (CallMessage p c) -> handler c s p >>= mkReply c
    , dispatchIf = checkCall cond
    }
handleRpcChan :: forall s a b . (Serializable a, Serializable b)
              => ChannelHandler s a b
              -> Dispatcher s
handleRpcChan = handleRpcChanIf $ input (const True)
handleRpcChanIf :: forall s a b . (Serializable a, Serializable b)
                => Condition s a
                -> ChannelHandler s a b
                -> Dispatcher s
handleRpcChanIf cond handler
  = DispatchIf {
      dispatch   = \s (ChanMessage p c) -> handler c s p
    , dispatchIf = checkRpc cond
    }
handleRpcChan_ :: forall s a b . (Serializable a, Serializable b)
                  => StatelessChannelHandler s a b
                     
                  -> Dispatcher s
handleRpcChan_ = handleRpcChanIf_ $ input (const True)
handleRpcChanIf_ :: forall s a b . (Serializable a, Serializable b)
                 => Condition s a
                 -> StatelessChannelHandler s a b
                 -> Dispatcher s
handleRpcChanIf_ c h
  = DispatchIf { dispatch   = \s ((ChanMessage m p) :: Message a b) -> h p m s
               , dispatchIf = checkRpc c
               }
handleCast :: (Serializable a)
           => CastHandler s a
           -> Dispatcher s
handleCast = handleCastIf $ input (const True)
handleCastIf :: forall s a . (Serializable a)
    => Condition s a 
    -> CastHandler s a
       
    -> Dispatcher s
handleCastIf cond h
  = DispatchIf {
      dispatch   = \s ((CastMessage p) :: Message a ()) -> h s p
    , dispatchIf = checkCast cond
    }
handleExternal :: forall s a . (Serializable a)
               => STM a
               -> ActionHandler s a
               -> ExternDispatcher s
handleExternal a h =
  let matchMsg'   = matchSTM a (\(m :: r) -> return $ unsafeWrapMessage m)
      matchAny' f = matchSTM a (\(m :: r) -> return $ f (unsafeWrapMessage m)) in
  DispatchSTM
    { stmAction   = a
    , dispatchStm = h
    , matchStm    = matchMsg'
    , matchAnyStm = matchAny'
    }
handleExternal_ :: forall s a . (Serializable a)
                => STM a
                -> StatelessHandler s a
                -> ExternDispatcher s
handleExternal_ a h = handleExternal a (flip h)
handleCallExternal :: forall s r w . (Serializable r)
                   => STM r
                   -> (w -> STM ())
                   -> CallHandler s r w
                   -> ExternDispatcher s
handleCallExternal reader writer handler =
  let matchMsg'   = matchSTM reader (\(m :: r) -> return $ unsafeWrapMessage m)
      matchAny' f = matchSTM reader (\(m :: r) -> return $ f $ unsafeWrapMessage m) in
  DispatchSTM
    { stmAction   = reader
    , dispatchStm = doStmReply handler
    , matchStm    = matchMsg'
    , matchAnyStm = matchAny'
    }
  where
    doStmReply d s m = d s m >>= doXfmReply writer
    doXfmReply _ (NoReply a)         = return a
    doXfmReply _ (ProcessReject _ a) = return a
    doXfmReply w (ProcessReply r' a) = liftIO (atomically $ w r') >> return a
handleControlChan :: forall s a . (Serializable a)
    => ControlChannel a 
    -> ActionHandler s a
       
    -> ExternDispatcher s
handleControlChan chan h
  = DispatchCC { channel      = snd $ unControl chan
               , dispatchChan = \s ((CastMessage p) :: Message a ()) -> h s p
               }
handleControlChan_ :: forall s a. (Serializable a)
           => ControlChannel a
           -> StatelessHandler s a
           -> ExternDispatcher s
handleControlChan_ chan h
  = DispatchCC { channel      = snd $ unControl chan
               , dispatchChan = \s ((CastMessage p) :: Message a ()) -> h p s
               }
handleCast_ :: (Serializable a)
            => StatelessHandler s a
            -> Dispatcher s
handleCast_ = handleCastIf_ $ input (const True)
handleCastIf_ :: forall s a . (Serializable a)
    => Condition s a 
    -> StatelessHandler s a
        
    -> Dispatcher s
handleCastIf_ cond h
  = DispatchIf { dispatch   = \s ((CastMessage p) :: Message a ()) -> h p $ s
               , dispatchIf = checkCast cond
               }
action :: forall s a . (Serializable a)
    => StatelessHandler s a
          
    -> Dispatcher s
action h = handleDispatch perform
  where perform :: ActionHandler s a
        perform s a = let f = h a in f s
handleDispatch :: forall s a . (Serializable a)
               => ActionHandler s a
               -> Dispatcher s
handleDispatch = handleDispatchIf $ input (const True)
handleDispatchIf :: forall s a . (Serializable a)
                 => Condition s a
                 -> ActionHandler s a
                 -> Dispatcher s
handleDispatchIf cond handler = DispatchIf {
      dispatch = doHandle handler
    , dispatchIf = check cond
    }
  where doHandle :: (Serializable a)
                 => ActionHandler s a
                 -> s
                 -> Message a ()
                 -> Process (ProcessAction s)
        doHandle h s msg =
            case msg of
                (CallMessage p _) -> h s p
                (CastMessage p)   -> h s p
                (ChanMessage p _) -> h s p
handleInfo :: forall s a. (Serializable a)
           => ActionHandler s a
           -> DeferredDispatcher s
handleInfo h = DeferredDispatcher { dispatchInfo = doHandleInfo h }
  where
    doHandleInfo :: forall s2 a2. (Serializable a2)
                             => ActionHandler s2 a2
                             -> s2
                             -> P.Message
                             -> Process (Maybe (ProcessAction s2))
    doHandleInfo h' s msg = handleMessage msg (h' s)
handleRaw :: forall s. ActionHandler s P.Message
          -> DeferredDispatcher s
handleRaw h = DeferredDispatcher { dispatchInfo = doHandle h }
  where
    doHandle h' s msg = fmap Just (h' s msg)
handleExit :: forall s a. (Serializable a)
           => (ProcessId -> ActionHandler s a)
           -> ExitSignalDispatcher s
handleExit h = ExitSignalDispatcher { dispatchExit = doHandleExit h }
  where
    doHandleExit :: (ProcessId -> ActionHandler s a)
                 -> s
                 -> ProcessId
                 -> P.Message
                 -> Process (Maybe (ProcessAction s))
    doHandleExit h' s p msg = handleMessage msg (h' p s)
handleExitIf :: forall s a . (Serializable a)
             => (s -> a -> Bool)
             -> (ProcessId -> ActionHandler s a)
             -> ExitSignalDispatcher s
handleExitIf c h = ExitSignalDispatcher { dispatchExit = doHandleExit c h }
  where
    doHandleExit :: (s -> a -> Bool)
                 -> (ProcessId -> ActionHandler s a)
                 -> s
                 -> ProcessId
                 -> P.Message
                 -> Process (Maybe (ProcessAction s))
    doHandleExit c' h' s p msg = handleMessageIf msg (c' s) (h' p s)
mkReply :: (Serializable b)
        => CallRef b
        -> ProcessReply b s
        -> Process (ProcessAction s)
mkReply cRef act
  | (NoReply a)          <- act  = return a
  | (CallRef (_, tg'))   <- cRef
  , (ProcessReply  r' a) <- act  = sendTo cRef (CallResponse r' tg') >> return a
  | (CallRef (_, ct'))   <- cRef
  , (ProcessReject r' a) <- act  = sendTo cRef (CallRejected r' ct') >> return a
  | otherwise                    = die $ ExitOther "mkReply.InvalidState"
check :: forall s m a . (Serializable m)
            => Condition s m
            -> s
            -> Message m a
            -> Bool
check (Condition c) st msg = c st $ decode msg
check (State     c) st _   = c st
check (Input     c) _  msg = c $ decode msg
checkRpc :: forall s m a . (Serializable m)
            => Condition s m
            -> s
            -> Message m a
            -> Bool
checkRpc cond st msg@(ChanMessage _ _) = check cond st msg
checkRpc _    _  _                     = False
checkCall :: forall s m a . (Serializable m)
             => Condition s m
             -> s
             -> Message m a
             -> Bool
checkCall cond st msg@(CallMessage _ _) = check cond st msg
checkCall _    _  _                     = False
checkCast :: forall s m . (Serializable m)
             => Condition s m
             -> s
             -> Message m ()
             -> Bool
checkCast cond st msg@(CastMessage _) = check cond st msg
checkCast _    _     _                = False
decode :: Message a b -> a
decode (CallMessage a _) = a
decode (CastMessage a)   = a
decode (ChanMessage a _) = a