module Test.DejaFu.Conc.Internal.Threading where
import Control.Exception (Exception, MaskingState(..),
SomeException, fromException)
import Data.List (intersect)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe, isJust)
import Test.DejaFu.Common
import Test.DejaFu.Conc.Internal.Common
import qualified Data.Map.Strict as M
type Threads n r = Map ThreadId (Thread n r)
data Thread n r = Thread
{ _continuation :: Action n r
, _blocking :: Maybe BlockedOn
, _handlers :: [Handler n r]
, _masking :: MaskingState
}
mkthread :: Action n r -> Thread n r
mkthread c = Thread c Nothing [] Unmasked
data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq
(~=) :: Thread n r -> BlockedOn -> Bool
thread ~= theblock = case (_blocking thread, theblock) of
(Just (OnMVarFull _), OnMVarFull _) -> True
(Just (OnMVarEmpty _), OnMVarEmpty _) -> True
(Just (OnTVar _), OnTVar _) -> True
(Just (OnMask _), OnMask _) -> True
_ -> False
data Handler n r = forall e. Exception e => Handler (e -> MaskingState -> Action n r)
propagate :: SomeException -> ThreadId -> Threads n r -> Maybe (Threads n r)
propagate e tid threads = case M.lookup tid threads >>= go . _handlers of
Just (act, hs) -> Just $ except act hs tid threads
Nothing -> Nothing
where
go [] = Nothing
go (Handler h:hs) = maybe (go hs) (\act -> Just (act, hs)) $ h <$> fromException e
interruptible :: Thread n r -> Bool
interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread))
catching :: Exception e => (e -> Action n r) -> ThreadId -> Threads n r -> Threads n r
catching h = M.adjust $ \thread ->
let ms0 = _masking thread
h' = Handler $ \e ms -> (if ms /= ms0 then AResetMask False False ms0 else id) (h e)
in thread { _handlers = h' : _handlers thread }
uncatching :: ThreadId -> Threads n r -> Threads n r
uncatching = M.adjust $ \thread -> thread { _handlers = tail $ _handlers thread }
except :: (MaskingState -> Action n r) -> [Handler n r] -> ThreadId -> Threads n r -> Threads n r
except actf hs = M.adjust $ \thread -> thread
{ _continuation = actf (_masking thread)
, _handlers = hs
, _blocking = Nothing
}
mask :: MaskingState -> ThreadId -> Threads n r -> Threads n r
mask ms = M.adjust $ \thread -> thread { _masking = ms }
goto :: Action n r -> ThreadId -> Threads n r -> Threads n r
goto a = M.adjust $ \thread -> thread { _continuation = a }
launch :: ThreadId -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch parent tid a threads = launch' ms tid a threads where
ms = fromMaybe Unmasked $ _masking <$> M.lookup parent threads
launch' :: MaskingState -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r
launch' ms tid a = M.insert tid thread where
thread = Thread { _continuation = a umask, _blocking = Nothing, _handlers = [], _masking = ms }
umask mb = resetMask True Unmasked >> mb >>= \b -> resetMask False ms >> pure b
resetMask typ m = cont $ \k -> AResetMask typ True m $ k ()
kill :: ThreadId -> Threads n r -> Threads n r
kill = M.delete
block :: BlockedOn -> ThreadId -> Threads n r -> Threads n r
block blockedOn = M.adjust $ \thread -> thread { _blocking = Just blockedOn }
wake :: BlockedOn -> Threads n r -> (Threads n r, [ThreadId])
wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where
unblock thread
| isBlocked thread = thread { _blocking = Nothing }
| otherwise = thread
isBlocked thread = case (_blocking thread, blockedOn) of
(Just (OnTVar tvids), OnTVar blockedOn') -> tvids `intersect` blockedOn' /= []
(theblock, _) -> theblock == Just blockedOn