module Test.DejaFu.Deterministic.Internal.Threading where
import Control.Exception (Exception, MaskingState(..), SomeException, fromException)
import Data.List (intersect, nub)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, isJust, isNothing)
import Test.DejaFu.Deterministic.Internal.Common
import qualified Data.Map.Strict as M
type Threads n r s = Map ThreadId (Thread n r s)
data Thread n r s = Thread
{ _continuation :: Action n r s
, _blocking :: Maybe BlockedOn
, _handlers :: [Handler n r s]
, _masking :: MaskingState
, _known :: [Either MVarId TVarId]
, _fullknown :: Bool
}
mkthread :: Action n r s -> Thread n r s
mkthread c = Thread c Nothing [] Unmasked [] False
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq
(~=) :: Thread n r s -> BlockedOn -> Bool
thread ~= theblock = case (_blocking thread, theblock) of
(Just (OnMVarFull _), OnMVarFull _) -> True
(Just (OnMVarEmpty _), OnMVarEmpty _) -> True
(Just (OnTVar _), OnTVar _) -> True
(Just (OnMask _), OnMask _) -> True
_ -> False
isLocked :: ThreadId -> Threads n r a -> Bool
isLocked tid ts
| allKnown = case M.lookup tid ts of
Just thread -> noRefs $ _blocking thread
Nothing -> False
| otherwise = M.null $ M.filter (isNothing . _blocking) ts
where
allKnown = all _fullknown $ M.elems ts
noRefs (Just (OnMVarFull cvarid)) = null $ findMVar cvarid
noRefs (Just (OnMVarEmpty cvarid)) = null $ findMVar cvarid
noRefs (Just (OnTVar tvids)) = null $ findTVars tvids
noRefs _ = True
findMVar cvarid = M.keys $ M.filterWithKey (check [Left cvarid]) ts
findTVars tvids = M.keys $ M.filterWithKey (check (map Right tvids)) ts
check lookingfor thetid thethread
| thetid == tid = False
| otherwise = (not . null $ lookingfor `intersect` _known thethread) && isNothing (_blocking thethread)
data Handler n r s = forall e. Exception e => Handler (e -> Action n r s)
propagate :: SomeException -> ThreadId -> Threads n r s -> Maybe (Threads n r s)
propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
Just (act, hs) -> Just $ except act hs tid threads
Nothing -> Nothing
where
go [] = Nothing
go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e
interruptible :: Thread n r s -> Bool
interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread))
catching :: Exception e => (e -> Action n r s) -> ThreadId -> Threads n r s -> Threads n r s
catching h = M.alter $ \(Just thread) -> Just $ thread { _handlers = Handler h : _handlers thread }
uncatching :: ThreadId -> Threads n r s -> Threads n r s
uncatching = M.alter $ \(Just thread) -> Just $ thread { _handlers = tail $ _handlers thread }
except :: Action n r s -> [Handler n r s] -> ThreadId -> Threads n r s -> Threads n r s
except act hs = M.alter $ \(Just thread) -> Just $ thread { _continuation = act, _handlers = hs, _blocking = Nothing }
mask :: MaskingState -> ThreadId -> Threads n r s -> Threads n r s
mask ms = M.alter $ \(Just thread) -> Just $ thread { _masking = ms }
goto :: Action n r s -> ThreadId -> Threads n r s -> Threads n r s
goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a })
launch :: ThreadId -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
launch parent tid a threads = launch' ms tid a threads where
ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads
launch' :: MaskingState -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
launch' ms tid a = M.insert tid thread where
thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms, _known = [], _fullknown = False }
umask mb = resetMask True Unmasked >> mb >>= \b -> resetMask False ms >> return b
resetMask typ m = cont $ \k -> AResetMask typ True m $ k ()
kill :: ThreadId -> Threads n r s -> Threads n r s
kill = M.delete
block :: BlockedOn -> ThreadId -> Threads n r s -> Threads n r s
block blockedOn = M.alter doBlock where
doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn }
doBlock _ = error "Invariant failure in 'block': thread does NOT exist!"
wake :: BlockedOn -> Threads n r s -> (Threads n r s, [ThreadId])
wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where
unblock thread
| isBlocked thread = thread { _blocking = Nothing }
| otherwise = thread
isBlocked thread = case (_blocking thread, blockedOn) of
(Just (OnTVar tvids), OnTVar blockedOn') -> tvids `intersect` blockedOn' /= []
(theblock, _) -> theblock == Just blockedOn
knows :: [Either MVarId TVarId] -> ThreadId -> Threads n r s -> Threads n r s
knows theids = M.alter go where
go (Just thread) = Just $ thread { _known = nub $ theids ++ _known thread }
go _ = error "Invariant failure in 'knows': thread does NOT exist!"
forgets :: [Either MVarId TVarId] -> ThreadId -> Threads n r s -> Threads n r s
forgets theids = M.alter go where
go (Just thread) = Just $ thread { _known = filter (`notElem` theids) $ _known thread }
go _ = error "Invariant failure in 'forgets': thread does NOT exist!"
fullknown :: ThreadId -> Threads n r s -> Threads n r s
fullknown = M.alter go where
go (Just thread) = Just $ thread { _fullknown = True }
go _ = error "Invariant failure in 'fullknown': thread does NOT exist!"