module Test.DejaFu.Deterministic.Internal.Threading where
import Control.Exception (Exception, MaskingState(..), SomeException(..), fromException)
import Control.Monad.Cont (cont)
import Data.List (intersect, nub)
import Data.Map (Map)
import Data.Maybe (fromMaybe, isJust, isNothing)
import Test.DejaFu.STM (CTVarId)
import Test.DejaFu.Deterministic.Internal.Common
import qualified Data.Map as M
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>))
#endif
type Threads n r s = Map ThreadId (Thread n r s)
data Thread n r s = Thread
{ _continuation :: Action n r s
, _blocking :: Maybe BlockedOn
, _handlers :: [Handler n r s]
, _masking :: MaskingState
, _known :: [Either CVarId CTVarId]
, _fullknown :: Bool
}
data BlockedOn = OnCVarFull CVarId | OnCVarEmpty CVarId | OnCTVar [CTVarId] | OnMask ThreadId deriving Eq
(~=) :: Thread n r s -> BlockedOn -> Bool
thread ~= theblock = case (_blocking thread, theblock) of
(Just (OnCVarFull _), OnCVarFull _) -> True
(Just (OnCVarEmpty _), OnCVarEmpty _) -> True
(Just (OnCTVar _), OnCTVar _) -> True
(Just (OnMask _), OnMask _) -> True
_ -> False
isLocked :: ThreadId -> Threads n r a -> Bool
isLocked tid ts
| allKnown = case M.lookup tid ts of
Just thread -> noRefs $ _blocking thread
Nothing -> False
| otherwise = M.null $ M.filter (isNothing . _blocking) ts
where
allKnown = all _fullknown $ M.elems ts
noRefs (Just (OnCVarFull cvarid)) = null $ findCVar cvarid
noRefs (Just (OnCVarEmpty cvarid)) = null $ findCVar cvarid
noRefs (Just (OnCTVar ctvids)) = null $ findCTVars ctvids
noRefs _ = True
findCVar cvarid = M.keys $ M.filterWithKey (check [Left cvarid]) ts
findCTVars ctvids = M.keys $ M.filterWithKey (check (map Right ctvids)) ts
check lookingfor thetid thethread
| thetid == tid = False
| otherwise = (not . null $ lookingfor `intersect` _known thethread) && isNothing (_blocking thethread)
data Handler n r s = forall e. Exception e => Handler (e -> Action n r s)
propagate :: SomeException -> [Handler n r s] -> Maybe (Action n r s, [Handler n r s])
propagate _ [] = Nothing
propagate e (Handler h:hs) = maybe (propagate e hs) (\act -> Just (act, hs)) $ h <$> e' where
e' = fromException e
interruptible :: Thread n r s -> Bool
interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread))
goto :: Action n r s -> ThreadId -> Threads n r s -> Threads n r s
goto a = M.alter $ \(Just thread) -> Just (thread { _continuation = a })
launch :: ThreadId -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
launch parent tid a threads = launch' mask tid a threads where
mask = fromMaybe Unmasked $ _masking <$> M.lookup parent threads
launch' :: MaskingState -> ThreadId -> ((forall b. M n r s b -> M n r s b) -> Action n r s) -> Threads n r s -> Threads n r s
launch' mask tid a = M.insert tid thread where
thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = mask, _known = [], _fullknown = False }
umask mb = resetMask True Unmasked >> mb >>= \b -> resetMask False mask >> return b
resetMask typ m = cont $ \k -> AResetMask typ True m $ k ()
kill :: ThreadId -> Threads n r s -> Threads n r s
kill = M.delete
block :: BlockedOn -> ThreadId -> Threads n r s -> Threads n r s
block blockedOn = M.alter doBlock where
doBlock (Just thread) = Just $ thread { _blocking = Just blockedOn }
doBlock _ = error "Invariant failure in 'block': thread does NOT exist!"
wake :: BlockedOn -> Threads n r s -> (Threads n r s, [ThreadId])
wake blockedOn threads = (M.map unblock threads, M.keys $ M.filter isBlocked threads) where
unblock thread
| isBlocked thread = thread { _blocking = Nothing }
| otherwise = thread
isBlocked thread = case (_blocking thread, blockedOn) of
(Just (OnCTVar ctvids), OnCTVar blockedOn') -> ctvids `intersect` blockedOn' /= []
(theblock, _) -> theblock == Just blockedOn
knows :: [Either CVarId CTVarId] -> ThreadId -> Threads n r s -> Threads n r s
knows theids = M.alter go where
go (Just thread) = Just $ thread { _known = nub $ theids ++ _known thread }
go _ = error "Invariant failure in 'knows': thread does NOT exist!"
forgets :: [Either CVarId CTVarId] -> ThreadId -> Threads n r s -> Threads n r s
forgets theids = M.alter go where
go (Just thread) = Just $ thread { _known = filter (`notElem` theids) $ _known thread }
go _ = error "Invariant failure in 'forgets': thread does NOT exist!"
fullknown :: ThreadId -> Threads n r s -> Threads n r s
fullknown = M.alter go where
go (Just thread) = Just $ thread { _fullknown = True }
go _ = error "Invariant failure in 'fullknown': thread does NOT exist!"