{-# LANGUAGE DeriveDataTypeable #-}

module Network.QUIC.Connection.Timeout (
    timeouter
  , timeout
  , fire
  , cfire
  , delay
  ) where

import Control.Concurrent
import Control.Concurrent.STM
import Data.Typeable
import GHC.Event
import System.IO.Unsafe (unsafePerformIO)
import qualified UnliftIO.Exception as E

import Network.QUIC.Connection.Types
import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Types

data TimeoutException = TimeoutException deriving (Int -> TimeoutException -> ShowS
[TimeoutException] -> ShowS
TimeoutException -> String
(Int -> TimeoutException -> ShowS)
-> (TimeoutException -> String)
-> ([TimeoutException] -> ShowS)
-> Show TimeoutException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TimeoutException] -> ShowS
$cshowList :: [TimeoutException] -> ShowS
show :: TimeoutException -> String
$cshow :: TimeoutException -> String
showsPrec :: Int -> TimeoutException -> ShowS
$cshowsPrec :: Int -> TimeoutException -> ShowS
Show, Typeable)

instance E.Exception TimeoutException where
  fromException :: SomeException -> Maybe TimeoutException
fromException = SomeException -> Maybe TimeoutException
forall e. Exception e => SomeException -> Maybe e
E.asyncExceptionFromException
  toException :: TimeoutException -> SomeException
toException = TimeoutException -> SomeException
forall e. Exception e => e -> SomeException
E.asyncExceptionToException

globalTimeoutQ :: TQueue (IO ())
globalTimeoutQ :: TQueue (IO ())
globalTimeoutQ = IO (TQueue (IO ())) -> TQueue (IO ())
forall a. IO a -> a
unsafePerformIO IO (TQueue (IO ()))
forall a. IO (TQueue a)
newTQueueIO
{-# NOINLINE globalTimeoutQ #-}

timeouter :: IO ()
timeouter :: IO ()
timeouter = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ STM (IO ()) -> IO (IO ())
forall a. STM a -> IO a
atomically (TQueue (IO ()) -> STM (IO ())
forall a. TQueue a -> STM a
readTQueue TQueue (IO ())
globalTimeoutQ)

timeout :: Microseconds -> IO a -> IO (Maybe a)
timeout :: Microseconds -> IO a -> IO (Maybe a)
timeout (Microseconds Int
ms) IO a
action = do
    ThreadId
tid <- IO ThreadId
myThreadId
    TimerManager
timmgr <- IO TimerManager
getSystemTimerManager
    let killMe :: IO ()
killMe = ThreadId -> TimeoutException -> IO ()
forall e (m :: * -> *).
(Exception e, MonadIO m) =>
ThreadId -> e -> m ()
E.throwTo ThreadId
tid TimeoutException
TimeoutException
        onTimeout :: IO ()
onTimeout = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (IO ()) -> IO () -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue (IO ())
globalTimeoutQ IO ()
killMe
        setup :: IO TimeoutKey
setup = TimerManager -> Int -> IO () -> IO TimeoutKey
registerTimeout TimerManager
timmgr Int
ms IO ()
onTimeout
        cleanup :: TimeoutKey -> IO ()
cleanup TimeoutKey
key = TimerManager -> TimeoutKey -> IO ()
unregisterTimeout TimerManager
timmgr TimeoutKey
key
    (TimeoutException -> IO (Maybe a)) -> IO (Maybe a) -> IO (Maybe a)
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
(e -> m a) -> m a -> m a
E.handleSyncOrAsync (\TimeoutException
TimeoutException -> Maybe a -> IO (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing) (IO (Maybe a) -> IO (Maybe a)) -> IO (Maybe a) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$
        IO TimeoutKey
-> (TimeoutKey -> IO ())
-> (TimeoutKey -> IO (Maybe a))
-> IO (Maybe a)
forall (m :: * -> *) a b c.
MonadUnliftIO m =>
m a -> (a -> m b) -> (a -> m c) -> m c
E.bracket IO TimeoutKey
setup TimeoutKey -> IO ()
cleanup ((TimeoutKey -> IO (Maybe a)) -> IO (Maybe a))
-> (TimeoutKey -> IO (Maybe a)) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ \TimeoutKey
_ -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> IO a -> IO (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
action

fire :: Connection -> Microseconds -> TimeoutCallback -> IO ()
fire :: Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn (Microseconds Int
microseconds) IO ()
action = do
    TimerManager
timmgr <- IO TimerManager
getSystemTimerManager
    IO TimeoutKey -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO TimeoutKey -> IO ()) -> IO TimeoutKey -> IO ()
forall a b. (a -> b) -> a -> b
$ TimerManager -> Int -> IO () -> IO TimeoutKey
registerTimeout TimerManager
timmgr Int
microseconds IO ()
action'
  where
    action' :: IO ()
action' = do
        Bool
alive <- Connection -> IO Bool
forall a. Connector a => a -> IO Bool
getAlive Connection
conn
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
alive IO ()
action IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catchSyncOrAsync` SomeException -> IO ()
ignore

cfire :: Connection -> Microseconds -> TimeoutCallback -> IO (IO ())
cfire :: Connection -> Microseconds -> IO () -> IO (IO ())
cfire Connection
conn (Microseconds Int
microseconds) IO ()
action = do
    TimerManager
timmgr <- IO TimerManager
getSystemTimerManager
    TimeoutKey
key <- TimerManager -> Int -> IO () -> IO TimeoutKey
registerTimeout TimerManager
timmgr Int
microseconds IO ()
action'
    let cancel :: IO ()
cancel = TimerManager -> TimeoutKey -> IO ()
unregisterTimeout TimerManager
timmgr TimeoutKey
key
    IO () -> IO (IO ())
forall (m :: * -> *) a. Monad m => a -> m a
return IO ()
cancel
  where
    action' :: IO ()
action' = do
        Bool
alive <- Connection -> IO Bool
forall a. Connector a => a -> IO Bool
getAlive Connection
conn
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
alive IO ()
action IO () -> (SomeException -> IO ()) -> IO ()
forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catchSyncOrAsync` SomeException -> IO ()
ignore

delay :: Microseconds -> IO ()
delay :: Microseconds -> IO ()
delay (Microseconds Int
microseconds) = Int -> IO ()
threadDelay Int
microseconds