{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.State13 (
    CryptLevel (
        CryptEarlySecret,
        CryptHandshakeSecret,
        CryptApplicationSecret
    ),
    TrafficSecret,
    getTxRecordState,
    getRxRecordState,
    setTxRecordState,
    setRxRecordState,
    getTxLevel,
    getRxLevel,
    clearTxRecordState,
    clearRxRecordState,
    setHelloParameters13,
    transcriptHash,
    wrapAsMessageHash13,
    PendingRecvAction (..),
    setPendingRecvActions,
    popPendingRecvAction,
) where

import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef

import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types

getTxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxRecordState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxTxRecordState

getRxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxRxRecordState

getXState
    :: Context
    -> (Context -> MVar RecordState)
    -> IO (Hash, Cipher, CryptLevel, ByteString)
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
func = do
    RecordState
tx <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
func Context
ctx)
    let usedCipher :: Cipher
usedCipher = Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tx
        usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
        level :: CryptLevel
level = RecordState -> CryptLevel
stCryptLevel RecordState
tx
        secret :: ByteString
secret = CryptState -> ByteString
cstMacSecret (CryptState -> ByteString) -> CryptState -> ByteString
forall a b. (a -> b) -> a -> b
$ RecordState -> CryptState
stCryptState RecordState
tx
    (Hash, Cipher, CryptLevel, ByteString)
-> IO (Hash, Cipher, CryptLevel, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Hash
usedHash, Cipher
usedCipher, CryptLevel
level, ByteString
secret)

-- In the case of QUIC, stCipher is Nothing.
-- So, fromJust causes an error.
getTxLevel :: Context -> IO CryptLevel
getTxLevel :: Context -> IO CryptLevel
getTxLevel Context
ctx = Context -> (Context -> MVar RecordState) -> IO CryptLevel
getXLevel Context
ctx Context -> MVar RecordState
ctxTxRecordState

getRxLevel :: Context -> IO CryptLevel
getRxLevel :: Context -> IO CryptLevel
getRxLevel Context
ctx = Context -> (Context -> MVar RecordState) -> IO CryptLevel
getXLevel Context
ctx Context -> MVar RecordState
ctxRxRecordState

getXLevel
    :: Context
    -> (Context -> MVar RecordState)
    -> IO CryptLevel
getXLevel :: Context -> (Context -> MVar RecordState) -> IO CryptLevel
getXLevel Context
ctx Context -> MVar RecordState
func = do
    RecordState
tx <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
func Context
ctx)
    CryptLevel -> IO CryptLevel
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptLevel -> IO CryptLevel) -> CryptLevel -> IO CryptLevel
forall a b. (a -> b) -> a -> b
$ RecordState -> CryptLevel
stCryptLevel RecordState
tx

class TrafficSecret ty where
    fromTrafficSecret :: ty -> (CryptLevel, ByteString)

instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
    fromTrafficSecret :: AnyTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: AnyTrafficSecret a
prx@(AnyTrafficSecret ByteString
s) = (AnyTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel AnyTrafficSecret a
prx, ByteString
s)

instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
    fromTrafficSecret :: ClientTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ClientTrafficSecret a
prx@(ClientTrafficSecret ByteString
s) = (ClientTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel ClientTrafficSecret a
prx, ByteString
s)

instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
    fromTrafficSecret :: ServerTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ServerTrafficSecret a
prx@(ServerTrafficSecret ByteString
s) = (ServerTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel ServerTrafficSecret a
prx, ByteString
s)

setTxRecordState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setTxRecordState :: forall ty.
TrafficSecret ty =>
Context -> Hash -> Cipher -> ty -> IO ()
setTxRecordState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxTxRecordState BulkDirection
BulkEncrypt

setRxRecordState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setRxRecordState :: forall ty.
TrafficSecret ty =>
Context -> Hash -> Cipher -> ty -> IO ()
setRxRecordState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxRxRecordState BulkDirection
BulkDecrypt

setXState
    :: TrafficSecret ty
    => (Context -> MVar RecordState)
    -> BulkDirection
    -> Context
    -> Hash
    -> Cipher
    -> ty
    -> IO ()
setXState :: forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher ty
ts =
    let (CryptLevel
lvl, ByteString
secret) = ty -> (CryptLevel, ByteString)
forall ty. TrafficSecret ty => ty -> (CryptLevel, ByteString)
fromTrafficSecret ty
ts
     in (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret

setXState'
    :: (Context -> MVar RecordState)
    -> BulkDirection
    -> Context
    -> Hash
    -> Cipher
    -> CryptLevel
    -> ByteString
    -> IO ()
setXState' :: (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret =
    MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt)
  where
    bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
    keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
    ivSize :: Int
ivSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
8 (Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
    key :: ByteString
key = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"key" ByteString
"" Int
keySize
    iv :: ByteString
iv = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"iv" ByteString
"" Int
ivSize
    cst :: CryptState
cst =
        CryptState
            { cstKey :: BulkState
cstKey = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk BulkDirection
encOrDec ByteString
key
            , cstIV :: ByteString
cstIV = ByteString
iv
            , cstMacSecret :: ByteString
cstMacSecret = ByteString
secret
            }
    rt :: RecordState
rt =
        RecordState
            { stCryptState :: CryptState
stCryptState = CryptState
cst
            , stMacState :: MacState
stMacState = MacState{msSequence :: Word64
msSequence = Word64
0}
            , stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
lvl
            , stCipher :: Maybe Cipher
stCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
            , stCompression :: Compression
stCompression = Compression
nullCompression
            }

clearTxRecordState :: Context -> IO ()
clearTxRecordState :: Context -> IO ()
clearTxRecordState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxTxRecordState

clearRxRecordState :: Context -> IO ()
clearRxRecordState :: Context -> IO ()
clearRxRecordState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxRxRecordState

clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
func Context
ctx =
    MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
rt -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt{stCipher = Nothing})

setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 Cipher
cipher = do
    HandshakeState
hst <- HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
    case HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst of
        Maybe Cipher
Nothing -> do
            HandshakeState -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
                HandshakeState
hst
                    { hstPendingCipher = Just cipher
                    , hstPendingCompression = nullCompression
                    , hstHandshakeDigest = updateDigest $ hstHandshakeDigest hst
                    }
            Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
        Just Cipher
oldcipher
            | Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
oldcipher -> Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
            | Bool
otherwise ->
                Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$
                    TLSError -> Either TLSError ()
forall a b. a -> Either a b
Left (TLSError -> Either TLSError ()) -> TLSError -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$
                        String -> AlertDescription -> TLSError
Error_Protocol String
"TLS 1.3 cipher changed after hello retry" AlertDescription
IllegalParameter
  where
    hashAlg :: Hash
hashAlg = Cipher -> Hash
cipherHash Cipher
cipher
    updateDigest :: HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeMessages [ByteString]
bytes) = HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
    updateDigest (HandshakeDigestContext HashCtx
_) = String -> HandshakeDigest
forall a. HasCallStack => String -> a
error String
"cannot initialize digest with another digest"

-- When a HelloRetryRequest is sent or received, the existing transcript must be
-- wrapped in a "message_hash" construct.  See RFC 8446 section 4.4.1.  This
-- applies to key-schedule computations as well as the ones for PSK binders.
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
    Cipher
cipher <- HandshakeM Cipher
getPendingCipher
    Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest (Cipher -> Hash
cipherHash Cipher
cipher) ByteString -> ByteString
foldFunc
  where
    foldFunc :: ByteString -> ByteString
foldFunc ByteString
dig =
        [ByteString] -> ByteString
B.concat
            [ ByteString
"\254\0\0"
            , Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
dig)
            , ByteString
dig
            ]

transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash :: forall (m :: * -> *). MonadIO m => Context -> m ByteString
transcriptHash Context
ctx = do
    HandshakeState
hst <- Maybe HandshakeState -> HandshakeState
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe HandshakeState -> HandshakeState)
-> m (Maybe HandshakeState) -> m HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> m (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
    case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
        HandshakeDigestContext HashCtx
hashCtx -> ByteString -> m ByteString
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> m ByteString) -> ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString
hashFinal HashCtx
hashCtx
        HandshakeMessages [ByteString]
_ -> String -> m ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized handshake digest"

setPendingRecvActions :: Context -> [PendingRecvAction] -> IO ()
setPendingRecvActions :: Context -> [PendingRecvAction] -> IO ()
setPendingRecvActions Context
ctx = IORef [PendingRecvAction] -> [PendingRecvAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef [PendingRecvAction]
ctxPendingRecvActions Context
ctx)

popPendingRecvAction :: Context -> IO (Maybe PendingRecvAction)
popPendingRecvAction :: Context -> IO (Maybe PendingRecvAction)
popPendingRecvAction Context
ctx = do
    let ref :: IORef [PendingRecvAction]
ref = Context -> IORef [PendingRecvAction]
ctxPendingRecvActions Context
ctx
    [PendingRecvAction]
actions <- IORef [PendingRecvAction] -> IO [PendingRecvAction]
forall a. IORef a -> IO a
readIORef IORef [PendingRecvAction]
ref
    case [PendingRecvAction]
actions of
        PendingRecvAction
bs : [PendingRecvAction]
bss -> IORef [PendingRecvAction] -> [PendingRecvAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [PendingRecvAction]
ref [PendingRecvAction]
bss IO ()
-> IO (Maybe PendingRecvAction) -> IO (Maybe PendingRecvAction)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe PendingRecvAction -> IO (Maybe PendingRecvAction)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PendingRecvAction -> Maybe PendingRecvAction
forall a. a -> Maybe a
Just PendingRecvAction
bs)
        [] -> Maybe PendingRecvAction -> IO (Maybe PendingRecvAction)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PendingRecvAction
forall a. Maybe a
Nothing