module Network.QUIC.Connector where

import Control.Concurrent.STM
import Data.IORef
import Network.QUIC.Types

class Connector a where
    getRole            :: a -> Role
    getEncryptionLevel :: a -> IO EncryptionLevel
    getMaxPacketSize   :: a -> IO Int
    getConnectionState :: a -> IO ConnectionState
    getPacketNumber    :: a -> IO PacketNumber
    getAlive           :: a -> IO Bool

----------------------------------------------------------------

data ConnState = ConnState {
    ConnState -> Role
role            :: Role
  , ConnState -> TVar ConnectionState
connectionState :: TVar ConnectionState
  , ConnState -> IORef PacketNumber
packetNumber    :: IORef PacketNumber   -- squeezing three to one
  , ConnState -> TVar EncryptionLevel
encryptionLevel :: TVar EncryptionLevel -- to synchronize
  , ConnState -> IORef PacketNumber
maxPacketSize   :: IORef Int
  -- Explicitly separated from 'ConnectionState'
  -- It seems that STM triggers a dead-lock if
  -- it is used in the close function of bracket.
  , ConnState -> IORef Bool
connectionAlive :: IORef Bool
  }

newConnState :: Role -> IO ConnState
newConnState :: Role -> IO ConnState
newConnState Role
rl =
    Role
-> TVar ConnectionState
-> IORef PacketNumber
-> TVar EncryptionLevel
-> IORef PacketNumber
-> IORef Bool
-> ConnState
ConnState Role
rl (TVar ConnectionState
 -> IORef PacketNumber
 -> TVar EncryptionLevel
 -> IORef PacketNumber
 -> IORef Bool
 -> ConnState)
-> IO (TVar ConnectionState)
-> IO
     (IORef PacketNumber
      -> TVar EncryptionLevel
      -> IORef PacketNumber
      -> IORef Bool
      -> ConnState)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConnectionState -> IO (TVar ConnectionState)
forall a. a -> IO (TVar a)
newTVarIO ConnectionState
Handshaking
                 IO
  (IORef PacketNumber
   -> TVar EncryptionLevel
   -> IORef PacketNumber
   -> IORef Bool
   -> ConnState)
-> IO (IORef PacketNumber)
-> IO
     (TVar EncryptionLevel
      -> IORef PacketNumber -> IORef Bool -> ConnState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PacketNumber -> IO (IORef PacketNumber)
forall a. a -> IO (IORef a)
newIORef PacketNumber
0
                 IO
  (TVar EncryptionLevel
   -> IORef PacketNumber -> IORef Bool -> ConnState)
-> IO (TVar EncryptionLevel)
-> IO (IORef PacketNumber -> IORef Bool -> ConnState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> EncryptionLevel -> IO (TVar EncryptionLevel)
forall a. a -> IO (TVar a)
newTVarIO EncryptionLevel
InitialLevel
                 IO (IORef PacketNumber -> IORef Bool -> ConnState)
-> IO (IORef PacketNumber) -> IO (IORef Bool -> ConnState)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PacketNumber -> IO (IORef PacketNumber)
forall a. a -> IO (IORef a)
newIORef PacketNumber
defaultQUICPacketSize
                 IO (IORef Bool -> ConnState) -> IO (IORef Bool) -> IO ConnState
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True

----------------------------------------------------------------

data Role = Client | Server deriving (Role -> Role -> Bool
(Role -> Role -> Bool) -> (Role -> Role -> Bool) -> Eq Role
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Role -> Role -> Bool
$c/= :: Role -> Role -> Bool
== :: Role -> Role -> Bool
$c== :: Role -> Role -> Bool
Eq, PacketNumber -> Role -> ShowS
[Role] -> ShowS
Role -> String
(PacketNumber -> Role -> ShowS)
-> (Role -> String) -> ([Role] -> ShowS) -> Show Role
forall a.
(PacketNumber -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Role] -> ShowS
$cshowList :: [Role] -> ShowS
show :: Role -> String
$cshow :: Role -> String
showsPrec :: PacketNumber -> Role -> ShowS
$cshowsPrec :: PacketNumber -> Role -> ShowS
Show)

isClient :: Connector a => a -> Bool
isClient :: a -> Bool
isClient a
conn = a -> Role
forall a. Connector a => a -> Role
getRole a
conn Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
Client

isServer :: Connector a => a -> Bool
isServer :: a -> Bool
isServer a
conn = a -> Role
forall a. Connector a => a -> Role
getRole a
conn Role -> Role -> Bool
forall a. Eq a => a -> a -> Bool
== Role
Server

----------------------------------------------------------------

data ConnectionState = Handshaking
                     | ReadyFor0RTT
                     | ReadyFor1RTT
                     | Established
                     deriving (ConnectionState -> ConnectionState -> Bool
(ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> Eq ConnectionState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ConnectionState -> ConnectionState -> Bool
$c/= :: ConnectionState -> ConnectionState -> Bool
== :: ConnectionState -> ConnectionState -> Bool
$c== :: ConnectionState -> ConnectionState -> Bool
Eq, Eq ConnectionState
Eq ConnectionState
-> (ConnectionState -> ConnectionState -> Ordering)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> Bool)
-> (ConnectionState -> ConnectionState -> ConnectionState)
-> (ConnectionState -> ConnectionState -> ConnectionState)
-> Ord ConnectionState
ConnectionState -> ConnectionState -> Bool
ConnectionState -> ConnectionState -> Ordering
ConnectionState -> ConnectionState -> ConnectionState
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ConnectionState -> ConnectionState -> ConnectionState
$cmin :: ConnectionState -> ConnectionState -> ConnectionState
max :: ConnectionState -> ConnectionState -> ConnectionState
$cmax :: ConnectionState -> ConnectionState -> ConnectionState
>= :: ConnectionState -> ConnectionState -> Bool
$c>= :: ConnectionState -> ConnectionState -> Bool
> :: ConnectionState -> ConnectionState -> Bool
$c> :: ConnectionState -> ConnectionState -> Bool
<= :: ConnectionState -> ConnectionState -> Bool
$c<= :: ConnectionState -> ConnectionState -> Bool
< :: ConnectionState -> ConnectionState -> Bool
$c< :: ConnectionState -> ConnectionState -> Bool
compare :: ConnectionState -> ConnectionState -> Ordering
$ccompare :: ConnectionState -> ConnectionState -> Ordering
$cp1Ord :: Eq ConnectionState
Ord, PacketNumber -> ConnectionState -> ShowS
[ConnectionState] -> ShowS
ConnectionState -> String
(PacketNumber -> ConnectionState -> ShowS)
-> (ConnectionState -> String)
-> ([ConnectionState] -> ShowS)
-> Show ConnectionState
forall a.
(PacketNumber -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionState] -> ShowS
$cshowList :: [ConnectionState] -> ShowS
show :: ConnectionState -> String
$cshow :: ConnectionState -> String
showsPrec :: PacketNumber -> ConnectionState -> ShowS
$cshowsPrec :: PacketNumber -> ConnectionState -> ShowS
Show)

isConnectionEstablished :: Connector a => a -> IO Bool
isConnectionEstablished :: a -> IO Bool
isConnectionEstablished a
conn = do
    ConnectionState
st <- a -> IO ConnectionState
forall a. Connector a => a -> IO ConnectionState
getConnectionState a
conn
    Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> IO Bool) -> Bool -> IO Bool
forall a b. (a -> b) -> a -> b
$ case ConnectionState
st of
      ConnectionState
Established -> Bool
True
      ConnectionState
_           -> Bool
False