{-# LANGUAGE BangPatterns #-}

module System.IO.Streams.Heartbeat
    ( heartbeatOutputStream
    , heartbeatInputStream
    , HeartbeatException (..)
    ) where

import           Control.Concurrent       (threadDelay)
import           Control.Concurrent.Async (async, cancel, link)
import           Control.Exception        (Exception, throw)
import           Control.Monad            (forever)
import           Data.IORef               (atomicModifyIORef', newIORef, writeIORef)
import           Data.Time.Clock          (DiffTime, UTCTime, diffTimeToPicoseconds, diffUTCTime, getCurrentTime)
import           System.IO.Streams        (InputStream, OutputStream)
import qualified System.IO.Streams        as Streams


-- | Send a message 'a' if nothing has been written on the stream for some interval of time.
-- Writing 'Nothing' to this 'OutputStream' is required for proper cleanup.
heartbeatOutputStream :: DiffTime -- ^ Heartbeat interval
                      -> a        -- ^ Heartbeat message
                      -> OutputStream a -> IO (OutputStream a)
heartbeatOutputStream interval msg os = do
    t <- newIORef =<< getCurrentTime
    writeAsync <- async $ delayInterval >> forever (writeHeartbeat t)
    link writeAsync
    Streams.makeOutputStream (resetHeartbeat t writeAsync)
  where
    delayInterval = delayDiffTime interval

    writeHeartbeat t = do
        !now <- getCurrentTime
        (!timeTilHeartbeat, !triggerHeartbeat) <- atomicModifyIORef' t (heartbeatTime interval now)

        if triggerHeartbeat
            then Streams.write (Just msg) os >> delayInterval
            else delayDiffTime timeTilHeartbeat

    resetHeartbeat t _ x@(Just _)       = Streams.write x os >> getCurrentTime >>= writeIORef t
    resetHeartbeat _ writeAsync Nothing = Streams.write Nothing os >> cancel writeAsync


-- | Exception to kill the heartbeat monitoring thread
-- Heartbeat Exceptions carry the grace period, ie. the last time a message was received
data HeartbeatException = MissedHeartbeat DiffTime deriving (Show, Eq)
instance Exception HeartbeatException


-- | Grace period = grace time multiplier x heartbeat interval
-- Usually something like graceMultiplier = 2 is a good idea.
--
-- This throws a 'MissedHeartbeat' exception if a heartbeat is not
-- received within the grace period.
heartbeatInputStream :: DiffTime -- ^ Heartbeat interval
                     -> DiffTime -- ^ Grace time multiplier
                     -> InputStream a -> IO (InputStream a)
heartbeatInputStream interval graceMultiplier is = do
    t <- newIORef =<< getCurrentTime
    checkAsync <- async $ delayDiffTime gracePeriod >> forever (checkHeartbeat t)
    link checkAsync
    -- If disconnect is received, cancel heartbeat watching thread
    Streams.mapM_ (resetHeartbeat t) is >>= Streams.atEndOfInput (cancel checkAsync)
  where
    gracePeriod = graceMultiplier * interval

    checkHeartbeat t = do
        !now <- getCurrentTime
        !triggerDisconnect <- snd <$> atomicModifyIORef' t (heartbeatTime gracePeriod now)

        if triggerDisconnect
            then throw (MissedHeartbeat gracePeriod)
            else delayDiffTime interval

    resetHeartbeat t _ = getCurrentTime >>= writeIORef t


-- | This is structured to work nicely with 'atomicModifyIORef'. Given
-- the heartbeat interval and the current timestamp, calculate if a
-- heartbeat must be sent and how much time there is until the next heartbeat
-- must be sent.
heartbeatTime :: DiffTime -- ^ Maximum time since last message, ie. heartbeat interval or grace period
              -> UTCTime  -- ^ Current timestamp
              -> UTCTime  -- ^ Last message timestamp
              -> (UTCTime, (DiffTime, Bool)) -- ^ (New last message timestamp, (time til heartbeat, send new message?))
heartbeatTime interval now lastTime = (if triggerHeartbeat then now else lastTime, (timeTilHeartbeat, triggerHeartbeat))
  where
    timeSinceMsg = realToFrac $ diffUTCTime now lastTime
    triggerHeartbeat = timeSinceMsg >= interval
    timeTilHeartbeat = interval - timeSinceMsg


delayDiffTime :: DiffTime -> IO ()
delayDiffTime = threadDelay . picosToMicros
  where picosToMicros = fromIntegral . diffTimeToPicoseconds . (/ 1000000)