module Test.DejaFu.Conc.Internal where
import Control.Exception (MaskingState(..),
toException)
import Control.Monad.Conc.Class (MonadConc,
rtsSupportsBoundThreads)
import Control.Monad.Ref (MonadRef, newRef, readRef,
writeRef)
import Data.Functor (void)
import Data.List (sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust)
import Data.Monoid ((<>))
import Data.Sequence (Seq, (<|))
import qualified Data.Sequence as Seq
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Memory
import Test.DejaFu.Conc.Internal.STM
import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.Internal
import Test.DejaFu.Schedule
import Test.DejaFu.Types
type SeqTrace
= Seq (Decision, [(ThreadId, Lookahead)], ThreadAction)
runConcurrency :: (MonadConc n, MonadRef r n)
=> Scheduler g
-> MemType
-> g
-> IdSource
-> Int
-> M n r a
-> n (Either Failure a, Context n r g, SeqTrace, Maybe (ThreadId, ThreadAction))
runConcurrency sched memtype g idsrc caps ma = do
(c, ref) <- runRefCont AStop (Just . Right) (runM ma)
let threads0 = launch' Unmasked initialThread (const c) M.empty
threads <- (if rtsSupportsBoundThreads then makeBound initialThread else pure) threads0
let ctx = Context { cSchedState = g
, cIdSource = idsrc
, cThreads = threads
, cWriteBuf = emptyBuffer
, cCaps = caps
}
(finalCtx, trace, finalAction) <- runThreads sched memtype ref ctx
let finalThreads = cThreads finalCtx
mapM_ (`kill` finalThreads) (M.keys finalThreads)
out <- readRef ref
pure (efromJust "runConcurrency" out, finalCtx, trace, finalAction)
data Context n r g = Context
{ cSchedState :: g
, cIdSource :: IdSource
, cThreads :: Threads n r
, cWriteBuf :: WriteBuffer r
, cCaps :: Int
}
runThreads :: (MonadConc n, MonadRef r n)
=> Scheduler g
-> MemType
-> r (Maybe (Either Failure a))
-> Context n r g
-> n (Context n r g, SeqTrace, Maybe (ThreadId, ThreadAction))
runThreads sched memtype ref = go Seq.empty Nothing where
go sofar prior ctx
| isTerminated = pure (ctx, sofar, prior)
| isDeadlocked = die sofar prior Deadlock ctx
| isSTMLocked = die sofar prior STMDeadlock ctx
| otherwise =
let ctx' = ctx { cSchedState = g' }
in case choice of
Just chosen -> case M.lookup chosen threadsc of
Just thread
| isBlocked thread -> die sofar prior InternalError ctx'
| otherwise -> step chosen thread ctx'
Nothing -> die sofar prior InternalError ctx'
Nothing -> die sofar prior Abort ctx'
where
(choice, g') = scheduleThread sched prior (efromList "runThreads" runnable') (cSchedState ctx)
runnable' = [(t, lookahead (_continuation a)) | (t, a) <- sortOn fst $ M.assocs runnable]
runnable = M.filter (not . isBlocked) threadsc
threadsc = addCommitThreads (cWriteBuf ctx) threads
threads = cThreads ctx
isBlocked = isJust . _blocking
isTerminated = initialThread `notElem` M.keys threads
isDeadlocked = M.null (M.filter (not . isBlocked) threads) &&
(((~= OnMVarFull undefined) <$> M.lookup initialThread threads) == Just True ||
((~= OnMVarEmpty undefined) <$> M.lookup initialThread threads) == Just True ||
((~= OnMask undefined) <$> M.lookup initialThread threads) == Just True)
isSTMLocked = M.null (M.filter (not . isBlocked) threads) &&
((~= OnTVar []) <$> M.lookup initialThread threads) == Just True
unblockWaitingOn tid = fmap unblock where
unblock thrd = case _blocking thrd of
Just (OnMask t) | t == tid -> thrd { _blocking = Nothing }
_ -> thrd
die sofar' finalDecision reason finalCtx = do
writeRef ref (Just $ Left reason)
pure (finalCtx, sofar', finalDecision)
step chosen thread ctx' = do
(res, actOrTrc) <- stepThread sched memtype chosen (_continuation thread) $ ctx { cSchedState = g' }
let trc = getTrc actOrTrc
let sofar' = sofar <> trc
let prior' = getPrior actOrTrc
case res of
Right ctx'' ->
let threads' = if (interruptible <$> M.lookup chosen (cThreads ctx'')) /= Just False
then unblockWaitingOn chosen (cThreads ctx'')
else cThreads ctx''
ctx''' = ctx'' { cThreads = delCommitThreads threads' }
in go sofar' prior' ctx'''
Left failure ->
let ctx'' = ctx' { cThreads = delCommitThreads threads }
in die sofar' prior' failure ctx''
where
decision
| Just chosen == (fst <$> prior) = Continue
| (fst <$> prior) `notElem` map (Just . fst) runnable' = Start chosen
| otherwise = SwitchTo chosen
getTrc (Single a) = Seq.singleton (decision, alternatives, a)
getTrc (SubC as _) = (decision, alternatives, Subconcurrency) <| as
alternatives = filter (\(t, _) -> t /= chosen) runnable'
getPrior (Single a) = Just (chosen, a)
getPrior (SubC _ finalD) = finalD
data Act
= Single ThreadAction
| SubC SeqTrace (Maybe (ThreadId, ThreadAction))
deriving (Eq, Show)
stepThread :: forall n r g. (MonadConc n, MonadRef r n)
=> Scheduler g
-> MemType
-> ThreadId
-> Action n r
-> Context n r g
-> n (Either Failure (Context n r g), Act)
stepThread sched memtype tid action ctx = case action of
AFork n a b -> pure $
let threads' = launch tid newtid a (cThreads ctx)
(idSource', newtid) = nextTId n (cIdSource ctx)
in (Right ctx { cThreads = goto (b newtid) tid threads', cIdSource = idSource' }, Single (Fork newtid))
AForkOS n a b -> do
let (idSource', newtid) = nextTId n (cIdSource ctx)
let threads' = launch tid newtid a (cThreads ctx)
threads'' <- makeBound newtid threads'
pure (Right ctx { cThreads = goto (b newtid) tid threads'', cIdSource = idSource' }, Single (Fork newtid))
AIsBound c ->
let isBound = isJust (_bound =<< M.lookup tid (cThreads ctx))
in simple (goto (c isBound) tid (cThreads ctx)) (IsCurrentThreadBound isBound)
AMyTId c -> simple (goto (c tid) tid (cThreads ctx)) MyThreadId
AGetNumCapabilities c -> simple (goto (c (cCaps ctx)) tid (cThreads ctx)) $ GetNumCapabilities (cCaps ctx)
ASetNumCapabilities i c -> pure
(Right ctx { cThreads = goto c tid (cThreads ctx), cCaps = i }, Single (SetNumCapabilities i))
AYield c -> simple (goto c tid (cThreads ctx)) Yield
ADelay n c -> simple (goto c tid (cThreads ctx)) (ThreadDelay n)
ANewMVar n c -> do
let (idSource', newmvid) = nextMVId n (cIdSource ctx)
ref <- newRef Nothing
let mvar = MVar newmvid ref
pure (Right ctx { cThreads = goto (c mvar) tid (cThreads ctx), cIdSource = idSource' }, Single (NewMVar newmvid))
APutMVar cvar@(MVar cvid _) a c -> synchronised $ do
(success, threads', woken) <- putIntoMVar cvar a c tid (cThreads ctx)
simple threads' $ if success then PutMVar cvid woken else BlockedPutMVar cvid
ATryPutMVar cvar@(MVar cvid _) a c -> synchronised $ do
(success, threads', woken) <- tryPutIntoMVar cvar a c tid (cThreads ctx)
simple threads' $ TryPutMVar cvid success woken
AReadMVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', _) <- readFromMVar cvar c tid (cThreads ctx)
simple threads' $ if success then ReadMVar cvid else BlockedReadMVar cvid
ATryReadMVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', _) <- tryReadFromMVar cvar c tid (cThreads ctx)
simple threads' $ TryReadMVar cvid success
ATakeMVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', woken) <- takeFromMVar cvar c tid (cThreads ctx)
simple threads' $ if success then TakeMVar cvid woken else BlockedTakeMVar cvid
ATryTakeMVar cvar@(MVar cvid _) c -> synchronised $ do
(success, threads', woken) <- tryTakeFromMVar cvar c tid (cThreads ctx)
simple threads' $ TryTakeMVar cvid success woken
ANewCRef n a c -> do
let (idSource', newcrid) = nextCRId n (cIdSource ctx)
ref <- newRef (M.empty, 0, a)
let cref = CRef newcrid ref
pure (Right ctx { cThreads = goto (c cref) tid (cThreads ctx), cIdSource = idSource' }, Single (NewCRef newcrid))
AReadCRef cref@(CRef crid _) c -> do
val <- readCRef cref tid
simple (goto (c val) tid (cThreads ctx)) $ ReadCRef crid
AReadCRefCas cref@(CRef crid _) c -> do
tick <- readForTicket cref tid
simple (goto (c tick) tid (cThreads ctx)) $ ReadCRefCas crid
AModCRef cref@(CRef crid _) f c -> synchronised $ do
(new, val) <- f <$> readCRef cref tid
writeImmediate cref new
simple (goto (c val) tid (cThreads ctx)) $ ModCRef crid
AModCRefCas cref@(CRef crid _) f c -> synchronised $ do
tick@(Ticket _ _ old) <- readForTicket cref tid
let (new, val) = f old
void $ casCRef cref tid tick new
simple (goto (c val) tid (cThreads ctx)) $ ModCRefCas crid
AWriteCRef cref@(CRef crid _) a c -> case memtype of
SequentialConsistency -> do
writeImmediate cref a
simple (goto c tid (cThreads ctx)) $ WriteCRef crid
TotalStoreOrder -> do
wb' <- bufferWrite (cWriteBuf ctx) (tid, Nothing) cref a
pure (Right ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Single (WriteCRef crid))
PartialStoreOrder -> do
wb' <- bufferWrite (cWriteBuf ctx) (tid, Just crid) cref a
pure (Right ctx { cThreads = goto c tid (cThreads ctx), cWriteBuf = wb' }, Single (WriteCRef crid))
ACasCRef cref@(CRef crid _) tick a c -> synchronised $ do
(suc, tick') <- casCRef cref tid tick a
simple (goto (c (suc, tick')) tid (cThreads ctx)) $ CasCRef crid suc
ACommit t c -> do
wb' <- case memtype of
SequentialConsistency ->
fatal "stepThread.ACommit" "Attempting to commit under SequentialConsistency"
TotalStoreOrder -> commitWrite (cWriteBuf ctx) (t, Nothing)
PartialStoreOrder -> commitWrite (cWriteBuf ctx) (t, Just c)
pure (Right ctx { cWriteBuf = wb' }, Single (CommitCRef t c))
AAtom stm c -> synchronised $ do
(res, idSource', trace) <- runTransaction stm (cIdSource ctx)
case res of
Success _ written val ->
let (threads', woken) = wake (OnTVar written) (cThreads ctx)
in pure (Right ctx { cThreads = goto (c val) tid threads', cIdSource = idSource' }, Single (STM trace woken))
Retry touched ->
let threads' = block (OnTVar touched) tid (cThreads ctx)
in pure (Right ctx { cThreads = threads', cIdSource = idSource'}, Single (BlockedSTM trace))
Exception e -> do
let act = STM trace []
res' <- stepThrow tid (cThreads ctx) act e
pure $ case res' of
(Right ctx', _) -> (Right ctx' { cIdSource = idSource' }, Single act)
(Left err, _) -> (Left err, Single act)
ALift na -> do
a <- runLiftedAct tid (cThreads ctx) na
simple (goto a tid (cThreads ctx)) LiftIO
AThrow e -> stepThrow tid (cThreads ctx) Throw e
AThrowTo t e c -> synchronised $
let threads' = goto c tid (cThreads ctx)
blocked = block (OnMask t) tid (cThreads ctx)
in case M.lookup t (cThreads ctx) of
Just thread
| interruptible thread -> stepThrow t threads' (ThrowTo t) e
| otherwise -> simple blocked $ BlockedThrowTo t
Nothing -> simple threads' $ ThrowTo t
ACatching h ma c ->
let a = runCont ma (APopCatching . c)
e exc = runCont (h exc) c
threads' = goto a tid (catching e tid (cThreads ctx))
in simple threads' Catching
APopCatching a ->
let threads' = goto a tid (uncatching tid (cThreads ctx))
in simple threads' PopCatching
AMasking m ma c ->
let a = runCont (ma umask) (AResetMask False False m' . c)
m' = _masking . efromJust "stepThread.AMasking" $ M.lookup tid (cThreads ctx)
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> pure b
resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k ()
threads' = goto a tid (mask m tid (cThreads ctx))
in simple threads' $ SetMasking False m
AResetMask b1 b2 m c ->
let act = (if b1 then SetMasking else ResetMasking) b2 m
threads' = goto c tid (mask m tid (cThreads ctx))
in simple threads' act
AReturn c -> simple (goto c tid (cThreads ctx)) Return
AStop na -> do
na
threads' <- kill tid (cThreads ctx)
simple threads' Stop
ASub ma c
| M.size (cThreads ctx) > 1 -> pure (Left IllegalSubconcurrency, Single Subconcurrency)
| otherwise -> do
(res, ctx', trace, finalDecision) <-
runConcurrency sched memtype (cSchedState ctx) (cIdSource ctx) (cCaps ctx) ma
pure (Right ctx { cThreads = goto (AStopSub (c res)) tid (cThreads ctx)
, cIdSource = cIdSource ctx'
, cSchedState = cSchedState ctx' }, SubC trace finalDecision)
AStopSub c -> simple (goto c tid (cThreads ctx)) StopSubconcurrency
where
stepThrow t ts act e =
let some = toException e
in case propagate some t ts of
Just ts' -> simple ts' act
Nothing
| t == initialThread -> pure (Left (UncaughtException some), Single act)
| otherwise -> do
ts' <- kill t ts
simple ts' act
simple threads' act = pure (Right ctx { cThreads = threads' }, Single act)
synchronised ma = do
writeBarrier (cWriteBuf ctx)
res <- ma
pure $ case res of
(Right ctx', act) -> (Right ctx' { cWriteBuf = emptyBuffer }, act)
_ -> res