module Network.QUIC.Connection.Timeout (
    timeout,
    fire,
    cfire,
    delay,
) where

import Control.Concurrent
import Control.Exception
import Data.Unique (Unique, newUnique)
import GHC.Conc.Sync
import Network.QUIC.Event

import Network.QUIC.Connection.Types
import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Types

newtype Timeout = Timeout Unique deriving (Timeout -> Timeout -> Bool
(Timeout -> Timeout -> Bool)
-> (Timeout -> Timeout -> Bool) -> Eq Timeout
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Timeout -> Timeout -> Bool
== :: Timeout -> Timeout -> Bool
$c/= :: Timeout -> Timeout -> Bool
/= :: Timeout -> Timeout -> Bool
Eq)

instance Show Timeout where
    show :: Timeout -> String
show Timeout
_ = String
"<<timeout>>"

instance Exception Timeout where
    toException :: Timeout -> SomeException
toException = Timeout -> SomeException
forall e. Exception e => e -> SomeException
asyncExceptionToException
    fromException :: SomeException -> Maybe Timeout
fromException = SomeException -> Maybe Timeout
forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException

-- 'SomeException') within the computation will break the timeout behavior.
timeout :: Microseconds -> String -> IO a -> IO (Maybe a)
timeout :: forall a. Microseconds -> String -> IO a -> IO (Maybe a)
timeout (Microseconds Int
n) String
label IO a
f
    | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = (a -> Maybe a) -> IO a -> IO (Maybe a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Maybe a
forall a. a -> Maybe a
Just IO a
f
    | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
    | Bool
otherwise = do
        -- In the threaded RTS, we use the Timer Manager to delay the
        -- (fairly expensive) 'forkIO' call until the timeout has expired.
        --
        -- An additional thread is required for the actual delivery of
        -- the Timeout exception because killThread (or another throwTo)
        -- is the only way to reliably interrupt a throwTo in flight.
        ThreadId
pid <- IO ThreadId
myThreadId
        Timeout
ex <- (Unique -> Timeout) -> IO Unique -> IO Timeout
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Unique -> Timeout
Timeout IO Unique
newUnique
        TimerManager
tm <- IO TimerManager
getSystemTimerManager
        -- 'lock' synchronizes the timeout handler and the main thread:
        --  * the main thread can disable the handler by writing to 'lock';
        --  * the handler communicates the spawned thread's id through 'lock'.
        -- These two cases are mutually exclusive.
        MVar ThreadId
lock <- IO (MVar ThreadId)
forall a. IO (MVar a)
newEmptyMVar
        let handleTimeout :: IO ()
handleTimeout = do
                Bool
v <- MVar ThreadId -> IO Bool
forall a. MVar a -> IO Bool
isEmptyMVar MVar ThreadId
lock
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
v (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask -> IO () -> IO ()
forall a. IO a -> IO a
unmask (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                    ThreadId
tid <- IO ThreadId
myThreadId
                    ThreadId -> String -> IO ()
labelThread ThreadId
tid (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"timeout:" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
label
                    Bool
v2 <- MVar ThreadId -> ThreadId -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar ThreadId
lock (ThreadId -> IO Bool) -> IO ThreadId -> IO Bool
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO ThreadId
myThreadId
                    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
v2 (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ThreadId -> Timeout -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
pid Timeout
ex
            cleanupTimeout :: TimeoutKey -> IO ()
cleanupTimeout TimeoutKey
key = IO () -> IO ()
forall a. IO a -> IO a
uninterruptibleMask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Bool
v <- MVar ThreadId -> ThreadId -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar ThreadId
lock ThreadId
forall a. HasCallStack => a
undefined
                if Bool
v
                    then TimerManager -> TimeoutKey -> IO ()
unregisterTimeout TimerManager
tm TimeoutKey
key
                    else MVar ThreadId -> IO ThreadId
forall a. MVar a -> IO a
takeMVar MVar ThreadId
lock IO ThreadId -> (ThreadId -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ThreadId -> IO ()
killThread
        (Timeout -> Maybe ())
-> (() -> IO (Maybe a)) -> IO (Maybe a) -> IO (Maybe a)
forall e b a.
Exception e =>
(e -> Maybe b) -> (b -> IO a) -> IO a -> IO a
handleJust
            (\Timeout
e -> if Timeout
e Timeout -> Timeout -> Bool
forall a. Eq a => a -> a -> Bool
== Timeout
ex then () -> Maybe ()
forall a. a -> Maybe a
Just () else Maybe ()
forall a. Maybe a
Nothing)
            (\()
_ -> Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing)
            ( IO TimeoutKey
-> (TimeoutKey -> IO ())
-> (TimeoutKey -> IO (Maybe a))
-> IO (Maybe a)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
                (TimerManager -> Int -> IO () -> IO TimeoutKey
registerTimeout TimerManager
tm Int
n IO ()
handleTimeout)
                TimeoutKey -> IO ()
cleanupTimeout
                (\TimeoutKey
_ -> (a -> Maybe a) -> IO a -> IO (Maybe a)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Maybe a
forall a. a -> Maybe a
Just IO a
f)
            )

fire :: Connection -> Microseconds -> TimeoutCallback -> IO ()
fire :: Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn (Microseconds Int
microseconds) IO ()
action = do
    TimerManager
timmgr <- IO TimerManager
getSystemTimerManager
    IO TimeoutKey -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO TimeoutKey -> IO ()) -> IO TimeoutKey -> IO ()
forall a b. (a -> b) -> a -> b
$ TimerManager -> Int -> IO () -> IO TimeoutKey
registerTimeout TimerManager
timmgr Int
microseconds IO ()
action'
  where
    action' :: IO ()
action' = do
        Bool
alive <- Connection -> IO Bool
forall a. Connector a => a -> IO Bool
getAlive Connection
conn
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
alive IO ()
action IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
ignore

cfire :: Connection -> Microseconds -> TimeoutCallback -> IO (IO ())
cfire :: Connection -> Microseconds -> IO () -> IO (IO ())
cfire Connection
conn (Microseconds Int
microseconds) IO ()
action = do
    TimerManager
timmgr <- IO TimerManager
getSystemTimerManager
    TimeoutKey
key <- TimerManager -> Int -> IO () -> IO TimeoutKey
registerTimeout TimerManager
timmgr Int
microseconds IO ()
action'
    let cancel :: IO ()
cancel = TimerManager -> TimeoutKey -> IO ()
unregisterTimeout TimerManager
timmgr TimeoutKey
key
    IO () -> IO (IO ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return IO ()
cancel
  where
    action' :: IO ()
action' = do
        Bool
alive <- Connection -> IO Bool
forall a. Connector a => a -> IO Bool
getAlive Connection
conn
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
alive IO ()
action IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
ignore

delay :: Microseconds -> IO ()
delay :: Microseconds -> IO ()
delay (Microseconds Int
microseconds) = Int -> IO ()
threadDelay Int
microseconds