{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-} -- for GHC 7.4 or earlier

-- |
--
-- In order to provide slowloris protection, Warp provides timeout handlers. We
-- follow these rules:
--
-- * A timeout is created when a connection is opened.
--
-- * When all request headers are read, the timeout is tickled.
--
-- * Every time at least 2048 bytes of the request body are read, the timeout
--   is tickled.
--
-- * The timeout is paused while executing user code. This will apply to both
--   the application itself, and a ResponseSource response. The timeout is
--   resumed as soon as we return from user code.
--
-- * Every time data is successfully sent to the client, the timeout is tickled.

module Network.Wai.Handler.Warp.Timeout (
  -- * Types
    Manager
  , TimeoutAction
  , Handle
  -- * Manager
  , initialize
  , stopManager
  , withManager
  -- * Registration
  , register
  , registerKillThread
  -- * Control
  , tickle
  , cancel
  , pause
  , resume
  -- * Exceptions
  , TimeoutThread (..)
  ) where

#ifndef MIN_VERSION_base
#define MIN_VERSION_base(x,y,z) 1
#endif

#if MIN_VERSION_base(4,6,0)
import Control.Concurrent (mkWeakThreadId, ThreadId)
#else
import GHC.Conc (ThreadId(..))
import GHC.Exts (mkWeak#)
import GHC.IO (IO (IO))
#endif
import Control.Concurrent (myThreadId)
import qualified Control.Exception as E
import GHC.Weak (Weak (..))
import System.Mem.Weak (deRefWeak)
import Data.Typeable (Typeable)
import Control.Reaper
#if USE_ATOMIC_PRIMOPS
import qualified Data.Atomics.Counter.Unboxed as C
import Data.Atomics.Counter.Unboxed (casCounter)
#else
import Network.Wai.Handler.Warp.IORef
#endif

----------------------------------------------------------------

-- | A timeout manager
type Manager = Reaper [Handle] Handle

-- | An action to be performed on timeout.
type TimeoutAction = IO ()

-- | A handle used by 'Manager'
data Handle = Handle TimeoutAction
#if USE_ATOMIC_PRIMOPS
    {-# UNPACK #-} !C.AtomicCounter
#else
    {-# UNPACK #-} !(IORef Int)
#endif

-- Four states for the AtomicCounter:
--
-- 1: Active    -- Manager turns it to Inactive.
-- 2: Inactive  -- Manager removes it with timeout action.
-- 3: Paused    -- Manager does not change it.
-- 4: Canceled  -- Manager removes it without timeout action.

----------------------------------------------------------------

-- | Creating timeout manager which works every N micro seconds
--   where N is the first argument.
initialize :: Int -> IO Manager
initialize timeout = mkReaper defaultReaperSettings
        { reaperAction = mkListAction prune
        , reaperDelay = timeout
        }
  where
    prune m@(Handle onTimeout iactive) = do
        -- Try to change from active to inactive
        (wasActive, newState) <- casCounter iactive 1 2
        case newState of
            2 | not wasActive -> do -- inactive
                onTimeout `E.catch` ignoreAll
                return Nothing
            4 -> return Nothing -- canceled
            _        -> return $ Just m

#if !USE_ATOMIC_PRIMOPS
casCounter :: IORef Int -> Int -> Int -> IO (Bool, Int)
casCounter ref old new = atomicModifyIORef' ref $ \curr ->
    if old == curr
        then (new, (True, new))
        else (old, (False, old))
#endif

----------------------------------------------------------------

-- | Stopping timeout manager.
stopManager :: Manager -> IO ()
stopManager mgr = E.mask_ (reaperStop mgr >>= mapM_ fire)
  where
    fire (Handle onTimeout _) = onTimeout `E.catch` ignoreAll

ignoreAll :: E.SomeException -> IO ()
ignoreAll _ = return ()

----------------------------------------------------------------

-- | Registering a timeout action.
register :: Manager -> TimeoutAction -> IO Handle
register mgr onTimeout = do
#if USE_ATOMIC_PRIMOPS
    iactive <- C.newCounter 1
#else
    iactive <- newIORef 1
#endif
    let h = Handle onTimeout iactive
    reaperAdd mgr h
    return h

-- | Registering a timeout action of killing this thread.
registerKillThread :: Manager -> IO Handle
registerKillThread m = do
    wtid <- myThreadId >>= mkWeakThreadId
    register m $ killIfExist wtid

-- If ThreadId is hold referred by a strong reference,
-- it leaks even after the thread is killed.
-- So, let's use a weak reference so that CG can throw ThreadId away.
-- deRefWeak checks if ThreadId referenced by the weak reference
-- exists. If exists, it means that the thread is alive.
killIfExist :: Weak ThreadId -> TimeoutAction
killIfExist wtid = deRefWeak wtid >>= maybe (return ()) (`E.throwTo` TimeoutThread)

data TimeoutThread = TimeoutThread
    deriving Typeable
instance E.Exception TimeoutThread
instance Show TimeoutThread where
    show TimeoutThread = "Thread killed by Warp's timeout reaper"

#if !MIN_VERSION_base(4,6,0)
mkWeakThreadId :: ThreadId -> IO (Weak ThreadId)
mkWeakThreadId t@(ThreadId t#) = IO $ \s ->
   case mkWeak# t# t Nothing s of
      (# s1, w #) -> (# s1, Weak w #)
#endif

----------------------------------------------------------------

writeCounter :: Int -> Handle -> IO ()
#if USE_ATOMIC_PRIMOPS
writeCounter i (Handle _ iactive) = C.writeCounter iactive i
#else
writeCounter i (Handle _ iactive) = writeIORef iactive i
#endif
{-# INLINE writeCounter #-}

-- | Setting the state to active.
--   'Manager' turns active to inactive repeatedly.
tickle :: Handle -> IO ()
tickle = writeCounter 1

-- | Setting the state to canceled.
--   'Manager' eventually removes this without timeout action.
cancel :: Handle -> IO ()
cancel = writeCounter 4

-- | Setting the state to paused.
--   'Manager' does not change the value.
pause :: Handle -> IO ()
pause = writeCounter 3

-- | Setting the paused state to active.
--   This is an alias to 'tickle'.
resume :: Handle -> IO ()
resume = tickle

----------------------------------------------------------------

-- | Call the inner function with a timeout manager.
withManager :: Int -- ^ timeout in microseconds
            -> (Manager -> IO a)
            -> IO a
withManager timeout f = do
    -- FIXME when stopManager is available, use it
    man <- initialize timeout
    f man