{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} module Snap.Internal.Http.Server.TimeoutManager ( TimeoutManager , TimeoutThread , initialize , stop , register , tickle , set , modify , cancel ) where ------------------------------------------------------------------------------ import Control.Exception (evaluate, finally) import qualified Control.Exception as E import Control.Monad (Monad (return, (>>=)), mapM_, void, when) import qualified Data.ByteString.Char8 as S import Data.IORef (IORef, newIORef, readIORef, writeIORef) import Prelude (Bool, Double, IO, Int, Show (..), const, fromIntegral, max, null, otherwise, round, ($), ($!), (+), (++), (-), (.), (<=), (==)) ------------------------------------------------------------------------------ import Control.Concurrent (MVar, newEmptyMVar, putMVar, readMVar, takeMVar, tryPutMVar) ------------------------------------------------------------------------------ import Snap.Internal.Http.Server.Clock (ClockTime) import qualified Snap.Internal.Http.Server.Clock as Clock import Snap.Internal.Http.Server.Common (atomicModifyIORef', eatException) import qualified Snap.Internal.Http.Server.Thread as T ------------------------------------------------------------------------------ type State = ClockTime canceled :: State canceled = 0 isCanceled :: State -> Bool isCanceled = (== 0) ------------------------------------------------------------------------------ data TimeoutThread = TimeoutThread { _thread :: !T.SnapThread , _state :: !(IORef State) , _hGetTime :: !(IO ClockTime) } instance Show TimeoutThread where show = show . _thread ------------------------------------------------------------------------------ -- | Given a 'State' value and the current time, apply the given modification -- function to the amount of time remaining. -- smap :: ClockTime -> (ClockTime -> ClockTime) -> State -> State smap now f deadline | isCanceled deadline = deadline | otherwise = t' where remaining = max 0 (deadline - now) newremaining = f remaining t' = now + newremaining ------------------------------------------------------------------------------ data TimeoutManager = TimeoutManager { _defaultTimeout :: !ClockTime , _pollInterval :: !ClockTime , _getTime :: !(IO ClockTime) , _threads :: !(IORef [TimeoutThread]) , _morePlease :: !(MVar ()) , _managerThread :: !(MVar T.SnapThread) } ------------------------------------------------------------------------------ -- | Create a new TimeoutManager. initialize :: Double -- ^ default timeout -> Double -- ^ poll interval -> IO ClockTime -- ^ function to get current time -> IO TimeoutManager initialize defaultTimeout interval getTime = E.uninterruptibleMask_ $ do conns <- newIORef [] mp <- newEmptyMVar mthr <- newEmptyMVar let tm = TimeoutManager (Clock.fromSecs defaultTimeout) (Clock.fromSecs interval) getTime conns mp mthr thr <- T.fork "snap-server: timeout manager" $ managerThread tm putMVar mthr thr return tm ------------------------------------------------------------------------------ -- | Stop a TimeoutManager. stop :: TimeoutManager -> IO () stop tm = readMVar (_managerThread tm) >>= T.cancelAndWait ------------------------------------------------------------------------------ wakeup :: TimeoutManager -> IO () wakeup tm = void $ tryPutMVar (_morePlease tm) $! () ------------------------------------------------------------------------------ -- | Register a new thread with the TimeoutManager. register :: TimeoutManager -- ^ manager to register -- with -> S.ByteString -- ^ thread label -> ((forall a . IO a -> IO a) -> IO ()) -- ^ thread action to run -> IO TimeoutThread register tm label action = do now <- getTime let !state = now + defaultTimeout stateRef <- newIORef state th <- E.uninterruptibleMask_ $ do t <- T.fork label action let h = TimeoutThread t stateRef getTime atomicModifyIORef' threads (\x -> (h:x, ())) >>= evaluate return $! h wakeup tm return th where getTime = _getTime tm threads = _threads tm defaultTimeout = _defaultTimeout tm ------------------------------------------------------------------------------ -- | Tickle the timeout on a connection to be at least N seconds into the -- future. If the existing timeout is set for M seconds from now, where M > N, -- then the timeout is unaffected. tickle :: TimeoutThread -> Int -> IO () tickle th = modify th . max {-# INLINE tickle #-} ------------------------------------------------------------------------------ -- | Set the timeout on a connection to be N seconds into the future. set :: TimeoutThread -> Int -> IO () set th = modify th . const {-# INLINE set #-} ------------------------------------------------------------------------------ -- | Modify the timeout with the given function. modify :: TimeoutThread -> (Int -> Int) -> IO () modify th f = do now <- getTime state <- readIORef stateRef let !state' = smap now f' state writeIORef stateRef state' where f' !x = Clock.fromSecs $! fromIntegral $ f $ round $ Clock.toSecs x getTime = _hGetTime th stateRef = _state th {-# INLINE modify #-} ------------------------------------------------------------------------------ -- | Cancel a timeout. cancel :: TimeoutThread -> IO () cancel h = E.uninterruptibleMask_ $ do writeIORef (_state h) canceled T.cancel $ _thread h {-# INLINE cancel #-} ------------------------------------------------------------------------------ managerThread :: TimeoutManager -> (forall a. IO a -> IO a) -> IO () managerThread tm restore = restore loop `finally` cleanup where cleanup = E.uninterruptibleMask_ $ eatException (readIORef threads >>= destroyAll) -------------------------------------------------------------------------- getTime = _getTime tm morePlease = _morePlease tm pollInterval = _pollInterval tm threads = _threads tm -------------------------------------------------------------------------- loop = do now <- getTime E.uninterruptibleMask $ \restore' -> do handles <- atomicModifyIORef' threads (\x -> ([], x)) if null handles then do restore' $ takeMVar morePlease else do handles' <- processHandles now handles atomicModifyIORef' threads (\x -> (handles' ++ x, ())) >>= evaluate Clock.sleepFor pollInterval loop -------------------------------------------------------------------------- processHandles now handles = go handles [] where go [] !kept = return $! kept go (x:xs) !kept = do !state <- readIORef $ _state x !kept' <- if isCanceled state then do b <- T.isFinished (_thread x) return $! if b then kept else (x:kept) else do when (state <= now) $ do T.cancel (_thread x) writeIORef (_state x) canceled return (x:kept) go xs kept' -------------------------------------------------------------------------- destroyAll xs = do mapM_ (T.cancel . _thread) xs mapM_ (T.wait . _thread) xs