module Test.DejaFu.Common
(
ThreadId(..)
, CRefId(..)
, MVarId(..)
, TVarId(..)
, initialThread
, IdSource(..)
, nextCRId
, nextMVId
, nextTVId
, nextTId
, initialIdSource
, ThreadAction(..)
, isBlock
, tvarsOf
, tvarsWritten
, tvarsRead
, Lookahead(..)
, rewind
, willRelease
, ActionType(..)
, isBarrier
, isCommit
, synchronises
, crefOf
, mvarOf
, simplifyAction
, simplifyLookahead
, TTrace
, TAction(..)
, Trace
, Decision(..)
, showTrace
, threadNames
, preEmpCount
, Failure(..)
, isInternalError
, isAbort
, isDeadlock
, isUncaughtException
, isIllegalSubconcurrency
, showFail
, MemType(..)
, MonadFailException(..)
, runRefCont
, ehead
, etail
, eidx
, efromJust
, efromList
, fatal
) where
import Control.DeepSeq (NFData(..))
import Control.Exception (Exception(..), MaskingState(..),
SomeException, displayException)
import Control.Monad.Ref (MonadRef(..))
import Data.Function (on)
import Data.List (intercalate)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe (fromMaybe, mapMaybe)
import Data.Set (Set)
import qualified Data.Set as S
data ThreadId = ThreadId (Maybe String) !Int
instance Eq ThreadId where
(ThreadId _ i) == (ThreadId _ j) = i == j
instance Ord ThreadId where
compare (ThreadId _ i) (ThreadId _ j) = compare i j
instance Show ThreadId where
show (ThreadId (Just n) _) = n
show (ThreadId Nothing i) = show i
instance NFData ThreadId where
rnf (ThreadId n i) = rnf (n, i)
data CRefId = CRefId (Maybe String) !Int
instance Eq CRefId where
(CRefId _ i) == (CRefId _ j) = i == j
instance Ord CRefId where
compare (CRefId _ i) (CRefId _ j) = compare i j
instance Show CRefId where
show (CRefId (Just n) _) = n
show (CRefId Nothing i) = show i
instance NFData CRefId where
rnf (CRefId n i) = rnf (n, i)
data MVarId = MVarId (Maybe String) !Int
instance Eq MVarId where
(MVarId _ i) == (MVarId _ j) = i == j
instance Ord MVarId where
compare (MVarId _ i) (MVarId _ j) = compare i j
instance Show MVarId where
show (MVarId (Just n) _) = n
show (MVarId Nothing i) = show i
instance NFData MVarId where
rnf (MVarId n i) = rnf (n, i)
data TVarId = TVarId (Maybe String) !Int
instance Eq TVarId where
(TVarId _ i) == (TVarId _ j) = i == j
instance Ord TVarId where
compare (TVarId _ i) (TVarId _ j) = compare i j
instance Show TVarId where
show (TVarId (Just n) _) = n
show (TVarId Nothing i) = show i
instance NFData TVarId where
rnf (TVarId n i) = rnf (n, i)
initialThread :: ThreadId
initialThread = ThreadId (Just "main") 0
data IdSource = Id
{ _nextCRId :: Int
, _nextMVId :: Int
, _nextTVId :: Int
, _nextTId :: Int
, _usedCRNames :: [String]
, _usedMVNames :: [String]
, _usedTVNames :: [String]
, _usedTNames :: [String]
} deriving (Eq, Ord, Show)
instance NFData IdSource where
rnf idsource = rnf ( _nextCRId idsource
, _nextMVId idsource
, _nextTVId idsource
, _nextTId idsource
, _usedCRNames idsource
, _usedMVNames idsource
, _usedTVNames idsource
, _usedTNames idsource
)
nextCRId :: String -> IdSource -> (IdSource, CRefId)
nextCRId name idsource = (newIdSource, newCRId) where
newIdSource = idsource { _nextCRId = newId, _usedCRNames = newUsed }
newCRId = CRefId newName newId
newId = _nextCRId idsource + 1
(newName, newUsed) = nextId name (_usedCRNames idsource)
nextMVId :: String -> IdSource -> (IdSource, MVarId)
nextMVId name idsource = (newIdSource, newMVId) where
newIdSource = idsource { _nextMVId = newId, _usedMVNames = newUsed }
newMVId = MVarId newName newId
newId = _nextMVId idsource + 1
(newName, newUsed) = nextId name (_usedMVNames idsource)
nextTVId :: String -> IdSource -> (IdSource, TVarId)
nextTVId name idsource = (newIdSource, newTVId) where
newIdSource = idsource { _nextTVId = newId, _usedTVNames = newUsed }
newTVId = TVarId newName newId
newId = _nextTVId idsource + 1
(newName, newUsed) = nextId name (_usedTVNames idsource)
nextTId :: String -> IdSource -> (IdSource, ThreadId)
nextTId name idsource = (newIdSource, newTId) where
newIdSource = idsource { _nextTId = newId, _usedTNames = newUsed }
newTId = ThreadId newName newId
newId = _nextTId idsource + 1
(newName, newUsed) = nextId name (_usedTNames idsource)
initialIdSource :: IdSource
initialIdSource = Id 0 0 0 0 [] [] [] []
data ThreadAction =
Fork ThreadId
| MyThreadId
| GetNumCapabilities Int
| SetNumCapabilities Int
| Yield
| ThreadDelay Int
| NewMVar MVarId
| PutMVar MVarId [ThreadId]
| BlockedPutMVar MVarId
| TryPutMVar MVarId Bool [ThreadId]
| ReadMVar MVarId
| TryReadMVar MVarId Bool
| BlockedReadMVar MVarId
| TakeMVar MVarId [ThreadId]
| BlockedTakeMVar MVarId
| TryTakeMVar MVarId Bool [ThreadId]
| NewCRef CRefId
| ReadCRef CRefId
| ReadCRefCas CRefId
| ModCRef CRefId
| ModCRefCas CRefId
| WriteCRef CRefId
| CasCRef CRefId Bool
| CommitCRef ThreadId CRefId
| STM TTrace [ThreadId]
| BlockedSTM TTrace
| Catching
| PopCatching
| Throw
| ThrowTo ThreadId
| BlockedThrowTo ThreadId
| Killed
| SetMasking Bool MaskingState
| ResetMasking Bool MaskingState
| LiftIO
| Return
| Stop
| Subconcurrency
| StopSubconcurrency
deriving (Eq, Show)
instance NFData ThreadAction where
rnf (Fork t) = rnf t
rnf (ThreadDelay n) = rnf n
rnf (GetNumCapabilities c) = rnf c
rnf (SetNumCapabilities c) = rnf c
rnf (NewMVar m) = rnf m
rnf (PutMVar m ts) = rnf (m, ts)
rnf (BlockedPutMVar m) = rnf m
rnf (TryPutMVar m b ts) = rnf (m, b, ts)
rnf (ReadMVar m) = rnf m
rnf (TryReadMVar m b) = rnf (m, b)
rnf (BlockedReadMVar m) = rnf m
rnf (TakeMVar m ts) = rnf (m, ts)
rnf (BlockedTakeMVar m) = rnf m
rnf (TryTakeMVar m b ts) = rnf (m, b, ts)
rnf (NewCRef c) = rnf c
rnf (ReadCRef c) = rnf c
rnf (ReadCRefCas c) = rnf c
rnf (ModCRef c) = rnf c
rnf (ModCRefCas c) = rnf c
rnf (WriteCRef c) = rnf c
rnf (CasCRef c b) = rnf (c, b)
rnf (CommitCRef t c) = rnf (t, c)
rnf (STM tr ts) = rnf (tr, ts)
rnf (BlockedSTM tr) = rnf tr
rnf (ThrowTo t) = rnf t
rnf (BlockedThrowTo t) = rnf t
rnf (SetMasking b m) = b `seq` m `seq` ()
rnf (ResetMasking b m) = b `seq` m `seq` ()
rnf a = a `seq` ()
isBlock :: ThreadAction -> Bool
isBlock (BlockedThrowTo _) = True
isBlock (BlockedTakeMVar _) = True
isBlock (BlockedReadMVar _) = True
isBlock (BlockedPutMVar _) = True
isBlock (BlockedSTM _) = True
isBlock _ = False
tvarsOf :: ThreadAction -> Set TVarId
tvarsOf act = tvarsRead act `S.union` tvarsWritten act
tvarsWritten :: ThreadAction -> Set TVarId
tvarsWritten act = S.fromList $ case act of
STM trc _ -> concatMap tvarsOf' trc
BlockedSTM trc -> concatMap tvarsOf' trc
_ -> []
where
tvarsOf' (TWrite tv) = [tv]
tvarsOf' (TOrElse ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' (TCatch ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' _ = []
tvarsRead :: ThreadAction -> Set TVarId
tvarsRead act = S.fromList $ case act of
STM trc _ -> concatMap tvarsOf' trc
BlockedSTM trc -> concatMap tvarsOf' trc
_ -> []
where
tvarsOf' (TRead tv) = [tv]
tvarsOf' (TOrElse ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' (TCatch ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' _ = []
data Lookahead =
WillFork
| WillMyThreadId
| WillGetNumCapabilities
| WillSetNumCapabilities Int
| WillYield
| WillThreadDelay Int
| WillNewMVar
| WillPutMVar MVarId
| WillTryPutMVar MVarId
| WillReadMVar MVarId
| WillTryReadMVar MVarId
| WillTakeMVar MVarId
| WillTryTakeMVar MVarId
| WillNewCRef
| WillReadCRef CRefId
| WillReadCRefCas CRefId
| WillModCRef CRefId
| WillModCRefCas CRefId
| WillWriteCRef CRefId
| WillCasCRef CRefId
| WillCommitCRef ThreadId CRefId
| WillSTM
| WillCatching
| WillPopCatching
| WillThrow
| WillThrowTo ThreadId
| WillSetMasking Bool MaskingState
| WillResetMasking Bool MaskingState
| WillLiftIO
| WillReturn
| WillStop
| WillSubconcurrency
| WillStopSubconcurrency
deriving (Eq, Show)
instance NFData Lookahead where
rnf (WillThreadDelay n) = rnf n
rnf (WillSetNumCapabilities c) = rnf c
rnf (WillPutMVar m) = rnf m
rnf (WillTryPutMVar m) = rnf m
rnf (WillReadMVar m) = rnf m
rnf (WillTryReadMVar m) = rnf m
rnf (WillTakeMVar m) = rnf m
rnf (WillTryTakeMVar m) = rnf m
rnf (WillReadCRef c) = rnf c
rnf (WillReadCRefCas c) = rnf c
rnf (WillModCRef c) = rnf c
rnf (WillModCRefCas c) = rnf c
rnf (WillWriteCRef c) = rnf c
rnf (WillCasCRef c) = rnf c
rnf (WillCommitCRef t c) = rnf (t, c)
rnf (WillThrowTo t) = rnf t
rnf (WillSetMasking b m) = b `seq` m `seq` ()
rnf (WillResetMasking b m) = b `seq` m `seq` ()
rnf l = l `seq` ()
rewind :: ThreadAction -> Maybe Lookahead
rewind (Fork _) = Just WillFork
rewind MyThreadId = Just WillMyThreadId
rewind (GetNumCapabilities _) = Just WillGetNumCapabilities
rewind (SetNumCapabilities i) = Just (WillSetNumCapabilities i)
rewind Yield = Just WillYield
rewind (ThreadDelay n) = Just (WillThreadDelay n)
rewind (NewMVar _) = Just WillNewMVar
rewind (PutMVar c _) = Just (WillPutMVar c)
rewind (BlockedPutMVar c) = Just (WillPutMVar c)
rewind (TryPutMVar c _ _) = Just (WillTryPutMVar c)
rewind (ReadMVar c) = Just (WillReadMVar c)
rewind (BlockedReadMVar c) = Just (WillReadMVar c)
rewind (TryReadMVar c _) = Just (WillTryReadMVar c)
rewind (TakeMVar c _) = Just (WillTakeMVar c)
rewind (BlockedTakeMVar c) = Just (WillTakeMVar c)
rewind (TryTakeMVar c _ _) = Just (WillTryTakeMVar c)
rewind (NewCRef _) = Just WillNewCRef
rewind (ReadCRef c) = Just (WillReadCRef c)
rewind (ReadCRefCas c) = Just (WillReadCRefCas c)
rewind (ModCRef c) = Just (WillModCRef c)
rewind (ModCRefCas c) = Just (WillModCRefCas c)
rewind (WriteCRef c) = Just (WillWriteCRef c)
rewind (CasCRef c _) = Just (WillCasCRef c)
rewind (CommitCRef t c) = Just (WillCommitCRef t c)
rewind (STM _ _) = Just WillSTM
rewind (BlockedSTM _) = Just WillSTM
rewind Catching = Just WillCatching
rewind PopCatching = Just WillPopCatching
rewind Throw = Just WillThrow
rewind (ThrowTo t) = Just (WillThrowTo t)
rewind (BlockedThrowTo t) = Just (WillThrowTo t)
rewind Killed = Nothing
rewind (SetMasking b m) = Just (WillSetMasking b m)
rewind (ResetMasking b m) = Just (WillResetMasking b m)
rewind LiftIO = Just WillLiftIO
rewind Return = Just WillReturn
rewind Stop = Just WillStop
rewind Subconcurrency = Just WillSubconcurrency
rewind StopSubconcurrency = Just WillStopSubconcurrency
willRelease :: Lookahead -> Bool
willRelease WillFork = True
willRelease WillYield = True
willRelease (WillThreadDelay _) = True
willRelease (WillPutMVar _) = True
willRelease (WillTryPutMVar _) = True
willRelease (WillReadMVar _) = True
willRelease (WillTakeMVar _) = True
willRelease (WillTryTakeMVar _) = True
willRelease WillSTM = True
willRelease WillThrow = True
willRelease (WillSetMasking _ _) = True
willRelease (WillResetMasking _ _) = True
willRelease WillStop = True
willRelease _ = False
data ActionType =
UnsynchronisedRead CRefId
| UnsynchronisedWrite CRefId
| UnsynchronisedOther
| PartiallySynchronisedCommit CRefId
| PartiallySynchronisedWrite CRefId
| PartiallySynchronisedModify CRefId
| SynchronisedModify CRefId
| SynchronisedRead MVarId
| SynchronisedWrite MVarId
| SynchronisedOther
deriving (Eq, Show)
instance NFData ActionType where
rnf (UnsynchronisedRead c) = rnf c
rnf (UnsynchronisedWrite c) = rnf c
rnf (PartiallySynchronisedCommit c) = rnf c
rnf (PartiallySynchronisedWrite c) = rnf c
rnf (PartiallySynchronisedModify c) = rnf c
rnf (SynchronisedModify c) = rnf c
rnf (SynchronisedRead m) = rnf m
rnf (SynchronisedWrite m) = rnf m
rnf a = a `seq` ()
isBarrier :: ActionType -> Bool
isBarrier (SynchronisedModify _) = True
isBarrier (SynchronisedRead _) = True
isBarrier (SynchronisedWrite _) = True
isBarrier SynchronisedOther = True
isBarrier _ = False
isCommit :: ActionType -> CRefId -> Bool
isCommit (PartiallySynchronisedCommit c) r = c == r
isCommit (PartiallySynchronisedWrite c) r = c == r
isCommit (PartiallySynchronisedModify c) r = c == r
isCommit _ _ = False
synchronises :: ActionType -> CRefId -> Bool
synchronises a r = isCommit a r || isBarrier a
crefOf :: ActionType -> Maybe CRefId
crefOf (UnsynchronisedRead r) = Just r
crefOf (UnsynchronisedWrite r) = Just r
crefOf (SynchronisedModify r) = Just r
crefOf (PartiallySynchronisedCommit r) = Just r
crefOf (PartiallySynchronisedWrite r) = Just r
crefOf (PartiallySynchronisedModify r) = Just r
crefOf _ = Nothing
mvarOf :: ActionType -> Maybe MVarId
mvarOf (SynchronisedRead c) = Just c
mvarOf (SynchronisedWrite c) = Just c
mvarOf _ = Nothing
simplifyAction :: ThreadAction -> ActionType
simplifyAction = maybe UnsynchronisedOther simplifyLookahead . rewind
simplifyLookahead :: Lookahead -> ActionType
simplifyLookahead (WillPutMVar c) = SynchronisedWrite c
simplifyLookahead (WillTryPutMVar c) = SynchronisedWrite c
simplifyLookahead (WillReadMVar c) = SynchronisedRead c
simplifyLookahead (WillTryReadMVar c) = SynchronisedRead c
simplifyLookahead (WillTakeMVar c) = SynchronisedRead c
simplifyLookahead (WillTryTakeMVar c) = SynchronisedRead c
simplifyLookahead (WillReadCRef r) = UnsynchronisedRead r
simplifyLookahead (WillReadCRefCas r) = UnsynchronisedRead r
simplifyLookahead (WillModCRef r) = SynchronisedModify r
simplifyLookahead (WillModCRefCas r) = PartiallySynchronisedModify r
simplifyLookahead (WillWriteCRef r) = UnsynchronisedWrite r
simplifyLookahead (WillCasCRef r) = PartiallySynchronisedWrite r
simplifyLookahead (WillCommitCRef _ r) = PartiallySynchronisedCommit r
simplifyLookahead WillSTM = SynchronisedOther
simplifyLookahead (WillThrowTo _) = SynchronisedOther
simplifyLookahead _ = UnsynchronisedOther
type TTrace = [TAction]
data TAction =
TNew TVarId
| TRead TVarId
| TWrite TVarId
| TRetry
| TOrElse TTrace (Maybe TTrace)
| TThrow
| TCatch TTrace (Maybe TTrace)
| TStop
deriving (Eq, Show)
instance NFData TAction where
rnf (TRead t) = rnf t
rnf (TWrite t) = rnf t
rnf (TOrElse tr mtr) = rnf (tr, mtr)
rnf (TCatch tr mtr) = rnf (tr, mtr)
rnf ta = ta `seq` ()
type Trace
= [(Decision, [(ThreadId, Lookahead)], ThreadAction)]
data Decision =
Start ThreadId
| Continue
| SwitchTo ThreadId
deriving (Eq, Show)
instance NFData Decision where
rnf (Start t) = rnf t
rnf (SwitchTo t) = rnf t
rnf d = d `seq` ()
showTrace :: Trace -> String
showTrace [] = "<trace discarded>"
showTrace trc = intercalate "\n" $ concatMap go trc : strkey where
go (_,_,CommitCRef _ _) = "C-"
go (Start (ThreadId _ i),_,_) = "S" ++ show i ++ "-"
go (SwitchTo (ThreadId _ i),_,_) = "P" ++ show i ++ "-"
go (Continue,_,_) = "-"
strkey =
[" " ++ show i ++ ": " ++ name | (i, name) <- threadNames trc]
threadNames :: Trace -> [(Int, String)]
threadNames = mapMaybe go where
go (_, _, Fork (ThreadId (Just name) i)) = Just (i, name)
go _ = Nothing
preEmpCount :: [(Decision, ThreadAction)]
-> (Decision, Lookahead)
-> Int
preEmpCount (x:xs) (d, _) = go initialThread x xs where
go _ (_, Yield) (r@(SwitchTo t, _):rest) = go t r rest
go _ (_, ThreadDelay _) (r@(SwitchTo t, _):rest) = go t r rest
go tid prior (r@(SwitchTo t, _):rest)
| isCommitThread t = go tid prior (skip rest)
| otherwise = 1 + go t r rest
go _ _ (r@(Start t, _):rest) = go t r rest
go tid _ (r@(Continue, _):rest) = go tid r rest
go _ prior [] = case (prior, d) of
((_, Yield), SwitchTo _) -> 0
((_, ThreadDelay _), SwitchTo _) -> 0
(_, SwitchTo _) -> 1
_ -> 0
isCommitThread = (< initialThread)
skip = dropWhile (not . isContextSwitch . fst)
isContextSwitch Continue = False
isContextSwitch _ = True
preEmpCount [] _ = 0
data Failure
= InternalError
| Abort
| Deadlock
| STMDeadlock
| UncaughtException SomeException
| IllegalSubconcurrency
deriving Show
instance Eq Failure where
(==) = (==) `on` _other
instance Ord Failure where
compare = compare `on` _other
instance NFData Failure where
rnf = rnf . _other
_other :: Failure -> (Int, Maybe String)
_other InternalError = (0, Nothing)
_other Abort = (1, Nothing)
_other Deadlock = (2, Nothing)
_other STMDeadlock = (3, Nothing)
_other (UncaughtException e) = (4, Just (show e))
_other IllegalSubconcurrency = (5, Nothing)
showFail :: Failure -> String
showFail Abort = "[abort]"
showFail Deadlock = "[deadlock]"
showFail STMDeadlock = "[stm-deadlock]"
showFail InternalError = "[internal-error]"
showFail (UncaughtException exc) = "[" ++ displayException exc ++ "]"
showFail IllegalSubconcurrency = "[illegal-subconcurrency]"
isInternalError :: Failure -> Bool
isInternalError InternalError = True
isInternalError _ = False
isAbort :: Failure -> Bool
isAbort Abort = True
isAbort _ = False
isDeadlock :: Failure -> Bool
isDeadlock Deadlock = True
isDeadlock STMDeadlock = True
isDeadlock _ = False
isUncaughtException :: Failure -> Bool
isUncaughtException (UncaughtException _) = True
isUncaughtException _ = False
isIllegalSubconcurrency :: Failure -> Bool
isIllegalSubconcurrency IllegalSubconcurrency = True
isIllegalSubconcurrency _ = False
data MemType =
SequentialConsistency
| TotalStoreOrder
| PartialStoreOrder
deriving (Eq, Show, Read, Ord, Enum, Bounded)
instance NFData MemType where
rnf m = m `seq` ()
newtype MonadFailException = MonadFailException String
deriving Show
instance Exception MonadFailException
runRefCont :: MonadRef r n => (n () -> x) -> (a -> Maybe b) -> ((a -> x) -> x) -> n (x, r (Maybe b))
runRefCont act f k = do
ref <- newRef Nothing
let c = k (act . writeRef ref . f)
pure (c, ref)
ehead :: String -> [a] -> a
ehead _ (x:_) = x
ehead src _ = fatal src "head: empty list"
etail :: String -> [a] -> [a]
etail _ (_:xs) = xs
etail src _ = fatal src "tail: empty list"
eidx :: String -> [a] -> Int -> a
eidx src xs i
| i < length xs = xs !! i
| otherwise = fatal src "(!!): index too large"
efromJust :: String -> Maybe a -> a
efromJust _ (Just x) = x
efromJust src _ = fatal src "fromJust: Nothing"
efromList :: String -> [a] -> NonEmpty a
efromList _ (x:xs) = x:|xs
efromList src _ = fatal src "fromList: empty list"
fatal :: String -> String -> a
fatal src msg = error ("(dejafu: " ++ src ++ ") " ++ msg)
nextId :: String -> [String] -> (Maybe String, [String])
nextId name used = (newName, newUsed) where
newName
| null name = Nothing
| occurrences > 0 = Just (name ++ "-" ++ show occurrences)
| otherwise = Just name
newUsed
| null name = used
| otherwise = name : used
occurrences = length (filter (==name) used)