module Control.TimeWarp.Timed.TimedT
( TimedT
, PureThreadId
, runTimedT
, defaultLoggerName
) where
import Control.Applicative ((<|>))
import Control.Exception.Base (AsyncException (ThreadKilled),
Exception (fromException),
SomeException (..))
import Control.Lens (at, makeLenses, use, view, (%=), (%~),
(&), (.=), (.~), (<&>), (<<+=),
(<<.=), (^.))
import Control.Monad (unless, void)
import Control.Monad.Catch (Handler (..), MonadCatch, MonadMask,
MonadThrow, catch, catchAll, catches,
finally, mask, throwM, try,
uninterruptibleMask)
import Control.Monad.Cont (ContT (..), runContT)
import Control.Monad.Loops (whileM_)
import Control.Monad.Reader (ReaderT (..), ask, local, runReaderT)
import Control.Monad.State (MonadState (get, put, state), StateT,
evalStateT)
import Control.Monad.Trans (MonadIO, MonadTrans, lift, liftIO)
import Data.Function (on)
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.List (foldl')
import Data.Ord (comparing)
import Formatting (sformat, shown, (%))
import qualified Data.Map as M
import qualified Data.PQueue.Min as PQ
import Safe (fromJustNote)
import System.Wlog (CanLog, HasLoggerName (..),
LoggerName, WithLogger, logDebug,
logWarning)
import Control.TimeWarp.Timed.MonadTimed (Microsecond, MonadTimed (..),
MonadTimedError (MTTimeoutError),
ThreadId, after, for, mcs, schedule,
timeout, virtualTime)
newtype PureThreadId = PureThreadId Integer
deriving (Eq, Ord)
instance Show PureThreadId where
show (PureThreadId tid) = "PureThreadId " ++ show tid
data ThreadCtx c = ThreadCtx
{
_threadId :: PureThreadId
, _handlers :: [(Handler c (), Handler c ())]
, _loggerName :: LoggerName
}
$(makeLenses ''ThreadCtx)
data Event m c = Event
{ _timestamp :: Microsecond
, _action :: m ()
, _threadCtx :: ThreadCtx c
}
$(makeLenses ''Event)
instance Eq (Event m c) where
(==) = (==) `on` _timestamp
instance Ord (Event m c) where
compare = comparing _timestamp
data Scenario m c = Scenario
{
_events :: PQ.MinQueue (Event m c)
, _curTime :: Microsecond
, _asyncExceptions :: M.Map PureThreadId SomeException
, _threadsCounter :: Integer
}
$(makeLenses ''Scenario)
emptyScenario :: Scenario m c
emptyScenario =
Scenario
{ _events = PQ.empty
, _curTime = 0
, _asyncExceptions = M.empty
, _threadsCounter = 0
}
newtype Core m a = Core
{ getCore :: StateT (Scenario (TimedT m) (Core m)) m a
} deriving (Functor, Applicative, Monad, MonadIO, MonadThrow,
MonadCatch, MonadMask,
MonadState (Scenario (TimedT m) (Core m)))
instance MonadTrans Core where
lift = Core . lift
newtype TimedT m a = TimedT
{ unwrapTimedT :: ReaderT (ThreadCtx (Core m))
( ContT ()
( Core m )
) a
} deriving (Functor, Applicative, Monad, MonadIO)
instance MonadThrow m => MonadThrow (TimedT m) where
throwM = TimedT . throwM
instance MonadTrans TimedT where
lift = TimedT . lift . lift . lift
instance MonadState s m => MonadState s (TimedT m) where
get = lift get
put = lift . put
state = lift . state
instance HasLoggerName (TimedT m) where
getLoggerName = TimedT $ view loggerName
modifyLoggerName how = TimedT . local (loggerName %~ how) . unwrapTimedT
instance CanLog m => CanLog (TimedT m) where
newtype ContException = ContException SomeException
deriving (Show)
instance Exception ContException
instance (MonadCatch m) => MonadCatch (TimedT m) where
catch m handler =
TimedT $
ReaderT $
\r ->
ContT $
\c ->
let safeCont x = c x `catchAll` (throwM . ContException)
r' = r & handlers %~ (:) (Handler handler', contHandler)
act = unwrapCore' r' $ m >>= wrapCore . safeCont
handler' e = unwrapCore' r $ handler e >>= wrapCore . c
in act `catches` [contHandler, Handler handler']
contHandler :: MonadThrow m => Handler m ()
contHandler = Handler $ \(ContException e) -> throwM e
instance (MonadIO m, MonadCatch m) => MonadMask (TimedT m) where
mask a = a id
uninterruptibleMask = mask
wrapCore :: Monad m => Core m a -> TimedT m a
wrapCore = TimedT . lift . lift
unwrapCore :: ThreadCtx (Core m)
-> (a -> Core m ())
-> TimedT m a
-> Core m ()
unwrapCore r c = flip runContT c
. flip runReaderT r
. unwrapTimedT
unwrapCore' :: Monad m => ThreadCtx (Core m) -> TimedT m () -> Core m ()
unwrapCore' r = unwrapCore r return
getTimedT :: Monad m => TimedT m a -> m ()
getTimedT t = flip evalStateT emptyScenario
$ getCore
$ unwrapCore vacuumCtx (void . return) t
where
vacuumCtx = error "Access to thread context from nowhere"
launchTimedT :: (MonadIO m, MonadCatch m) => TimedT m () -> m ()
launchTimedT timed = getTimedT $ do
mainThreadCtx >>= flip runInSandbox timed
whileM_ notDone $ do
nextEv <- wrapCore . Core $ do
(ev, evs') <- fromJustNote "Suddenly no more events" . PQ.minView
<$> use events
events .= evs'
return ev
TimedT $ curTime .= nextEv ^. timestamp
let ctx = nextEv ^. threadCtx
tid = ctx ^. threadId
maybeAsyncExc <- TimedT $ asyncExceptions . at tid <<.= Nothing
let act = do
mapM_ throwInnard maybeAsyncExc
runInSandbox ctx (nextEv ^. action)
wrapCore $ (unwrapCore' ctx act) `catchesSeq` (ctx ^. handlers)
where
notDone :: Monad m => TimedT m Bool
notDone = not . PQ.null <$> TimedT (use events)
runInSandbox r = wrapCore . unwrapCore' r
mainThreadCtx = getNextThreadId <&>
\tid ->
ThreadCtx
{ _threadId = tid
, _handlers = [( Handler throwInnard
, contHandler
)]
, _loggerName = defaultLoggerName
}
catchesSeq = foldl' $ \act (h, hc) -> act `catches` [hc, h]
throwInnard (SomeException e) = throwM e
getNextThreadId :: Monad m => TimedT m PureThreadId
getNextThreadId = TimedT . fmap PureThreadId $ threadsCounter <<+= 1
runTimedT
:: (MonadIO m, MonadCatch m)
=> TimedT m a -> m a
runTimedT timed = do
ref <- liftIO $ newIORef Nothing
launchTimedT $ do
m <- try timed
liftIO . writeIORef ref . Just $ m
res :: Either SomeException a <- fromJustNote "runTimedT: no result"
<$> liftIO (readIORef ref)
either throwM return res
isThreadKilled :: SomeException -> Bool
isThreadKilled = maybe False (== ThreadKilled) . fromException
threadKilledNotifier
:: WithLogger m
=> SomeException -> m ()
threadKilledNotifier e
| isThreadKilled e = logDebug msg
| otherwise = logWarning msg
where
msg = sformat ("Thread killed by exception: " % shown) e
type instance ThreadId (TimedT m) = PureThreadId
instance (CanLog m, MonadIO m, MonadThrow m, MonadCatch m) =>
MonadTimed (TimedT m) where
virtualTime = TimedT $ use curTime
currentTime = virtualTime
fork act
= do
_timestamp <- virtualTime
tid <- getNextThreadId
logName <- getLoggerName
let _threadCtx =
ThreadCtx
{ _threadId = tid
, _handlers = []
, _loggerName = logName
}
_action = act `catch` threadKilledNotifier
TimedT $ events %= PQ.insert Event {..}
wait $ for 1 mcs
return tid
wait relativeToNow = do
cur <- virtualTime
ctx <- TimedT ask
let event following =
Event
{ _threadCtx = ctx
, _timestamp = max cur (relativeToNow cur)
, _action = wrapCore following
}
TimedT . lift . ContT $
\c -> events %= PQ.insert (event $ c ())
myThreadId = TimedT $ view threadId
throwTo tid e = do
wakeUpThread
TimedT $ asyncExceptions . at tid %= (<|> Just (SomeException e))
where
wakeUpThread = TimedT $ do
time <- use curTime
let modifyRequired event =
if event ^. threadCtx . threadId == tid
then event & timestamp .~ time
else event
events %= PQ.fromList . map modifyRequired . PQ.toList
timeout t action' = do
pid <- myThreadId
done <- liftIO $ newIORef False
schedule (after t) $
(liftIO (readIORef done) >>=) . flip unless $
throwTo pid $ MTTimeoutError "Timeout exceeded"
action' `finally` liftIO (writeIORef done True)
forkSlave = undefined
defaultLoggerName :: LoggerName
defaultLoggerName = "emulation"