{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Snap.Internal.Http.Server.TimeoutManager
  ( TimeoutManager
  , TimeoutThread
  , initialize
  , stop
  , register
  , tickle
  , set
  , modify
  , cancel
  ) where

------------------------------------------------------------------------------
import           Control.Exception                (evaluate, finally)
import qualified Control.Exception                as E
import           Control.Monad                    (Monad (return, (>>=)), mapM_, void, when)
import qualified Data.ByteString.Char8            as S
import           Data.IORef                       (IORef, newIORef, readIORef, writeIORef)
import           Prelude                          (Bool, Double, IO, Int, Show (..), const, fromIntegral, max, null, otherwise, round, ($), ($!), (+), (++), (-), (.), (<=), (==))
------------------------------------------------------------------------------
import           Control.Concurrent               (MVar, newEmptyMVar, putMVar, readMVar, takeMVar, tryPutMVar)
------------------------------------------------------------------------------
import           Snap.Internal.Http.Server.Clock  (ClockTime)
import qualified Snap.Internal.Http.Server.Clock  as Clock
import           Snap.Internal.Http.Server.Common (atomicModifyIORef', eatException)
import qualified Snap.Internal.Http.Server.Thread as T


------------------------------------------------------------------------------
type State = ClockTime

canceled :: State
canceled :: State
canceled = State
0

isCanceled :: State -> Bool
isCanceled :: State -> Bool
isCanceled = (forall a. Eq a => a -> a -> Bool
== State
0)


------------------------------------------------------------------------------
data TimeoutThread = TimeoutThread {
      TimeoutThread -> SnapThread
_thread     :: !T.SnapThread
    , TimeoutThread -> IORef State
_state      :: !(IORef State)
    , TimeoutThread -> IO State
_hGetTime   :: !(IO ClockTime)
    }

instance Show TimeoutThread where
    show :: TimeoutThread -> String
show = forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeoutThread -> SnapThread
_thread


------------------------------------------------------------------------------
-- | Given a 'State' value and the current time, apply the given modification
-- function to the amount of time remaining.
--
smap :: ClockTime -> (ClockTime -> ClockTime) -> State -> State
smap :: State -> (State -> State) -> State -> State
smap State
now State -> State
f State
deadline | State -> Bool
isCanceled State
deadline = State
deadline
                    | Bool
otherwise = State
t'
  where
    remaining :: State
remaining    = forall a. Ord a => a -> a -> a
max State
0 (State
deadline forall a. Num a => a -> a -> a
- State
now)
    newremaining :: State
newremaining = State -> State
f State
remaining
    t' :: State
t'           = State
now forall a. Num a => a -> a -> a
+ State
newremaining


------------------------------------------------------------------------------
data TimeoutManager = TimeoutManager {
      TimeoutManager -> State
_defaultTimeout :: !ClockTime
    , TimeoutManager -> State
_pollInterval   :: !ClockTime
    , TimeoutManager -> IO State
_getTime        :: !(IO ClockTime)
    , TimeoutManager -> IORef [TimeoutThread]
_threads        :: !(IORef [TimeoutThread])
    , TimeoutManager -> MVar ()
_morePlease     :: !(MVar ())
    , TimeoutManager -> MVar SnapThread
_managerThread  :: !(MVar T.SnapThread)
    }


------------------------------------------------------------------------------
-- | Create a new TimeoutManager.
initialize :: Double            -- ^ default timeout
           -> Double            -- ^ poll interval
           -> IO ClockTime      -- ^ function to get current time
           -> IO TimeoutManager
initialize :: Double -> Double -> IO State -> IO TimeoutManager
initialize Double
defaultTimeout Double
interval IO State
getTime = forall a. IO a -> IO a
E.uninterruptibleMask_ forall a b. (a -> b) -> a -> b
$ do
    IORef [TimeoutThread]
conns <- forall a. a -> IO (IORef a)
newIORef []
    MVar ()
mp    <- forall a. IO (MVar a)
newEmptyMVar
    MVar SnapThread
mthr  <- forall a. IO (MVar a)
newEmptyMVar

    let tm :: TimeoutManager
tm = State
-> State
-> IO State
-> IORef [TimeoutThread]
-> MVar ()
-> MVar SnapThread
-> TimeoutManager
TimeoutManager (Double -> State
Clock.fromSecs Double
defaultTimeout)
                            (Double -> State
Clock.fromSecs Double
interval)
                            IO State
getTime
                            IORef [TimeoutThread]
conns
                            MVar ()
mp
                            MVar SnapThread
mthr

    SnapThread
thr <- ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
T.fork ByteString
"snap-server: timeout manager" forall a b. (a -> b) -> a -> b
$ TimeoutManager -> (forall a. IO a -> IO a) -> IO ()
managerThread TimeoutManager
tm
    forall a. MVar a -> a -> IO ()
putMVar MVar SnapThread
mthr SnapThread
thr
    forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutManager
tm


------------------------------------------------------------------------------
-- | Stop a TimeoutManager.
stop :: TimeoutManager -> IO ()
stop :: TimeoutManager -> IO ()
stop TimeoutManager
tm = forall a. MVar a -> IO a
readMVar (TimeoutManager -> MVar SnapThread
_managerThread TimeoutManager
tm) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SnapThread -> IO ()
T.cancelAndWait


------------------------------------------------------------------------------
wakeup :: TimeoutManager -> IO ()
wakeup :: TimeoutManager -> IO ()
wakeup TimeoutManager
tm = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO Bool
tryPutMVar (TimeoutManager -> MVar ()
_morePlease TimeoutManager
tm) forall a b. (a -> b) -> a -> b
$! ()


------------------------------------------------------------------------------
-- | Register a new thread with the TimeoutManager.
register :: TimeoutManager                        -- ^ manager to register
                                                  --   with
         -> S.ByteString                          -- ^ thread label
         -> ((forall a . IO a -> IO a) -> IO ())  -- ^ thread action to run
         -> IO TimeoutThread
register :: TimeoutManager
-> ByteString
-> ((forall a. IO a -> IO a) -> IO ())
-> IO TimeoutThread
register TimeoutManager
tm ByteString
label (forall a. IO a -> IO a) -> IO ()
action = do
    State
now <- IO State
getTime
    let !state :: State
state = State
now forall a. Num a => a -> a -> a
+ State
defaultTimeout
    IORef State
stateRef <- forall a. a -> IO (IORef a)
newIORef State
state
    TimeoutThread
th <- forall a. IO a -> IO a
E.uninterruptibleMask_ forall a b. (a -> b) -> a -> b
$ do
        SnapThread
t <- ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
T.fork ByteString
label (forall a. IO a -> IO a) -> IO ()
action
        let h :: TimeoutThread
h = SnapThread -> IORef State -> IO State -> TimeoutThread
TimeoutThread SnapThread
t IORef State
stateRef IO State
getTime
        forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef [TimeoutThread]
threads (\[TimeoutThread]
x -> (TimeoutThread
hforall a. a -> [a] -> [a]
:[TimeoutThread]
x, ())) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. a -> IO a
evaluate
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! TimeoutThread
h
    TimeoutManager -> IO ()
wakeup TimeoutManager
tm
    forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutThread
th

  where
    getTime :: IO State
getTime        = TimeoutManager -> IO State
_getTime TimeoutManager
tm
    threads :: IORef [TimeoutThread]
threads        = TimeoutManager -> IORef [TimeoutThread]
_threads TimeoutManager
tm
    defaultTimeout :: State
defaultTimeout = TimeoutManager -> State
_defaultTimeout TimeoutManager
tm


------------------------------------------------------------------------------
-- | Tickle the timeout on a connection to be at least N seconds into the
-- future. If the existing timeout is set for M seconds from now, where M > N,
-- then the timeout is unaffected.
tickle :: TimeoutThread -> Int -> IO ()
tickle :: TimeoutThread -> Int -> IO ()
tickle TimeoutThread
th = TimeoutThread -> (Int -> Int) -> IO ()
modify TimeoutThread
th forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => a -> a -> a
max
{-# INLINE tickle #-}


------------------------------------------------------------------------------
-- | Set the timeout on a connection to be N seconds into the future.
set :: TimeoutThread -> Int -> IO ()
set :: TimeoutThread -> Int -> IO ()
set TimeoutThread
th = TimeoutThread -> (Int -> Int) -> IO ()
modify TimeoutThread
th forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> b -> a
const
{-# INLINE set #-}


------------------------------------------------------------------------------
-- | Modify the timeout with the given function.
modify :: TimeoutThread -> (Int -> Int) -> IO ()
modify :: TimeoutThread -> (Int -> Int) -> IO ()
modify TimeoutThread
th Int -> Int
f = do
    State
now   <- IO State
getTime
    State
state <- forall a. IORef a -> IO a
readIORef IORef State
stateRef
    let !state' :: State
state' = State -> (State -> State) -> State -> State
smap State
now State -> State
f' State
state
    forall a. IORef a -> a -> IO ()
writeIORef IORef State
stateRef State
state'

  where
    f' :: State -> State
f' !State
x    = Double -> State
Clock.fromSecs forall a b. (a -> b) -> a -> b
$! forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Int -> Int
f forall a b. (a -> b) -> a -> b
$ forall a b. (RealFrac a, Integral b) => a -> b
round forall a b. (a -> b) -> a -> b
$ State -> Double
Clock.toSecs State
x
    getTime :: IO State
getTime  = TimeoutThread -> IO State
_hGetTime TimeoutThread
th
    stateRef :: IORef State
stateRef = TimeoutThread -> IORef State
_state TimeoutThread
th
{-# INLINE modify #-}


------------------------------------------------------------------------------
-- | Cancel a timeout.
cancel :: TimeoutThread -> IO ()
cancel :: TimeoutThread -> IO ()
cancel TimeoutThread
h = forall a. IO a -> IO a
E.uninterruptibleMask_ forall a b. (a -> b) -> a -> b
$ do
    forall a. IORef a -> a -> IO ()
writeIORef (TimeoutThread -> IORef State
_state TimeoutThread
h) State
canceled
    SnapThread -> IO ()
T.cancel forall a b. (a -> b) -> a -> b
$ TimeoutThread -> SnapThread
_thread TimeoutThread
h
{-# INLINE cancel #-}


------------------------------------------------------------------------------
managerThread :: TimeoutManager -> (forall a. IO a -> IO a) -> IO ()
managerThread :: TimeoutManager -> (forall a. IO a -> IO a) -> IO ()
managerThread TimeoutManager
tm forall a. IO a -> IO a
restore = forall a. IO a -> IO a
restore forall {b}. IO b
loop forall a b. IO a -> IO b -> IO a
`finally` IO ()
cleanup
  where
    cleanup :: IO ()
cleanup = forall a. IO a -> IO a
E.uninterruptibleMask_ forall a b. (a -> b) -> a -> b
$
              forall a. IO a -> IO ()
eatException (forall a. IORef a -> IO a
readIORef IORef [TimeoutThread]
threads forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall {t :: * -> *}. Foldable t => t TimeoutThread -> IO ()
destroyAll)

    --------------------------------------------------------------------------
    getTime :: IO State
getTime      = TimeoutManager -> IO State
_getTime TimeoutManager
tm
    morePlease :: MVar ()
morePlease   = TimeoutManager -> MVar ()
_morePlease TimeoutManager
tm
    pollInterval :: State
pollInterval = TimeoutManager -> State
_pollInterval TimeoutManager
tm
    threads :: IORef [TimeoutThread]
threads      = TimeoutManager -> IORef [TimeoutThread]
_threads TimeoutManager
tm

    --------------------------------------------------------------------------
    loop :: IO b
loop = do
        State
now <- IO State
getTime
        forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
E.uninterruptibleMask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore' -> do
            [TimeoutThread]
handles <- forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef [TimeoutThread]
threads (\[TimeoutThread]
x -> ([], [TimeoutThread]
x))
            if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TimeoutThread]
handles
              then do forall a. IO a -> IO a
restore' forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> IO a
takeMVar MVar ()
morePlease
              else do
                [TimeoutThread]
handles' <- State -> [TimeoutThread] -> IO [TimeoutThread]
processHandles State
now [TimeoutThread]
handles
                forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef [TimeoutThread]
threads (\[TimeoutThread]
x -> ([TimeoutThread]
handles' forall a. [a] -> [a] -> [a]
++ [TimeoutThread]
x, ()))
                    forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. a -> IO a
evaluate
        State -> IO ()
Clock.sleepFor State
pollInterval
        IO b
loop

    --------------------------------------------------------------------------
    processHandles :: State -> [TimeoutThread] -> IO [TimeoutThread]
processHandles State
now [TimeoutThread]
handles = [TimeoutThread] -> [TimeoutThread] -> IO [TimeoutThread]
go [TimeoutThread]
handles []
      where
        go :: [TimeoutThread] -> [TimeoutThread] -> IO [TimeoutThread]
go [] ![TimeoutThread]
kept = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! [TimeoutThread]
kept

        go (TimeoutThread
x:[TimeoutThread]
xs) ![TimeoutThread]
kept = do
            !State
state <- forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ TimeoutThread -> IORef State
_state TimeoutThread
x
            ![TimeoutThread]
kept' <-
                if State -> Bool
isCanceled State
state
                  then do Bool
b <- SnapThread -> IO Bool
T.isFinished (TimeoutThread -> SnapThread
_thread TimeoutThread
x)
                          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! if Bool
b
                                      then [TimeoutThread]
kept
                                      else (TimeoutThread
xforall a. a -> [a] -> [a]
:[TimeoutThread]
kept)
                  else do forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (State
state forall a. Ord a => a -> a -> Bool
<= State
now) forall a b. (a -> b) -> a -> b
$ do
                            SnapThread -> IO ()
T.cancel (TimeoutThread -> SnapThread
_thread TimeoutThread
x)
                            forall a. IORef a -> a -> IO ()
writeIORef (TimeoutThread -> IORef State
_state TimeoutThread
x) State
canceled
                          forall (m :: * -> *) a. Monad m => a -> m a
return (TimeoutThread
xforall a. a -> [a] -> [a]
:[TimeoutThread]
kept)
            [TimeoutThread] -> [TimeoutThread] -> IO [TimeoutThread]
go [TimeoutThread]
xs [TimeoutThread]
kept'

    --------------------------------------------------------------------------
    destroyAll :: t TimeoutThread -> IO ()
destroyAll t TimeoutThread
xs = do
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SnapThread -> IO ()
T.cancel forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeoutThread -> SnapThread
_thread) t TimeoutThread
xs
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SnapThread -> IO ()
T.wait forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeoutThread -> SnapThread
_thread) t TimeoutThread
xs