module Test.DejaFu.Conc.Internal where
import Control.Exception (MaskingState(..), toException)
import Control.Monad.Ref (MonadRef, newRef, writeRef)
import Data.Functor (void)
import Data.List (sort)
import Data.List.NonEmpty (NonEmpty(..), fromList)
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, isJust, isNothing, listToMaybe)
import Test.DPOR (Scheduler)
import Test.DejaFu.Common
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Memory
import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.STM (Result(..))
runThreads :: MonadRef r n => (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace))
-> Scheduler ThreadId ThreadAction Lookahead g -> MemType -> g -> Threads n r s -> IdSource -> r (Maybe (Either Failure a)) -> n (g, Trace ThreadId ThreadAction Lookahead)
runThreads runstm sched memtype origg origthreads idsrc ref = go idsrc [] Nothing origg origthreads emptyBuffer 2 where
go idSource sofar prior g threads wb caps
| isTerminated = stop g
| isDeadlocked = die g Deadlock
| isSTMLocked = die g STMDeadlock
| isAborted = die g' Abort
| isNonexistant = die g' InternalError
| isBlocked = die g' InternalError
| otherwise = do
stepped <- stepThread runstm memtype (_continuation $ fromJust thread) idSource chosen threads wb caps
case stepped of
Right (threads', idSource', act, wb', caps') -> loop threads' idSource' act wb' caps'
Left UncaughtException
| chosen == initialThread -> die g' UncaughtException
| otherwise -> loop (kill chosen threads) idSource Killed wb caps
Left failure -> die g' failure
where
(choice, g') = sched (map (\(d,_,a) -> (d,a)) $ reverse sofar) ((\p (_,_,a) -> (p,a)) <$> prior <*> listToMaybe sofar) (fromList $ map (\(t,l:|_) -> (t,l)) runnable') g
chosen = fromJust choice
runnable' = [(t, nextActions t) | t <- sort $ M.keys runnable]
runnable = M.filter (isNothing . _blocking) threadsc
thread = M.lookup chosen threadsc
threadsc = addCommitThreads wb threads
isAborted = isNothing choice
isBlocked = isJust . _blocking $ fromJust thread
isNonexistant = isNothing thread
isTerminated = initialThread `notElem` M.keys threads
isDeadlocked = M.null (M.filter (isNothing . _blocking) threads) &&
(((~= OnMVarFull undefined) <$> M.lookup initialThread threads) == Just True ||
((~= OnMVarEmpty undefined) <$> M.lookup initialThread threads) == Just True ||
((~= OnMask undefined) <$> M.lookup initialThread threads) == Just True)
isSTMLocked = M.null (M.filter (isNothing . _blocking) threads) &&
((~= OnTVar []) <$> M.lookup initialThread threads) == Just True
unblockWaitingOn tid = fmap unblock where
unblock thrd = case _blocking thrd of
Just (OnMask t) | t == tid -> thrd { _blocking = Nothing }
_ -> thrd
decision
| Just chosen == prior = Continue
| prior `notElem` map (Just . fst) runnable' = Start chosen
| otherwise = SwitchTo chosen
nextActions t = lookahead . _continuation . fromJust $ M.lookup t threadsc
stop outg = pure (outg, sofar)
die outg reason = writeRef ref (Just $ Left reason) >> stop outg
loop threads' idSource' act wb' =
let sofar' = ((decision, runnable', act) : sofar)
threads'' = if (interruptible <$> M.lookup chosen threads') /= Just False then unblockWaitingOn chosen threads' else threads'
in go idSource' sofar' (Just chosen) g' (delCommitThreads threads'') wb'
stepThread :: forall n r s. MonadRef r n
=> (forall x. s x -> IdSource -> n (Result x, IdSource, TTrace))
-> MemType
-> Action n r s
-> IdSource
-> ThreadId
-> Threads n r s
-> WriteBuffer r
-> Int
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepThread runstm memtype action idSource tid threads wb caps = case action of
AFork n a b -> stepFork n a b
AMyTId c -> stepMyTId c
AGetNumCapabilities c -> stepGetNumCapabilities c
ASetNumCapabilities i c -> stepSetNumCapabilities i c
AYield c -> stepYield c
ANewVar n c -> stepNewVar n c
APutVar var a c -> stepPutVar var a c
ATryPutVar var a c -> stepTryPutVar var a c
AReadVar var c -> stepReadVar var c
ATakeVar var c -> stepTakeVar var c
ATryTakeVar var c -> stepTryTakeVar var c
ANewRef n a c -> stepNewRef n a c
AReadRef ref c -> stepReadRef ref c
AReadRefCas ref c -> stepReadRefCas ref c
AModRef ref f c -> stepModRef ref f c
AModRefCas ref f c -> stepModRefCas ref f c
AWriteRef ref a c -> stepWriteRef ref a c
ACasRef ref tick a c -> stepCasRef ref tick a c
ACommit t c -> stepCommit t c
AAtom stm c -> stepAtom stm c
ALift na -> stepLift na
AThrow e -> stepThrow e
AThrowTo t e c -> stepThrowTo t e c
ACatching h ma c -> stepCatching h ma c
APopCatching a -> stepPopCatching a
AMasking m ma c -> stepMasking m ma c
AResetMask b1 b2 m c -> stepResetMask b1 b2 m c
AReturn c -> stepReturn c
AMessage m c -> stepMessage m c
AStop na -> stepStop na
where
stepFork :: String
-> ((forall b. M n r s b -> M n r s b) -> Action n r s)
-> (ThreadId -> Action n r s)
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepFork n a b = return $ Right (goto (b newtid) tid threads', idSource', Fork newtid, wb, caps) where
threads' = launch tid newtid a threads
(idSource', newtid) = nextTId n idSource
stepMyTId c = simple (goto (c tid) tid threads) MyThreadId
stepGetNumCapabilities c = simple (goto (c caps) tid threads) $ GetNumCapabilities caps
stepSetNumCapabilities i c = return $ Right (goto c tid threads, idSource, SetNumCapabilities i, wb, i)
stepYield c = simple (goto c tid threads) Yield
stepPutVar cvar@(MVar cvid _) a c = synchronised $ do
(success, threads', woken) <- putIntoMVar cvar a c tid threads
simple threads' $ if success then PutVar cvid woken else BlockedPutVar cvid
stepTryPutVar cvar@(MVar cvid _) a c = synchronised $ do
(success, threads', woken) <- tryPutIntoMVar cvar a c tid threads
simple threads' $ TryPutVar cvid success woken
stepReadVar cvar@(MVar cvid _) c = synchronised $ do
(success, threads', _) <- readFromMVar cvar c tid threads
simple threads' $ if success then ReadVar cvid else BlockedReadVar cvid
stepTakeVar cvar@(MVar cvid _) c = synchronised $ do
(success, threads', woken) <- takeFromMVar cvar c tid threads
simple threads' $ if success then TakeVar cvid woken else BlockedTakeVar cvid
stepTryTakeVar cvar@(MVar cvid _) c = synchronised $ do
(success, threads', woken) <- tryTakeFromMVar cvar c tid threads
simple threads' $ TryTakeVar cvid success woken
stepReadRef cref@(CRef crid _) c = do
val <- readCRef cref tid
simple (goto (c val) tid threads) $ ReadRef crid
stepReadRefCas cref@(CRef crid _) c = do
tick <- readForTicket cref tid
simple (goto (c tick) tid threads) $ ReadRefCas crid
stepModRef cref@(CRef crid _) f c = synchronised $ do
(new, val) <- f <$> readCRef cref tid
writeImmediate cref new
simple (goto (c val) tid threads) $ ModRef crid
stepModRefCas cref@(CRef crid _) f c = synchronised $ do
tick@(Ticket _ _ old) <- readForTicket cref tid
let (new, val) = f old
void $ casCRef cref tid tick new
simple (goto (c val) tid threads) $ ModRefCas crid
stepWriteRef cref@(CRef crid _) a c = case memtype of
SequentialConsistency -> do
writeImmediate cref a
simple (goto c tid threads) $ WriteRef crid
TotalStoreOrder -> do
wb' <- bufferWrite wb (tid, Nothing) cref a
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps)
PartialStoreOrder -> do
wb' <- bufferWrite wb (tid, Just crid) cref a
return $ Right (goto c tid threads, idSource, WriteRef crid, wb', caps)
stepCasRef cref@(CRef crid _) tick a c = synchronised $ do
(suc, tick') <- casCRef cref tid tick a
simple (goto (c (suc, tick')) tid threads) $ CasRef crid suc
stepCommit t c = do
wb' <- case memtype of
SequentialConsistency ->
error "Attempting to commit under SequentialConsistency"
TotalStoreOrder -> commitWrite wb (t, Nothing)
PartialStoreOrder -> commitWrite wb (t, Just c)
return $ Right (threads, idSource, CommitRef t c, wb', caps)
stepAtom stm c = synchronised $ do
(res, idSource', trace) <- runstm stm idSource
case res of
Success _ written val ->
let (threads', woken) = wake (OnTVar written) threads
in return $ Right (goto (c val) tid threads', idSource', STM trace woken, wb, caps)
Retry touched ->
let threads' = block (OnTVar touched) tid threads
in return $ Right (threads', idSource', BlockedSTM trace, wb, caps)
Exception e -> do
res' <- stepThrow e
return $ case res' of
Right (threads', _, _, _, _) -> Right (threads', idSource', Throw, wb, caps)
Left err -> Left err
stepCatching h ma c = simple threads' Catching where
a = runCont ma (APopCatching . c)
e exc = runCont (h exc) (APopCatching . c)
threads' = goto a tid (catching e tid threads)
stepPopCatching a = simple threads' PopCatching where
threads' = goto a tid (uncatching tid threads)
stepThrow e =
case propagate (toException e) tid threads of
Just threads' -> simple threads' Throw
Nothing -> return $ Left UncaughtException
stepThrowTo t e c = synchronised $
let threads' = goto c tid threads
blocked = block (OnMask t) tid threads
in case M.lookup t threads of
Just thread
| interruptible thread -> case propagate (toException e) t threads' of
Just threads'' -> simple threads'' $ ThrowTo t
Nothing
| t == initialThread -> return $ Left UncaughtException
| otherwise -> simple (kill t threads') $ ThrowTo t
| otherwise -> simple blocked $ BlockedThrowTo t
Nothing -> simple threads' $ ThrowTo t
stepMasking :: MaskingState
-> ((forall b. M n r s b -> M n r s b) -> M n r s a)
-> (a -> Action n r s)
-> n (Either Failure (Threads n r s, IdSource, ThreadAction, WriteBuffer r, Int))
stepMasking m ma c = simple threads' $ SetMasking False m where
a = runCont (ma umask) (AResetMask False False m' . c)
m' = _masking . fromJust $ M.lookup tid threads
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> return b
resetMask typ ms = cont $ \k -> AResetMask typ True ms $ k ()
threads' = goto a tid (mask m tid threads)
stepResetMask b1 b2 m c = simple threads' act where
act = (if b1 then SetMasking else ResetMasking) b2 m
threads' = goto c tid (mask m tid threads)
stepNewVar n c = do
let (idSource', newmvid) = nextMVId n idSource
ref <- newRef Nothing
let mvar = MVar newmvid ref
return $ Right (goto (c mvar) tid threads, idSource', NewVar newmvid, wb, caps)
stepNewRef n a c = do
let (idSource', newcrid) = nextCRId n idSource
ref <- newRef (M.empty, 0, a)
let cref = CRef newcrid ref
return $ Right (goto (c cref) tid threads, idSource', NewRef newcrid, wb, caps)
stepLift na = do
a <- na
simple (goto a tid threads) LiftIO
stepReturn c = simple (goto c tid threads) Return
stepMessage m c = simple (goto c tid threads) (Message m)
stepStop na = na >> simple (kill tid threads) Stop
simple threads' act = return $ Right (threads', idSource, act, wb, caps)
synchronised ma = do
writeBarrier wb
res <- ma
return $ case res of
Right (threads', idSource', act', _, caps') -> Right (threads', idSource', act', emptyBuffer, caps')
_ -> res