{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE GADTs #-}
module Control.Eff.Concurrent.Dispatcher
( runMainProcess
, defaultMain
, spawn
, DispatcherError(..)
, DispatcherIO
, HasDispatcherIO
, ProcIO
)
where
import GHC.Stack
import Data.Maybe
import Data.Kind
import qualified Control.Exception as Exc
import Control.Concurrent as Concurrent
import Control.Concurrent.STM as STM
import Control.Eff
import Control.Eff.Concurrent.MessagePassing
import Control.Eff.ExceptionExtra
import Control.Eff.Lift
import Control.Eff.Log
import Control.Eff.Reader.Strict as Reader
import Control.Lens
import Control.Monad ( when
, void
)
import qualified Control.Monad.State as Mtl
import Data.Dynamic
import Data.Typeable ( typeRep )
import Data.Map ( Map )
import qualified Data.Map as Map
data ProcessInfo =
ProcessInfo { _processId :: ProcessId
, _messageQ :: STM.TQueue Dynamic
, _exitOnShutdown :: Bool
}
makeLenses ''ProcessInfo
instance Show ProcessInfo where
show p =
"ProcessInfo: " ++ show (p ^. processId) ++ " trapExit: "
++ show (not (p^.exitOnShutdown))
data Dispatcher =
Dispatcher { _nextPid :: ProcessId
, _processTable :: Map ProcessId ProcessInfo
, _threadIdTable :: Map ProcessId ThreadId
, _schedulerShuttingDown :: Bool
, _logChannel :: LogChannel String
}
makeLenses ''Dispatcher
newtype DispatcherVar = DispatcherVar { fromDispatcherVar :: STM.TVar Dispatcher }
deriving Typeable
data DispatcherError =
UnhandledMessageReceived Dynamic ProcessId
| ProcessNotFound ProcessId
| ProcessException String ProcessId
| DispatcherShuttingDown
| LowLevelIOException Exc.SomeException
deriving (Typeable, Show)
instance Exc.Exception DispatcherError
type HasDispatcherIO r = ( HasCallStack
, SetMember Lift (Lift IO) r
, Member (Exc DispatcherError) r
, Member (Logs String) r
, Member (Reader DispatcherVar) r)
type DispatcherIO =
'[ Exc DispatcherError
, Reader DispatcherVar
, Logs String
, Lift IO
]
type ProcIO = ConsProcIO DispatcherIO
type ConsProcIO r = MessagePassing ': Process ': r
instance MonadLog String (Eff ProcIO) where
logMessageFree = logMessageFreeEff
runMainProcess :: Eff ProcIO a -> LogChannel String -> IO a
runMainProcess e logC = withDispatcher
(dispatchMessages
(\cleanup -> do
mt <- lift myThreadId
mp <- self
logMessage (show mp ++ " main process started in thread " ++ show mt)
res <- try e
case res of
Left ex ->
do
logMessage
( show mp
++ " main process exception: "
++ ((show :: DispatcherError -> String) ex)
)
lift (runCleanUpAction cleanup)
>> throwError ex
Right rres -> do
logMessage (show mp ++ " main process exited")
lift (runCleanUpAction cleanup)
return rres
)
)
where
withDispatcher :: Eff DispatcherIO a -> IO a
withDispatcher mainProcessAction = do
myTId <- myThreadId
Exc.bracket
(newTVarIO (Dispatcher myPid Map.empty Map.empty False logC))
(tearDownDispatcher myTId)
(\sch -> runLift
(forwardLogsToChannel
logC
(runReader (runErrorRethrowIO mainProcessAction) (DispatcherVar sch))
)
)
where
myPid = 1
tearDownDispatcher myTId v = do
logChannelPutIO logC (show myTId ++" begin dispatcher tear down")
sch <-
(atomically
(do
sch <- readTVar v
let sch' = sch & schedulerShuttingDown .~ True
writeTVar v sch'
return sch
)
)
logChannelPutIO logC (show myTId ++ " killing threads: " ++
show (sch ^.. threadIdTable. traversed))
imapM_ (killProcThread myTId) (sch ^. threadIdTable)
atomically
(do
dispatcher <- readTVar v
let allThreadsDead = dispatcher^.threadIdTable.to Map.null
&& dispatcher^.processTable.to Map.null
STM.check allThreadsDead)
logChannelPutIO logC "all threads dead"
killProcThread myTId pid tid = when
(myTId /= tid)
( logChannelPutIO logC ("killing thread " ++ show pid)
>> killThread tid
)
defaultMain :: Eff ProcIO a -> IO a
defaultMain c =
runLoggingT
(logChannelBracket
(Just "~~~~~~ main process started")
(Just "====== main process exited")
(runMainProcess c))
(print :: String -> IO ())
runChildProcess
:: DispatcherVar
-> (CleanUpAction -> Eff ProcIO a)
-> IO (Either DispatcherError a)
runChildProcess s procAction = do
l <- (view logChannel) <$> atomically (readTVar (fromDispatcherVar s))
runLift
(forwardLogsToChannel
l
(runReader (runError (dispatchMessages procAction)) s)
)
getLogChannel :: HasDispatcherIO r => Eff r (LogChannel String)
getLogChannel = do
s <- getDispatcherTVar
lift ((view logChannel) <$> atomically (readTVar s))
overProcessInfo
:: HasDispatcherIO r
=> ProcessId
-> Mtl.StateT ProcessInfo STM.STM a
-> Eff r a
overProcessInfo pid stAction = liftEither =<< overDispatcher
(do
res <- use (processTable . at pid)
case res of
Nothing -> return (Left (ProcessNotFound pid))
Just pinfo -> do
(x, pinfoOut) <- Mtl.lift (Mtl.runStateT stAction pinfo)
processTable . at pid . _Just .= pinfoOut
return (Right x)
)
spawn :: HasDispatcherIO r => Eff ProcIO () -> Eff r ProcessId
spawn mfa = do
schedulerVar <- ask
pidVar <- lift newEmptyTMVarIO
cleanupVar <- lift newEmptyTMVarIO
lc <- getLogChannel
void
(lift
(Concurrent.forkFinally
(runChildProcess
schedulerVar
(\cleanUpAction -> do
lift (atomically (STM.putTMVar cleanupVar cleanUpAction))
pid <- self
lift (atomically (STM.putTMVar pidVar (Just pid)))
catchError
mfa
( logMessage
. ("process exception: " ++)
. (show :: DispatcherError -> String)
)
)
)
(\eres -> do
mcleanup <- atomically (STM.tryTakeTMVar cleanupVar)
void (atomically (tryPutTMVar pidVar Nothing))
case mcleanup of
Nothing -> return ()
Just ca -> do
runCleanUpAction ca
mt <- myThreadId
case eres of
Left se -> logChannelPutIO
lc
("thread " ++ show mt ++ " killed by exception: " ++ show se)
Right _ ->
logChannelPutIO lc ("thread " ++ show mt ++ " exited")
)
)
)
mPid <- lift (atomically (STM.takeTMVar pidVar))
maybe (throwError DispatcherShuttingDown) return mPid
newtype CleanUpAction = CleanUpAction { runCleanUpAction :: IO () }
dispatchMessages
:: forall r a
. (HasDispatcherIO r, HasCallStack)
=> (CleanUpAction -> Eff (ConsProcIO r) a)
-> Eff r a
dispatchMessages processAction = withMessageQueue
(\cleanUpAction pinfo ->
try
(handle_relay
return
(goProc (pinfo ^. processId))
(handle_relay return (go (pinfo ^. processId))
(processAction cleanUpAction)
))
>>=
either
(\(e :: DispatcherError) ->
do logMsg (show (pinfo^.processId)
++ " cleanup on exception: "
++ show e)
lift (runCleanUpAction cleanUpAction)
throwError e)
return
)
where
go
:: forall v
. HasCallStack
=> ProcessId
-> MessagePassing v
-> (v -> Eff (Process ': r) a)
-> Eff (Process ': r) a
go _pid (SendMessage toPid reqIn) k = do
psVar <- getDispatcherTVar
liftRethrow
LowLevelIOException
(atomically
(do
p <- readTVar psVar
let mto = p ^. processTable . at toPid
case mto of
Just toProc ->
let dReq = toDyn reqIn
in do
writeTQueue (toProc ^. messageQ) dReq
return True
Nothing -> return False
)
)
>>= k
go pid (ReceiveMessage onMsg) k = do
catchError
(do
mDynMsg <- overProcessInfo pid
(do
mq <- use messageQ
Mtl.lift (readTQueue mq))
case fromDynamic mDynMsg of
Just req -> let result = onMsg req in k (Message result)
nix@Nothing ->
let
msg =
"unexpected message: " ++ show mDynMsg ++ " expected: " ++ show
(typeRep nix)
in do
isExitOnShutdown <- overProcessInfo pid (use exitOnShutdown)
if isExitOnShutdown
then throwError (UnhandledMessageReceived mDynMsg pid)
else k (ProcessControlMessage msg)
)
(\(se :: DispatcherError) -> do
isExitOnShutdown <- overProcessInfo pid (use exitOnShutdown)
if isExitOnShutdown
then throwError se
else k (ProcessControlMessage (show se))
)
goProc
:: forall v x
. HasCallStack
=> ProcessId
-> Process v
-> (v -> Eff r x)
-> Eff r x
goProc pid SelfPid k = k pid
goProc pid (TrapExit s) k =
overProcessInfo pid (exitOnShutdown .= (not s)) >>= k
goProc pid GetTrapExit k =
overProcessInfo pid (use exitOnShutdown) >>= k . not
goProc pid (RaiseError msg) _k = do
logMsg (show pid ++ " error raised: " ++ msg)
throwError (ProcessException msg pid)
withMessageQueue
:: HasDispatcherIO r => (CleanUpAction -> ProcessInfo -> Eff r a) -> Eff r a
withMessageQueue m = do
mpinfo <- createQueue
lc <- getLogChannel
case mpinfo of
Just pinfo -> do
cleanUpAction <-
getDispatcherTVar >>= return . CleanUpAction . destroyQueue
lc
(pinfo ^. processId)
m cleanUpAction pinfo
Nothing -> throwError DispatcherShuttingDown
where
createQueue = do
myTId <- lift myThreadId
overDispatcher
(do
abortNow <- use schedulerShuttingDown
if abortNow
then return Nothing
else do
pid <- nextPid <<+= 1
channel <- Mtl.lift newTQueue
let pinfo = ProcessInfo pid channel True
threadIdTable . at pid .= Just myTId
processTable . at pid .= Just pinfo
return (Just pinfo)
)
destroyQueue lc pid psVar = do
didWork <- Exc.try
(overDispatcherIO
psVar
(do
abortNow <- use schedulerShuttingDown
if abortNow
then return (Nothing, False)
else do
os <- processTable . at pid <<.= Nothing
ot <- threadIdTable . at pid <<.= Nothing
return (os, isJust os || isJust ot)
)
)
let getCause =
Exc.try @Exc.SomeException
(overDispatcherIO psVar (preuse (processTable . at pid)))
>>= either
(return . (show pid ++) . show)
(return . (maybe (show pid) show))
case didWork of
Right (pinfo, True) ->
logChannelPutIO lc ("destroying queue: " ++ show pinfo)
Right (pinfo, False) ->
logChannelPutIO lc ("queue already destroyed: " ++ show pinfo)
Left (e :: Exc.SomeException) ->
getCause
>>= logChannelPutIO lc
. (("failed to destroy queue: " ++ show e ++ " ") ++)
overDispatcher
:: HasDispatcherIO r => Mtl.StateT Dispatcher STM.STM a -> Eff r a
overDispatcher stAction = do
psVar <- getDispatcherTVar
liftRethrow LowLevelIOException (overDispatcherIO psVar stAction)
overDispatcherIO
:: STM.TVar Dispatcher -> Mtl.StateT Dispatcher STM.STM a -> IO a
overDispatcherIO psVar stAction = STM.atomically
(do
ps <- STM.readTVar psVar
(result, psModified) <- Mtl.runStateT stAction ps
STM.writeTVar psVar psModified
return result
)
getDispatcherTVar :: HasDispatcherIO r => Eff r (TVar Dispatcher)
getDispatcherTVar = fromDispatcherVar <$> ask
getDispatcher :: HasDispatcherIO r => Eff r Dispatcher
getDispatcher = do
processesVar <- getDispatcherTVar
lift (atomically (readTVar processesVar))