-- | -- This module handles timeouts by using a (single) thread sitting in threadDelay -- and the STM. One can request an IO action be performed after some number of -- seconds and later cancel that request if need be. -- -- The number of threads used is constant. module Control.Timeout ( addTimeout , addTimeoutAtomic , cancelTimeout , TimeoutTag) where import Control.Concurrent import Control.Concurrent.STM import Control.Monad import System.IO.Unsafe import qualified Data.Map as Map import Data.Time.Clock.POSIX -- | This is set, atomically, to true when the manager thread is started. -- This thread isn't started unless someone actually creates a timeout managerThreadStarted :: TVar Bool managerThreadStarted = unsafePerformIO $ newTVarIO False -- | When a timeout thread times out, it compares the first element of this -- tuple against the value that it was created with. If they don't match then -- it's no longer the current timeout thread and it exits. Otherwise, it sets -- the second element to true and exits. signal :: TVar (Int, Bool) signal = unsafePerformIO $ newTVarIO (0, False) -- | This is a map of all the timeouts. It maps the absolute time -- that the timeout expires at to a list of tagged actions -- to perform at that time. For a given value in the map, the Ints of -- every element in the list must be unique. timeouts :: TVar (Map.Map POSIXTime [(Int, IO ())]) timeouts = unsafePerformIO $ newTVarIO Map.empty -- Here's how everything works. The addTimeout and cancelTimeout functions -- alter the above globals using the STM. The first call to addTimeout will -- start a manager thread which watches @signal@ and @timeouts@ for changes: -- either the timeoutThread has completed or someone has added/removed the -- least element of timeouts. -- -- In the first case, @timeouts@ is updated by removing all the expired -- timeouts and their actions are performed. In both cases, the time to the -- next timeout is calculated and a timeoutThread is forked to sleep for that -- length of time before signaling via @signal@. -- -- Timeouts are identified by their absolute time value and the unique tag -- number for their action at that time. When creating a timeout, that pair -- is wrapped in the (opaque) TimeoutTag type and returned. When canceling a -- timeout, the list of actions for the given absolute time is filtered to -- remove the indicated action. Because of this, the tags for a given absolute -- time must be unique. This is achieved by giving the first element a tag of 0 -- and giving subsequent elements a tag one greater than the current max. -- -- Each timeoutThread is given a tag (of a different type to the tags in the -- last paragraph) so that @signal@ is never set by an old timeoutThread -- which hasn't died yet. -- | Get the list of expired timers from a map of timeouts expiredTimers :: POSIXTime -- ^ the current time -> Map.Map POSIXTime a -- ^ the timeouts map -> ([a], Map.Map POSIXTime a) -- ^ the list of actions and a new map expiredTimers curtime m = do unfoldrWithValue f m where f m = case Map.minViewWithKey m of Nothing -> Nothing Just ((time, action), m') -> if time <= curtime then Just (action, m') else Nothing -- | Run the actions for all expired timers in the @timeouts@ global. Update -- that global with a new Map, less the expired timeouts. runExpiredTimers :: (Monad m) => POSIXTime -- ^ the current time -> TVar (Map.Map POSIXTime [(Int, m ())]) -- ^ the timeouts map -> STM (m ()) runExpiredTimers currentTime tm = do m <- readTVar tm let (actions, m') = expiredTimers currentTime m when (length actions > 0) $ writeTVar tm m' return $ (do sequence $ map (sequence . map snd) actions; return ()) -- | A version of unfoldr which returns the final value as well. Note that -- the resulting list comes off in reverse order unfoldrWithValue :: (b -> Maybe (a, b)) -> b -> ([a], b) unfoldrWithValue f i = inner f i [] where inner f i acc = case f i of Nothing -> (acc, i) Just (v, i') -> inner f i' $ v : acc -- | This is a thread which waits for the given number of milliseconds -- and tries to set the snd element of the global @signal@ to true, iff -- the first element of that global is equal to its tag number. timeoutThread :: Int -- ^ the id of this timeout, see @signal@ -> POSIXTime -- ^ the time to wakeup -> IO () timeoutThread id targetTime = do currentTime <- getPOSIXTime let deltausecs = truncate $ fromRational $ toRational ((targetTime - currentTime) * 1000000) when (deltausecs > 0) $ threadDelay deltausecs atomically (do (id', _) <- readTVar signal when (id' == id) $ writeTVar signal (id, True)) -- | This is an opaque type of timeouts. A value of this type is returned -- when creating a timeout and can be used to cancel the same timeout. newtype TimeoutTag = TimeoutTag (POSIXTime, Int) -- | Add an action to be performed at some point in the future. The action will -- occur inside a thread which is dedicated to performing them so it should -- run quickly and certainly should not block on IO etc. addTimeout :: Float -- ^ the number of seconds in the future to perform the action -> (IO ()) -- ^ the action to perform -> IO TimeoutTag addTimeout delta action = addTimeoutAtomic delta >>= (\x -> atomically $ x action) -- | Similar in function to addTimeout above, this call splits the IO and STM -- parts of the process so that a timeout can be added atomically. Consider -- the following code: -- -- * We add a timeout with an action which reads from a global TVar -- -- * We add the TimeoutTag (in case we wish to handle the timeout) and -- some bookkeeping data to the global TVar and trigger some external -- action (i.e. a network request) -- -- In this case, the timeout could occur before the bookkeeping is added. Now -- the timeout code won't find the correct state. If we switch the two actions -- then we don't have the TimeoutTag to add to the bookkeeping structure and we -- would need another TVar, or some such, to fill in later. addTimeoutAtomic :: Float -- ^ the number of seconds in the future to perform the action -> IO (IO () -> STM TimeoutTag) -- ^ an action to add the timeout and return the tag addTimeoutAtomic delta = do currentTime <- getPOSIXTime let future = currentTime + (fromRational $ toRational delta) stmAction :: IO () -> STM TimeoutTag stmAction action = do m <- readTVar timeouts case Map.lookup future m of Nothing -> do writeTVar timeouts $ Map.insert future [(0, action)] m return $ TimeoutTag (future, 0) Just xs -> do let magic = 1 + (maximum $ map fst xs) writeTVar timeouts $ Map.insert future ((magic, action) : xs) m return $ TimeoutTag (future, magic) -- If the manager thread isn't running, start it now. startp <- atomically (do started <- readTVar managerThreadStarted when (not started) $ writeTVar managerThreadStarted True return $ not started) when startp $ forkIO (timeoutManagerThread timeouts signal 0 Nothing) >> return () return stmAction -- | Remove a timeout. This function never fails, but will return False if the -- given timeout couldn't be found. This may be because cancelTimeout has -- already been called with this tag, or because the timeout has already -- fired. Note that, since timeouts are IO actions, they don't run atomically. -- Thus it's possible that this call returns False and that the timeout is -- currently in the process of running. -- -- Note that one should never call cancelTimeout twice with the same tag since -- it's possible that the tag will be reused and thus the second call could -- cancel a different timeout. cancelTimeout :: TimeoutTag -- ^ the tag returned by addTimeout -> STM Bool -- ^ returns False if the timeout didn't exist cancelTimeout (TimeoutTag (future, n)) = do m <- readTVar timeouts case Map.lookup future m of Nothing -> return False Just xs -> do let xs' = filter (\(t, _) -> t /= n) xs m' = Map.insert future xs' m writeTVar timeouts m' return $ length xs' /= length xs timeoutManagerThread :: TVar (Map.Map POSIXTime [(Int, IO ())]) -> TVar (Int, Bool) -> POSIXTime -- ^ the current minimum time -> Maybe ThreadId -- ^ the id of the current timeoutThread -> IO () timeoutManagerThread tm signal currentMin mthid = do -- the event is either a signal from the timeoutThread (True) or that a -- timeout has been added to the timeouts map (False) (event, currentTag) <- atomically (do s <- readTVar signal let currentTag = fst s if snd s then return (True, currentTag) else (do m <- readTVar tm if (Map.null m && currentMin /= 0) || (not (Map.null m) && fst (Map.findMin m) /= currentMin) then return (False, currentTag) else retry)) currentTime <- getPOSIXTime when event $ atomically (runExpiredTimers currentTime tm) >>= id minTimeout <- atomically (do m <- readTVar timeouts if Map.null m then return 0 else return $ fst $ Map.findMin m) let nextTag = currentTag + 1 atomically $ writeTVar signal (nextTag, False) case mthid of Nothing -> return () Just x -> killThread x tid <- case minTimeout of 0 -> return Nothing x -> forkIO (timeoutThread nextTag minTimeout) >>= return . Just timeoutManagerThread tm signal minTimeout tid