-- |
-- This module handles timeouts by using a (single) thread sitting in threadDelay
-- and the STM. One can request an IO action be performed after some number of
-- seconds and later cancel that request if need be.
-- The number of threads used is constant.
module Control.Timeout (
   , cancelTimeout
   , TimeoutTag) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad

import System.IO.Unsafe

import qualified Data.Map as Map
import Data.Time.Clock.POSIX

-- | This is set, atomically, to true when the manager thread is started.
--   This thread isn't started unless someone actually creates a timeout
managerThreadStarted :: TVar Bool
managerThreadStarted = unsafePerformIO $ newTVarIO False

-- | When a timeout thread times out, it compares the first element of this
--  tuple against the value that it was created with. If they don't match then
--  it's no longer the current timeout thread and it exits. Otherwise, it sets
--  the second element to true and exits.
signal :: TVar (Int, Bool)
signal = unsafePerformIO $ newTVarIO (0, False)

-- | This is a map of all the timeouts. It maps the absolute time
--   that the timeout expires at to a list of tagged actions
--   to perform at that time. For a given value in the map, the Ints of
--   every element in the list must be unique.
timeouts :: TVar (Map.Map POSIXTime [(Int, IO ())])
timeouts = unsafePerformIO $ newTVarIO Map.empty

-- Here's how everything works. The addTimeout and cancelTimeout functions
-- alter the above globals using the STM. The first call to addTimeout will
-- start a manager thread which watches @signal@ and @timeouts@ for changes:
-- either the timeoutThread has completed or someone has added/removed the
-- least element of timeouts.
-- In the first case, @timeouts@ is updated by removing all the expired
-- timeouts and their actions are performed. In both cases, the time to the
-- next timeout is calculated and a timeoutThread is forked to sleep for that
-- length of time before signaling via @signal@.
-- Timeouts are identified by their absolute time value and the unique tag
-- number for their action at that time. When creating a timeout, that pair
-- is wrapped in the (opaque) TimeoutTag type and returned. When canceling a
-- timeout, the list of actions for the given absolute time is filtered to
-- remove the indicated action. Because of this, the tags for a given absolute
-- time must be unique. This is achieved by giving the first element a tag of 0
-- and giving subsequent elements a tag one greater than the current max.
-- Each timeoutThread is given a tag (of a different type to the tags in the
-- last paragraph) so that @signal@ is never set by an old timeoutThread
-- which hasn't died yet.

-- | Get the list of expired timers from a map of timeouts
expiredTimers :: POSIXTime  -- ^ the current time
              -> Map.Map POSIXTime a  -- ^ the timeouts map
              -> ([a], Map.Map POSIXTime a)  -- ^ the list of actions and a new map
expiredTimers curtime m = do
  unfoldrWithValue f m where
  f m =
    case Map.minViewWithKey m of
         Nothing -> Nothing
         Just ((time, action), m') ->
           if time <= curtime
              then Just (action, m')
              else Nothing

-- | Run the actions for all expired timers in the @timeouts@ global. Update
--   that global with a new Map, less the expired timeouts.
runExpiredTimers :: (Monad m)
                 => POSIXTime  -- ^ the current time
                 -> TVar (Map.Map POSIXTime [(Int, m ())])  -- ^ the timeouts map
                 -> STM (m ())
runExpiredTimers currentTime tm = do
  m <- readTVar tm
  let (actions, m') = expiredTimers currentTime m
  when (length actions > 0) $ writeTVar tm m'
  return $ (do sequence $ map (sequence . map snd) actions; return ())

-- | A version of unfoldr which returns the final value as well. Note that
--   the resulting list comes off in reverse order
unfoldrWithValue :: (b -> Maybe (a, b)) -> b -> ([a], b)
unfoldrWithValue f i = inner f i [] where
  inner f i acc =
    case f i of
       Nothing -> (acc, i)
       Just (v, i') -> inner f i' $ v : acc

-- | This is a thread which waits for the given number of milliseconds
--   and tries to set the snd element of the global @signal@ to true, iff
--   the first element of that global is equal to its tag number.
timeoutThread :: Int  -- ^ the id of this timeout, see @signal@
              -> POSIXTime  -- ^ the time to wakeup
              -> IO ()
timeoutThread id targetTime = do
  currentTime <- getPOSIXTime
  let deltausecs = truncate $ fromRational $ toRational ((currentTime - targetTime) * 1000000)
  when (deltausecs > 0) $ threadDelay deltausecs
  atomically (do
    (id', _) <- readTVar signal
    when (id' == id) $ writeTVar signal (id, True))

-- | This is an opaque type of timeouts. A value of this type is returned
--   when creating a timeout and can be used to cancel the same timeout.
newtype TimeoutTag = TimeoutTag (POSIXTime, Int)

-- | Add an action to be performed at some point in the future. The action will
--   occur inside a thread which is dedicated to performing them so it should
--   run quickly and certainly should not block on IO etc.
addTimeout :: Float  -- ^ the number of seconds in the future to perform the action
           -> (IO ())  -- ^ the action to perform
           -> IO TimeoutTag
addTimeout delta action = do
  currentTime <- getPOSIXTime
  let future = currentTime + (fromRational $ toRational delta)
  tag <- atomically (do
    m <- readTVar timeouts
    case Map.lookup future m of
         Nothing -> do writeTVar timeouts $ Map.insert future [(0, action)] m
                       return $ TimeoutTag (future, 0)
         Just xs -> do let magic = 1 + (maximum $ map fst xs)
                       writeTVar timeouts $ Map.insert future ((magic, action) : xs) m
                       return $ TimeoutTag (future, magic))
  -- If the manager thread isn't running, start it now.
  startp <- atomically (do
    started <- readTVar managerThreadStarted
    when (not started) $ writeTVar managerThreadStarted True
    return $ not started)
  when startp $ forkIO (timeoutManagerThread timeouts signal 0 Nothing) >> return ()

  return tag

-- | Remove a timeout. This function never fails, but will return False if the
--   given timeout couldn't be found. This may be because cancelTimeout has
--   already been called with this tag, or because the timeout has already
--   fired. Note that, since timeouts are IO actions, they don't run atomically.
--   Thus it's possible that this call returns False and that the timeout is
--   currently in the process of running.
--   Note that one should never call cancelTimeout twice with the same tag since
--   it's possible that the tag will be reused and thus the second call could
--   cancel a different timeout.
cancelTimeout :: TimeoutTag  -- ^ the tag returned by addTimeout
              -> STM Bool  -- ^ returns False if the timeout didn't exist
cancelTimeout (TimeoutTag (future, n)) = do
  m <- readTVar timeouts
  case Map.lookup future m of
       Nothing -> return False
       Just xs -> do
         let xs' = filter (\(t, _) -> t /= n) xs
             m' = Map.insert future xs' m
         writeTVar timeouts m'
         return $ length xs' /= length xs

timeoutManagerThread :: TVar (Map.Map POSIXTime [(Int, IO ())])
                     -> TVar (Int, Bool)
                     -> POSIXTime  -- ^ the current minimum time
                     -> Maybe ThreadId  -- ^ the id of the current timeoutThread
                     -> IO ()
timeoutManagerThread tm signal currentMin mthid = do
  -- the event is either a signal from the timeoutThread (True) or that a
  -- timeout has been added to the timeouts map (False)
  (event, currentTag) <- atomically (do
    s <- readTVar signal
    let currentTag = fst s
    if snd s
       then return (True, currentTag)
       else (do
         m <- readTVar tm
         if (Map.null m && currentMin /= 0) ||
             (not (Map.null m) && fst (Map.findMin m) /= currentMin)
            then return (False, currentTag)
            else retry))

  currentTime <- getPOSIXTime
  when event $ atomically (runExpiredTimers currentTime tm) >>= id
  minTimeout <- atomically (do
    m <- readTVar timeouts
    if Map.null m
       then return 0
       else return $ fst $ Map.findMin m)
  let nextTag = currentTag + 1
  atomically $ writeTVar signal (nextTag, False)
  case mthid of
       Nothing -> return ()
       Just x -> killThread x
  tid <- forkIO $ timeoutThread nextTag minTimeout
  timeoutManagerThread tm signal minTimeout $ Just tid