{-# LANGUAGE BangPatterns #-}

module System.IO.Streams.Heartbeat
    ( heartbeatOutputStream
    , heartbeatInputStream
    , HeartbeatException (..)
    ) where

import           Control.Concurrent       (threadDelay)
import           Control.Concurrent.Async (async, cancel, link)
import           Control.Exception        (Exception, throw)
import           Control.Monad            (forever)
import           Data.IORef               (atomicModifyIORef', newIORef, writeIORef)
import           Data.Time.Clock          (DiffTime, UTCTime, diffTimeToPicoseconds, diffUTCTime, getCurrentTime)
import           System.IO.Streams        (InputStream, OutputStream)
import qualified System.IO.Streams        as Streams


-- | Send a message 'a' if nothing has been written on the stream for some interval of time.
-- Writing 'Nothing' to this 'OutputStream' is required for proper cleanup.
heartbeatOutputStream :: DiffTime -- ^ Heartbeat interval
                      -> a        -- ^ Heartbeat message
                      -> OutputStream a -> IO (OutputStream a)
heartbeatOutputStream :: DiffTime -> a -> OutputStream a -> IO (OutputStream a)
heartbeatOutputStream DiffTime
interval a
msg OutputStream a
os = do
    IORef UTCTime
t <- UTCTime -> IO (IORef UTCTime)
forall a. a -> IO (IORef a)
newIORef (UTCTime -> IO (IORef UTCTime)) -> IO UTCTime -> IO (IORef UTCTime)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO UTCTime
getCurrentTime
    Async Any
writeAsync <- IO Any -> IO (Async Any)
forall a. IO a -> IO (Async a)
async (IO Any -> IO (Async Any)) -> IO Any -> IO (Async Any)
forall a b. (a -> b) -> a -> b
$ IO ()
delayInterval IO () -> IO Any -> IO Any
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO () -> IO Any
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IORef UTCTime -> IO ()
writeHeartbeat IORef UTCTime
t)
    Async Any -> IO ()
forall a. Async a -> IO ()
link Async Any
writeAsync
    (Maybe a -> IO ()) -> IO (OutputStream a)
forall a. (Maybe a -> IO ()) -> IO (OutputStream a)
Streams.makeOutputStream (IORef UTCTime -> Async Any -> Maybe a -> IO ()
forall a. IORef UTCTime -> Async a -> Maybe a -> IO ()
resetHeartbeat IORef UTCTime
t Async Any
writeAsync)
  where
    delayInterval :: IO ()
delayInterval = DiffTime -> IO ()
delayDiffTime DiffTime
interval

    writeHeartbeat :: IORef UTCTime -> IO ()
writeHeartbeat IORef UTCTime
t = do
        !UTCTime
now <- IO UTCTime
getCurrentTime
        (!DiffTime
timeTilHeartbeat, !Bool
triggerHeartbeat) <- IORef UTCTime
-> (UTCTime -> (UTCTime, (DiffTime, Bool))) -> IO (DiffTime, Bool)
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef UTCTime
t (DiffTime -> UTCTime -> UTCTime -> (UTCTime, (DiffTime, Bool))
heartbeatTime DiffTime
interval UTCTime
now)

        if Bool
triggerHeartbeat
            then Maybe a -> OutputStream a -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
Streams.write (a -> Maybe a
forall a. a -> Maybe a
Just a
msg) OutputStream a
os IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
delayInterval
            else DiffTime -> IO ()
delayDiffTime DiffTime
timeTilHeartbeat

    resetHeartbeat :: IORef UTCTime -> Async a -> Maybe a -> IO ()
resetHeartbeat IORef UTCTime
t Async a
_ x :: Maybe a
x@(Just a
_)       = Maybe a -> OutputStream a -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
Streams.write Maybe a
x OutputStream a
os IO () -> IO UTCTime -> IO UTCTime
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO UTCTime
getCurrentTime IO UTCTime -> (UTCTime -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef UTCTime -> UTCTime -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef UTCTime
t
    resetHeartbeat IORef UTCTime
_ Async a
writeAsync Maybe a
Nothing = Maybe a -> OutputStream a -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
Streams.write Maybe a
forall a. Maybe a
Nothing OutputStream a
os IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Async a -> IO ()
forall a. Async a -> IO ()
cancel Async a
writeAsync


-- | Exception to kill the heartbeat monitoring thread
-- Heartbeat Exceptions carry the grace period, ie. the last time a message was received
data HeartbeatException = MissedHeartbeat DiffTime deriving (Int -> HeartbeatException -> ShowS
[HeartbeatException] -> ShowS
HeartbeatException -> String
(Int -> HeartbeatException -> ShowS)
-> (HeartbeatException -> String)
-> ([HeartbeatException] -> ShowS)
-> Show HeartbeatException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HeartbeatException] -> ShowS
$cshowList :: [HeartbeatException] -> ShowS
show :: HeartbeatException -> String
$cshow :: HeartbeatException -> String
showsPrec :: Int -> HeartbeatException -> ShowS
$cshowsPrec :: Int -> HeartbeatException -> ShowS
Show, HeartbeatException -> HeartbeatException -> Bool
(HeartbeatException -> HeartbeatException -> Bool)
-> (HeartbeatException -> HeartbeatException -> Bool)
-> Eq HeartbeatException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HeartbeatException -> HeartbeatException -> Bool
$c/= :: HeartbeatException -> HeartbeatException -> Bool
== :: HeartbeatException -> HeartbeatException -> Bool
$c== :: HeartbeatException -> HeartbeatException -> Bool
Eq)
instance Exception HeartbeatException


-- | Grace period = grace time multiplier x heartbeat interval
-- Usually something like graceMultiplier = 2 is a good idea.
--
-- This throws a 'MissedHeartbeat' exception if a heartbeat is not
-- received within the grace period.
heartbeatInputStream :: DiffTime -- ^ Heartbeat interval
                     -> DiffTime -- ^ Grace time multiplier
                     -> InputStream a -> IO (InputStream a)
heartbeatInputStream :: DiffTime -> DiffTime -> InputStream a -> IO (InputStream a)
heartbeatInputStream DiffTime
interval DiffTime
graceMultiplier InputStream a
is = do
    IORef UTCTime
t <- UTCTime -> IO (IORef UTCTime)
forall a. a -> IO (IORef a)
newIORef (UTCTime -> IO (IORef UTCTime)) -> IO UTCTime -> IO (IORef UTCTime)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO UTCTime
getCurrentTime
    Async Any
checkAsync <- IO Any -> IO (Async Any)
forall a. IO a -> IO (Async a)
async (IO Any -> IO (Async Any)) -> IO Any -> IO (Async Any)
forall a b. (a -> b) -> a -> b
$ DiffTime -> IO ()
delayDiffTime DiffTime
gracePeriod IO () -> IO Any -> IO Any
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO () -> IO Any
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IORef UTCTime -> IO ()
checkHeartbeat IORef UTCTime
t)
    Async Any -> IO ()
forall a. Async a -> IO ()
link Async Any
checkAsync
    -- If disconnect is received, cancel heartbeat watching thread
    (a -> IO ()) -> InputStream a -> IO (InputStream a)
forall a b. (a -> IO b) -> InputStream a -> IO (InputStream a)
Streams.mapM_ (IORef UTCTime -> a -> IO ()
forall p. IORef UTCTime -> p -> IO ()
resetHeartbeat IORef UTCTime
t) InputStream a
is IO (InputStream a)
-> (InputStream a -> IO (InputStream a)) -> IO (InputStream a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IO () -> InputStream a -> IO (InputStream a)
forall b a. IO b -> InputStream a -> IO (InputStream a)
Streams.atEndOfInput (Async Any -> IO ()
forall a. Async a -> IO ()
cancel Async Any
checkAsync)
  where
    gracePeriod :: DiffTime
gracePeriod = DiffTime
graceMultiplier DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
* DiffTime
interval

    checkHeartbeat :: IORef UTCTime -> IO ()
checkHeartbeat IORef UTCTime
t = do
        !UTCTime
now <- IO UTCTime
getCurrentTime
        !Bool
triggerDisconnect <- (DiffTime, Bool) -> Bool
forall a b. (a, b) -> b
snd ((DiffTime, Bool) -> Bool) -> IO (DiffTime, Bool) -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef UTCTime
-> (UTCTime -> (UTCTime, (DiffTime, Bool))) -> IO (DiffTime, Bool)
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef UTCTime
t (DiffTime -> UTCTime -> UTCTime -> (UTCTime, (DiffTime, Bool))
heartbeatTime DiffTime
gracePeriod UTCTime
now)

        if Bool
triggerDisconnect
            then HeartbeatException -> IO ()
forall a e. Exception e => e -> a
throw (DiffTime -> HeartbeatException
MissedHeartbeat DiffTime
gracePeriod)
            else DiffTime -> IO ()
delayDiffTime DiffTime
interval

    resetHeartbeat :: IORef UTCTime -> p -> IO ()
resetHeartbeat IORef UTCTime
t p
_ = IO UTCTime
getCurrentTime IO UTCTime -> (UTCTime -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IORef UTCTime -> UTCTime -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef UTCTime
t


-- | This is structured to work nicely with 'atomicModifyIORef'. Given
-- the heartbeat interval and the current timestamp, calculate if a
-- heartbeat must be sent and how much time there is until the next heartbeat
-- must be sent.
heartbeatTime :: DiffTime -- ^ Maximum time since last message, ie. heartbeat interval or grace period
              -> UTCTime  -- ^ Current timestamp
              -> UTCTime  -- ^ Last message timestamp
              -> (UTCTime, (DiffTime, Bool)) -- ^ (New last message timestamp, (time til heartbeat, send new message?))
heartbeatTime :: DiffTime -> UTCTime -> UTCTime -> (UTCTime, (DiffTime, Bool))
heartbeatTime DiffTime
interval UTCTime
now UTCTime
lastTime = (if Bool
triggerHeartbeat then UTCTime
now else UTCTime
lastTime, (DiffTime
timeTilHeartbeat, Bool
triggerHeartbeat))
  where
    timeSinceMsg :: DiffTime
timeSinceMsg = NominalDiffTime -> DiffTime
forall a b. (Real a, Fractional b) => a -> b
realToFrac (NominalDiffTime -> DiffTime) -> NominalDiffTime -> DiffTime
forall a b. (a -> b) -> a -> b
$ UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
now UTCTime
lastTime
    triggerHeartbeat :: Bool
triggerHeartbeat = DiffTime
timeSinceMsg DiffTime -> DiffTime -> Bool
forall a. Ord a => a -> a -> Bool
>= DiffTime
interval
    timeTilHeartbeat :: DiffTime
timeTilHeartbeat = DiffTime
interval DiffTime -> DiffTime -> DiffTime
forall a. Num a => a -> a -> a
- DiffTime
timeSinceMsg


delayDiffTime :: DiffTime -> IO ()
delayDiffTime :: DiffTime -> IO ()
delayDiffTime = Int -> IO ()
threadDelay (Int -> IO ()) -> (DiffTime -> Int) -> DiffTime -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Int
picosToMicros
  where picosToMicros :: DiffTime -> Int
picosToMicros = Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> (DiffTime -> Integer) -> DiffTime -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DiffTime -> Integer
diffTimeToPicoseconds (DiffTime -> Integer)
-> (DiffTime -> DiffTime) -> DiffTime -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DiffTime -> DiffTime -> DiffTime
forall a. Fractional a => a -> a -> a
/ DiffTime
1000000)