{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE RankNTypes #-} -- | -- Module : Test.DejaFu.Conc.Internal.Threading -- Copyright : (c) 2016--2017 Michael Walker -- License : MIT -- Maintainer : Michael Walker -- Stability : experimental -- Portability : ExistentialQuantification, RankNTypes -- -- Operations and types for threads. This module is NOT considered to -- form part of the public interface of this library. module Test.DejaFu.Conc.Internal.Threading where import qualified Control.Concurrent.Classy as C import Control.Exception (Exception, MaskingState(..), SomeException, fromException) import Control.Monad (forever) import Data.List (intersect) import Data.Map.Strict (Map) import qualified Data.Map.Strict as M import Data.Maybe (isJust) import Test.DejaFu.Conc.Internal.Common import Test.DejaFu.Internal import Test.DejaFu.Types -------------------------------------------------------------------------------- -- * Threads -- | Threads are stored in a map index by 'ThreadId'. type Threads n r = Map ThreadId (Thread n r) -- | All the state of a thread. data Thread n r = Thread { _continuation :: Action n r -- ^ The next action to execute. , _blocking :: Maybe BlockedOn -- ^ The state of any blocks. , _handlers :: [Handler n r] -- ^ Stack of exception handlers , _masking :: MaskingState -- ^ The exception masking state. , _bound :: Maybe (BoundThread n r) -- ^ State for the associated bound thread, if it exists. } -- | The state of a bound thread. data BoundThread n r = BoundThread { _runboundIO :: C.MVar n (n (Action n r)) -- ^ Run an @IO@ action in the bound thread by writing to this. , _getboundIO :: C.MVar n (Action n r) -- ^ Get the result of the above by reading from this. , _boundTId :: C.ThreadId n -- ^ Thread ID } -- | Construct a thread with just one action mkthread :: Action n r -> Thread n r mkthread c = Thread c Nothing [] Unmasked Nothing -------------------------------------------------------------------------------- -- * Blocking -- | A @BlockedOn@ is used to determine what sort of variable a thread -- is blocked on. data BlockedOn = OnMVarFull MVarId | OnMVarEmpty MVarId | OnTVar [TVarId] | OnMask ThreadId deriving Eq -- | Determine if a thread is blocked in a certain way. (~=) :: Thread n r -> BlockedOn -> Bool thread ~= theblock = case (_blocking thread, theblock) of (Just (OnMVarFull _), OnMVarFull _) -> True (Just (OnMVarEmpty _), OnMVarEmpty _) -> True (Just (OnTVar _), OnTVar _) -> True (Just (OnMask _), OnMask _) -> True _ -> False -------------------------------------------------------------------------------- -- * Exceptions -- | An exception handler. data Handler n r = forall e. Exception e => Handler (e -> MaskingState -> Action n r) -- | Propagate an exception upwards, finding the closest handler -- which can deal with it. propagate :: SomeException -> ThreadId -> Threads n r -> Maybe (Threads n r) propagate e tid threads = raise <$> propagate' handlers where handlers = _handlers (elookup "propagate" tid threads) raise (act, hs) = except act hs tid threads propagate' [] = Nothing propagate' (Handler h:hs) = maybe (propagate' hs) (\act -> Just (act, hs)) $ h <$> fromException e -- | Check if a thread can be interrupted by an exception. interruptible :: Thread n r -> Bool interruptible thread = _masking thread == Unmasked || (_masking thread == MaskedInterruptible && isJust (_blocking thread)) -- | Register a new exception handler. catching :: Exception e => (e -> Action n r) -> ThreadId -> Threads n r -> Threads n r catching h = eadjust "catching" $ \thread -> let ms0 = _masking thread h' = Handler $ \e ms -> (if ms /= ms0 then AResetMask False False ms0 else id) (h e) in thread { _handlers = h' : _handlers thread } -- | Remove the most recent exception handler. uncatching :: ThreadId -> Threads n r -> Threads n r uncatching = eadjust "uncatching" $ \thread -> thread { _handlers = etail "uncatching" (_handlers thread) } -- | Raise an exception in a thread. except :: (MaskingState -> Action n r) -> [Handler n r] -> ThreadId -> Threads n r -> Threads n r except actf hs = eadjust "except" $ \thread -> thread { _continuation = actf (_masking thread) , _handlers = hs , _blocking = Nothing } -- | Set the masking state of a thread. mask :: MaskingState -> ThreadId -> Threads n r -> Threads n r mask ms = eadjust "mask" $ \thread -> thread { _masking = ms } -------------------------------------------------------------------------------- -- * Manipulating threads -- | Replace the @Action@ of a thread. goto :: Action n r -> ThreadId -> Threads n r -> Threads n r goto a = eadjust "goto" $ \thread -> thread { _continuation = a } -- | Start a thread with the given ID, inheriting the masking state -- from the parent thread. This ID must not already be in use! launch :: ThreadId -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r launch parent tid a threads = launch' ms tid a threads where ms = _masking (elookup "launch" parent threads) -- | Start a thread with the given ID and masking state. This must not already be in use! launch' :: MaskingState -> ThreadId -> ((forall b. M n r b -> M n r b) -> Action n r) -> Threads n r -> Threads n r launch' ms tid a = einsert "launch'" tid thread where thread = Thread (a umask) Nothing [] ms Nothing umask mb = resetMask True Unmasked >> mb >>= \b -> resetMask False ms >> pure b resetMask typ m = cont $ \k -> AResetMask typ True m $ k () -- | Block a thread. block :: BlockedOn -> ThreadId -> Threads n r -> Threads n r block blockedOn = eadjust "block" $ \thread -> thread { _blocking = Just blockedOn } -- | Unblock all threads waiting on the appropriate block. For 'TVar' -- blocks, this will wake all threads waiting on at least one of the -- given 'TVar's. wake :: BlockedOn -> Threads n r -> (Threads n r, [ThreadId]) wake blockedOn threads = (unblock <$> threads, M.keys $ M.filter isBlocked threads) where unblock thread | isBlocked thread = thread { _blocking = Nothing } | otherwise = thread isBlocked thread = case (_blocking thread, blockedOn) of (Just (OnTVar tvids), OnTVar blockedOn') -> tvids `intersect` blockedOn' /= [] (theblock, _) -> theblock == Just blockedOn ------------------------------------------------------------------------------- -- ** Bound threads -- | Turn a thread into a bound thread. makeBound :: C.MonadConc n => ThreadId -> Threads n r -> n (Threads n r) makeBound tid threads = do runboundIO <- C.newEmptyMVar getboundIO <- C.newEmptyMVar btid <- C.forkOSN ("bound worker for '" ++ show tid ++ "'") (go runboundIO getboundIO) let bt = BoundThread runboundIO getboundIO btid pure (eadjust "makeBound" (\t -> t { _bound = Just bt }) tid threads) where go runboundIO getboundIO = forever $ do na <- C.takeMVar runboundIO C.putMVar getboundIO =<< na -- | Kill a thread and remove it from the thread map. -- -- If the thread is bound, the worker thread is cleaned up. kill :: C.MonadConc n => ThreadId -> Threads n r -> n (Threads n r) kill tid threads = do let thread = elookup "kill" tid threads maybe (pure ()) (C.killThread . _boundTId) (_bound thread) pure (M.delete tid threads) -- | Run an action. -- -- If the thread is bound, the action is run in the worker thread. runLiftedAct :: C.MonadConc n => ThreadId -> Threads n r -> n (Action n r) -> n (Action n r) runLiftedAct tid threads ma = case _bound =<< M.lookup tid threads of Just bt -> do C.putMVar (_runboundIO bt) ma C.takeMVar (_getboundIO bt) Nothing -> ma