{-# LANGUAGE OverloadedStrings #-}

module Network.QUIC.Closer (closure) where

import Foreign.Marshal.Alloc
import qualified Network.Socket as NS
import UnliftIO.Concurrent
import qualified UnliftIO.Exception as E
import Foreign.Ptr

import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Imports
import Network.QUIC.Packet
import Network.QUIC.Recovery
import Network.QUIC.Sender
import Network.QUIC.Types

closure :: Connection -> LDCC -> Either E.SomeException a -> IO a
closure :: Connection -> LDCC -> Either SomeException a -> IO a
closure Connection
conn LDCC
ldcc (Right a
x) = do
    Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ TransportError -> FrameType -> ReasonPhrase -> Frame
ConnectionClose TransportError
NoError FrameType
0 ReasonPhrase
""
    a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
closure Connection
conn LDCC
ldcc (Left SomeException
se)
  | Just e :: QUICException
e@(TransportErrorIsSent TransportError
err ReasonPhrase
desc) <- SomeException -> Maybe QUICException
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ TransportError -> FrameType -> ReasonPhrase -> Frame
ConnectionClose TransportError
err FrameType
0 ReasonPhrase
desc
        QUICException -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
e
  | Just e :: QUICException
e@(ApplicationProtocolErrorIsSent ApplicationProtocolError
err ReasonPhrase
desc) <- SomeException -> Maybe QUICException
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> Frame
ConnectionCloseApp ApplicationProtocolError
err ReasonPhrase
desc
        QUICException -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO QUICException
e
  | Just (Abort ApplicationProtocolError
err ReasonPhrase
desc) <- SomeException -> Maybe Abort
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc (Frame -> IO ()) -> Frame -> IO ()
forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> Frame
ConnectionCloseApp ApplicationProtocolError
err ReasonPhrase
desc
        QUICException -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO (QUICException -> IO a) -> QUICException -> IO a
forall a b. (a -> b) -> a -> b
$ ApplicationProtocolError -> ReasonPhrase -> QUICException
ApplicationProtocolErrorIsSent ApplicationProtocolError
err ReasonPhrase
desc
  | Just (VerNego Maybe Version
ver) <- SomeException -> Maybe Abort
forall e. Exception e => SomeException -> Maybe e
E.fromException SomeException
se = do
        NextVersion -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO (NextVersion -> IO a) -> NextVersion -> IO a
forall a b. (a -> b) -> a -> b
$ Maybe Version -> NextVersion
NextVersion Maybe Version
ver
  | Bool
otherwise = SomeException -> IO a
forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO SomeException
se

closure' :: Connection -> LDCC -> Frame -> IO ()
closure' :: Connection -> LDCC -> Frame -> IO ()
closure' Connection
conn LDCC
ldcc Frame
frame = do
    Connection -> IO ()
killReaders Connection
conn
    IO ()
killTimeouter <- Connection -> IO (IO ())
replaceKillTimeouter Connection
conn
    socks :: [Socket]
socks@(Socket
s:[Socket]
_) <- Connection -> IO [Socket]
clearSockets Connection
conn
    let bufsiz :: FrameType
bufsiz = FrameType
maximumUdpPayloadSize
    Ptr Word8
sendBuf <- FrameType -> IO (Ptr Word8)
forall a. FrameType -> IO (Ptr a)
mallocBytes (FrameType
bufsiz FrameType -> FrameType -> FrameType
forall a. Num a => a -> a -> a
* FrameType
3)
    FrameType
siz <- Connection -> Frame -> Ptr Word8 -> FrameType -> IO FrameType
encodeCC Connection
conn Frame
frame Ptr Word8
sendBuf FrameType
bufsiz
    let recvBuf :: Ptr b
recvBuf = Ptr Word8
sendBuf Ptr Word8 -> FrameType -> Ptr b
forall a b. Ptr a -> FrameType -> Ptr b
`plusPtr` (FrameType
bufsiz FrameType -> FrameType -> FrameType
forall a. Num a => a -> a -> a
* FrameType
2)
        recv :: IO FrameType
recv = Socket -> Ptr Word8 -> FrameType -> IO FrameType
NS.recvBuf Socket
s Ptr Word8
forall b. Ptr b
recvBuf FrameType
bufsiz
        hook :: IO ()
hook = Hooks -> IO ()
onCloseCompleted (Hooks -> IO ()) -> Hooks -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> Hooks
connHooks Connection
conn
    IO FrameType
send <- if Connection -> Bool
forall a. Connector a => a -> Bool
isClient Connection
conn then do
               Maybe SockAddr
msa <- Connection -> IO (Maybe SockAddr)
getServerAddr Connection
conn
               IO FrameType -> IO (IO FrameType)
forall (m :: * -> *) a. Monad m => a -> m a
return (IO FrameType -> IO (IO FrameType))
-> IO FrameType -> IO (IO FrameType)
forall a b. (a -> b) -> a -> b
$ case Maybe SockAddr
msa of
                 Maybe SockAddr
Nothing -> Socket -> Ptr Word8 -> FrameType -> IO FrameType
NS.sendBuf   Socket
s Ptr Word8
sendBuf FrameType
siz
                 Just SockAddr
sa -> Socket -> Ptr Word8 -> FrameType -> SockAddr -> IO FrameType
forall a. Socket -> Ptr a -> FrameType -> SockAddr -> IO FrameType
NS.sendBufTo Socket
s Ptr Word8
sendBuf FrameType
siz SockAddr
sa
            else
              IO FrameType -> IO (IO FrameType)
forall (m :: * -> *) a. Monad m => a -> m a
return (IO FrameType -> IO (IO FrameType))
-> IO FrameType -> IO (IO FrameType)
forall a b. (a -> b) -> a -> b
$ Socket -> Ptr Word8 -> FrameType -> IO FrameType
NS.sendBuf Socket
s Ptr Word8
sendBuf FrameType
siz
    Microseconds
pto <- LDCC -> IO Microseconds
getPTO LDCC
ldcc
    IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> (Either SomeException () -> IO ()) -> IO ThreadId
forall (m :: * -> *) a.
MonadUnliftIO m =>
m a -> (Either SomeException a -> m ()) -> m ThreadId
forkFinally (Microseconds -> IO FrameType -> IO FrameType -> IO () -> IO ()
closer Microseconds
pto IO FrameType
send IO FrameType
recv IO ()
hook) ((Either SomeException () -> IO ()) -> IO ThreadId)
-> (Either SomeException () -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \Either SomeException ()
_ -> do
        Ptr Word8 -> IO ()
forall a. Ptr a -> IO ()
free Ptr Word8
sendBuf
        (Socket -> IO ()) -> [Socket] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Socket -> IO ()
NS.close [Socket]
socks
        IO ()
killTimeouter

encodeCC :: Connection -> Frame -> Buffer -> BufferSize -> IO Int
encodeCC :: Connection -> Frame -> Ptr Word8 -> FrameType -> IO FrameType
encodeCC Connection
conn Frame
frame Ptr Word8
sendBuf0 FrameType
bufsiz0 = do
    EncryptionLevel
lvl0 <- Connection -> IO EncryptionLevel
forall a. Connector a => a -> IO EncryptionLevel
getEncryptionLevel Connection
conn
    let lvl :: EncryptionLevel
lvl | EncryptionLevel
lvl0 EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
RTT0Level = EncryptionLevel
InitialLevel
            | Bool
otherwise         = EncryptionLevel
lvl0
    if EncryptionLevel
lvl EncryptionLevel -> EncryptionLevel -> Bool
forall a. Eq a => a -> a -> Bool
== EncryptionLevel
HandshakeLevel then do
        FrameType
siz0 <- Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType
encCC Ptr Word8
sendBuf0 FrameType
bufsiz0 EncryptionLevel
InitialLevel
        let sendBuf1 :: Ptr b
sendBuf1 = Ptr Word8
sendBuf0 Ptr Word8 -> FrameType -> Ptr b
forall a b. Ptr a -> FrameType -> Ptr b
`plusPtr` FrameType
siz0
            bufsiz1 :: FrameType
bufsiz1 = FrameType
bufsiz0 FrameType -> FrameType -> FrameType
forall a. Num a => a -> a -> a
- FrameType
siz0
        FrameType
siz1 <- Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType
encCC Ptr Word8
forall b. Ptr b
sendBuf1 FrameType
bufsiz1 EncryptionLevel
HandshakeLevel
        FrameType -> IO FrameType
forall (m :: * -> *) a. Monad m => a -> m a
return (FrameType
siz0 FrameType -> FrameType -> FrameType
forall a. Num a => a -> a -> a
+ FrameType
siz1)
      else
        Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType
encCC Ptr Word8
sendBuf0 FrameType
bufsiz0 EncryptionLevel
lvl
  where
    encCC :: Ptr Word8 -> FrameType -> EncryptionLevel -> IO FrameType
encCC Ptr Word8
sendBuf FrameType
bufsiz EncryptionLevel
lvl = do
        Header
header <- Connection -> EncryptionLevel -> IO Header
mkHeader Connection
conn EncryptionLevel
lvl
        FrameType
mypn <- Connection -> IO FrameType
nextPacketNumber Connection
conn
        let plain :: Plain
plain = Flags Raw -> FrameType -> [Frame] -> FrameType -> Plain
Plain (Word8 -> Flags Raw
forall a. Word8 -> Flags a
Flags Word8
0) FrameType
mypn [Frame
frame] FrameType
0
            ppkt :: PlainPacket
ppkt = Header -> Plain -> PlainPacket
PlainPacket Header
header Plain
plain
        FrameType
siz <- (FrameType, FrameType) -> FrameType
forall a b. (a, b) -> a
fst ((FrameType, FrameType) -> FrameType)
-> IO (FrameType, FrameType) -> IO FrameType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection
-> Ptr Word8
-> FrameType
-> PlainPacket
-> Maybe FrameType
-> IO (FrameType, FrameType)
encodePlainPacket Connection
conn Ptr Word8
sendBuf FrameType
bufsiz PlainPacket
ppkt Maybe FrameType
forall a. Maybe a
Nothing
        if FrameType
siz FrameType -> FrameType -> Bool
forall a. Ord a => a -> a -> Bool
>= FrameType
0 then do
            TimeMicrosecond
now <- IO TimeMicrosecond
getTimeMicrosecond
            Connection -> PlainPacket -> TimeMicrosecond -> IO ()
forall q pkt.
(KeepQlog q, Qlog pkt) =>
q -> pkt -> TimeMicrosecond -> IO ()
qlogSent Connection
conn PlainPacket
ppkt TimeMicrosecond
now
            FrameType -> IO FrameType
forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
siz
          else
            FrameType -> IO FrameType
forall (m :: * -> *) a. Monad m => a -> m a
return FrameType
0

closer :: Microseconds -> IO Int -> IO Int -> IO () -> IO ()
closer :: Microseconds -> IO FrameType -> IO FrameType -> IO () -> IO ()
closer (Microseconds FrameType
pto) IO FrameType
send IO FrameType
recv IO ()
hook = FrameType -> IO ()
forall t. (Eq t, Num t) => t -> IO ()
loop (FrameType
3 :: Int)
  where
    loop :: t -> IO ()
loop t
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    loop t
n = do
        FrameType
_ <- IO FrameType
send
        IO TimeMicrosecond
getTimeMicrosecond IO TimeMicrosecond -> (TimeMicrosecond -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Microseconds -> TimeMicrosecond -> IO ()
skip (FrameType -> Microseconds
Microseconds FrameType
pto)
        Maybe FrameType
mx <- Microseconds -> IO FrameType -> IO (Maybe FrameType)
forall a. Microseconds -> IO a -> IO (Maybe a)
timeout (FrameType -> Microseconds
Microseconds (FrameType
pto FrameType -> FrameType -> FrameType
forall a. Bits a => a -> FrameType -> a
.>>. FrameType
1)) IO FrameType
recv
        case Maybe FrameType
mx of
          Maybe FrameType
Nothing -> IO ()
hook
          Just FrameType
0  -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just FrameType
_  -> t -> IO ()
loop (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1)
    skip :: Microseconds -> TimeMicrosecond -> IO ()
skip tmo :: Microseconds
tmo@(Microseconds FrameType
duration) TimeMicrosecond
base = do
        Maybe FrameType
mx <- Microseconds -> IO FrameType -> IO (Maybe FrameType)
forall a. Microseconds -> IO a -> IO (Maybe a)
timeout Microseconds
tmo IO FrameType
recv
        case Maybe FrameType
mx of
          Maybe FrameType
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just FrameType
0  -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just FrameType
_  -> do
              Microseconds FrameType
elapsed <- TimeMicrosecond -> IO Microseconds
getElapsedTimeMicrosecond TimeMicrosecond
base
              let duration' :: FrameType
duration' = FrameType
duration FrameType -> FrameType -> FrameType
forall a. Num a => a -> a -> a
- FrameType
elapsed
              Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (FrameType
duration' FrameType -> FrameType -> Bool
forall a. Ord a => a -> a -> Bool
>= FrameType
5000) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Microseconds -> TimeMicrosecond -> IO ()
skip (FrameType -> Microseconds
Microseconds FrameType
duration') TimeMicrosecond
base