{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
module Test.DejaFu.Conc.Internal.Threading where
import Control.Exception (Exception, MaskingState(..),
SomeException, fromException)
import Data.List (intersect)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust)
import GHC.Stack (HasCallStack)
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Internal
import Test.DejaFu.Types
type Threads n = Map ThreadId (Thread n)
data Thread n = Thread
{ _continuation :: Action n
, _blocking :: Maybe BlockedOn
, _handlers :: [Handler n]
, _masking :: MaskingState
, _bound :: Maybe (BoundThread n (Action n))
}
mkthread :: Action n -> Thread n
mkthread c = Thread c Nothing [] Unmasked Nothing
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq
(~=) :: Thread n -> BlockedOn -> Bool
thread ~= theblock = case (_blocking thread, theblock) of
(Just (OnMVarFull _), OnMVarFull _) -> True
(Just (OnMVarEmpty _), OnMVarEmpty _) -> True
(Just (OnTVar _), OnTVar _) -> True
(Just (OnMask _), OnMask _) -> True
_ -> False
data Handler n = forall e. Exception e => Handler (e -> MaskingState -> Action n)
propagate :: HasCallStack => SomeException -> ThreadId -> Threads n -> Maybe (Threads n)
propagate e tid threads = raise <$> propagate' handlers where
handlers = _handlers (elookup tid threads)
raise (act, hs) = except act hs tid threads
propagate' [] = Nothing
propagate' (Handler h:hs) = maybe (propagate' hs) (\act -> Just (act, hs)) $ h <$> fromException e
interruptible :: Thread n -> Bool
interruptible thread =
_masking thread == Unmasked ||
(_masking thread == MaskedInterruptible && isJust (_blocking thread))
catching :: (Exception e, HasCallStack) => (e -> Action n) -> ThreadId -> Threads n -> Threads n
catching h = eadjust $ \thread ->
let ms0 = _masking thread
h' = Handler $ \e ms -> (if ms /= ms0 then AResetMask False False ms0 else id) (h e)
in thread { _handlers = h' : _handlers thread }
uncatching :: HasCallStack => ThreadId -> Threads n -> Threads n
uncatching = eadjust $ \thread ->
thread { _handlers = etail (_handlers thread) }
except :: HasCallStack => (MaskingState -> Action n) -> [Handler n] -> ThreadId -> Threads n -> Threads n
except actf hs = eadjust $ \thread -> thread
{ _continuation = actf (_masking thread)
, _handlers = hs
, _blocking = Nothing
}
mask :: HasCallStack => MaskingState -> ThreadId -> Threads n -> Threads n
mask ms = eadjust $ \thread -> thread { _masking = ms }
goto :: HasCallStack => Action n -> ThreadId -> Threads n -> Threads n
goto a = eadjust $ \thread -> thread { _continuation = a }
launch :: HasCallStack => ThreadId -> ThreadId -> ((forall b. ModelConc n b -> ModelConc n b) -> Action n) -> Threads n -> Threads n
launch parent tid a threads = launch' ms tid a threads where
ms = _masking (elookup parent threads)
launch' :: HasCallStack => MaskingState -> ThreadId -> ((forall b. ModelConc n b -> ModelConc n b) -> Action n) -> Threads n -> Threads n
launch' ms tid a = einsert tid thread where
thread = Thread (a umask) Nothing [] ms Nothing
umask mb = resetMask True Unmasked >> mb >>= \b -> resetMask False ms >> pure b
resetMask typ m = ModelConc $ \k -> AResetMask typ True m $ k ()
block :: HasCallStack => BlockedOn -> ThreadId -> Threads n -> Threads n
block blockedOn = eadjust $ \thread -> thread { _blocking = Just blockedOn }
wake :: BlockedOn -> Threads n -> (Threads n, [ThreadId])
wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where
unblock thread
| isBlocked thread = thread { _blocking = Nothing }
| otherwise = thread
isBlocked thread = case (_blocking thread, blockedOn) of
(Just (OnTVar tvids), OnTVar blockedOn') -> tvids `intersect` blockedOn' /= []
(theblock, _) -> theblock == Just blockedOn
makeBound :: (MonadDejaFu n, HasCallStack)
=> n (BoundThread n (Action n)) -> ThreadId -> Threads n -> n (Threads n)
makeBound fbt tid threads = do
bt <- fbt
pure (eadjust (\t -> t { _bound = Just bt }) tid threads)
kill :: (MonadDejaFu n, HasCallStack) => ThreadId -> Threads n -> n (Threads n)
kill tid threads = do
let thread = elookup tid threads
maybe (pure ()) killBoundThread (_bound thread)
pure (M.delete tid threads)
runLiftedAct :: MonadDejaFu n => ThreadId -> Threads n -> n (Action n) -> n (Action n)
runLiftedAct tid threads ma = case _bound =<< M.lookup tid threads of
Just bt -> runInBoundThread bt ma
Nothing -> ma