{-# LANGUAGE MultiParamTypeClasses #-}
-- |
-- Module      : Network.TLS.Record.State
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Record.State
    ( CryptState(..)
    , CryptLevel(..)
    , HasCryptLevel(..)
    , MacState(..)
    , RecordOptions(..)
    , RecordState(..)
    , newRecordState
    , incrRecordState
    , RecordM
    , runRecordM
    , getRecordOptions
    , getRecordVersion
    , setRecordIV
    , withCompression
    , computeDigest
    , makeDigest
    , getBulk
    , getMacSequence
    ) where

import Control.Monad.State.Strict
import Network.TLS.Compression
import Network.TLS.Cipher
import Network.TLS.ErrT
import Network.TLS.Struct
import Network.TLS.Wire

import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Imports
import Network.TLS.Types

import qualified Data.ByteString as B

data CryptState = CryptState
    { CryptState -> BulkState
cstKey        :: !BulkState
    , CryptState -> ByteString
cstIV         :: !ByteString
    -- In TLS 1.2 or earlier, this holds mac secret.
    -- In TLS 1.3, this holds application traffic secret N.
    , CryptState -> ByteString
cstMacSecret  :: !ByteString
    } deriving (Int -> CryptState -> ShowS
[CryptState] -> ShowS
CryptState -> String
(Int -> CryptState -> ShowS)
-> (CryptState -> String)
-> ([CryptState] -> ShowS)
-> Show CryptState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptState] -> ShowS
$cshowList :: [CryptState] -> ShowS
show :: CryptState -> String
$cshow :: CryptState -> String
showsPrec :: Int -> CryptState -> ShowS
$cshowsPrec :: Int -> CryptState -> ShowS
Show)

newtype MacState = MacState
    { MacState -> Word64
msSequence :: Word64
    } deriving (Int -> MacState -> ShowS
[MacState] -> ShowS
MacState -> String
(Int -> MacState -> ShowS)
-> (MacState -> String) -> ([MacState] -> ShowS) -> Show MacState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MacState] -> ShowS
$cshowList :: [MacState] -> ShowS
show :: MacState -> String
$cshow :: MacState -> String
showsPrec :: Int -> MacState -> ShowS
$cshowsPrec :: Int -> MacState -> ShowS
Show)

data RecordOptions = RecordOptions
    { RecordOptions -> Version
recordVersion :: Version                -- version to use when sending/receiving
    , RecordOptions -> Bool
recordTLS13 :: Bool                     -- TLS13 record processing
    }

-- | TLS encryption level.
data CryptLevel
    = CryptInitial            -- ^ Unprotected traffic
    | CryptMasterSecret       -- ^ Protected with master secret (TLS < 1.3)
    | CryptEarlySecret        -- ^ Protected with early traffic secret (TLS 1.3)
    | CryptHandshakeSecret    -- ^ Protected with handshake traffic secret (TLS 1.3)
    | CryptApplicationSecret  -- ^ Protected with application traffic secret (TLS 1.3)
    deriving (CryptLevel -> CryptLevel -> Bool
(CryptLevel -> CryptLevel -> Bool)
-> (CryptLevel -> CryptLevel -> Bool) -> Eq CryptLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CryptLevel -> CryptLevel -> Bool
$c/= :: CryptLevel -> CryptLevel -> Bool
== :: CryptLevel -> CryptLevel -> Bool
$c== :: CryptLevel -> CryptLevel -> Bool
Eq,Int -> CryptLevel -> ShowS
[CryptLevel] -> ShowS
CryptLevel -> String
(Int -> CryptLevel -> ShowS)
-> (CryptLevel -> String)
-> ([CryptLevel] -> ShowS)
-> Show CryptLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptLevel] -> ShowS
$cshowList :: [CryptLevel] -> ShowS
show :: CryptLevel -> String
$cshow :: CryptLevel -> String
showsPrec :: Int -> CryptLevel -> ShowS
$cshowsPrec :: Int -> CryptLevel -> ShowS
Show)

class HasCryptLevel a where getCryptLevel :: proxy a -> CryptLevel
instance HasCryptLevel EarlySecret where getCryptLevel :: proxy EarlySecret -> CryptLevel
getCryptLevel proxy EarlySecret
_ = CryptLevel
CryptEarlySecret
instance HasCryptLevel HandshakeSecret where getCryptLevel :: proxy HandshakeSecret -> CryptLevel
getCryptLevel proxy HandshakeSecret
_ = CryptLevel
CryptHandshakeSecret
instance HasCryptLevel ApplicationSecret where getCryptLevel :: proxy ApplicationSecret -> CryptLevel
getCryptLevel proxy ApplicationSecret
_ = CryptLevel
CryptApplicationSecret

data RecordState = RecordState
    { RecordState -> Maybe Cipher
stCipher      :: Maybe Cipher
    , RecordState -> Compression
stCompression :: Compression
    , RecordState -> CryptLevel
stCryptLevel  :: !CryptLevel
    , RecordState -> CryptState
stCryptState  :: !CryptState
    , RecordState -> MacState
stMacState    :: !MacState
    } deriving (Int -> RecordState -> ShowS
[RecordState] -> ShowS
RecordState -> String
(Int -> RecordState -> ShowS)
-> (RecordState -> String)
-> ([RecordState] -> ShowS)
-> Show RecordState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RecordState] -> ShowS
$cshowList :: [RecordState] -> ShowS
show :: RecordState -> String
$cshow :: RecordState -> String
showsPrec :: Int -> RecordState -> ShowS
$cshowsPrec :: Int -> RecordState -> ShowS
Show)

newtype RecordM a = RecordM { RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM :: RecordOptions
                                         -> RecordState
                                         -> Either TLSError (a, RecordState) }

instance Applicative RecordM where
    pure :: a -> RecordM a
pure = a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: RecordM (a -> b) -> RecordM a -> RecordM b
(<*>) = RecordM (a -> b) -> RecordM a -> RecordM b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad RecordM where
    return :: a -> RecordM a
return a
a  = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st  -> (a, RecordState) -> Either TLSError (a, RecordState)
forall a b. b -> Either a b
Right (a
a, RecordState
st)
    RecordM a
m1 >>= :: RecordM a -> (a -> RecordM b) -> RecordM b
>>= a -> RecordM b
m2 = (RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (b, RecordState))
 -> RecordM b)
-> (RecordOptions
    -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
                    case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m1 RecordOptions
opt RecordState
st of
                        Left TLSError
err       -> TLSError -> Either TLSError (b, RecordState)
forall a b. a -> Either a b
Left TLSError
err
                        Right (a
a, RecordState
st2) -> RecordM b
-> RecordOptions -> RecordState -> Either TLSError (b, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (a -> RecordM b
m2 a
a) RecordOptions
opt RecordState
st2

instance Functor RecordM where
    fmap :: (a -> b) -> RecordM a -> RecordM b
fmap a -> b
f RecordM a
m = (RecordOptions -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (b, RecordState))
 -> RecordM b)
-> (RecordOptions
    -> RecordState -> Either TLSError (b, RecordState))
-> RecordM b
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
                case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
                    Left TLSError
err       -> TLSError -> Either TLSError (b, RecordState)
forall a b. a -> Either a b
Left TLSError
err
                    Right (a
a, RecordState
st2) -> (b, RecordState) -> Either TLSError (b, RecordState)
forall a b. b -> Either a b
Right (a -> b
f a
a, RecordState
st2)

getRecordOptions :: RecordM RecordOptions
getRecordOptions :: RecordM RecordOptions
getRecordOptions = (RecordOptions
 -> RecordState -> Either TLSError (RecordOptions, RecordState))
-> RecordM RecordOptions
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
  -> RecordState -> Either TLSError (RecordOptions, RecordState))
 -> RecordM RecordOptions)
-> (RecordOptions
    -> RecordState -> Either TLSError (RecordOptions, RecordState))
-> RecordM RecordOptions
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st -> (RecordOptions, RecordState)
-> Either TLSError (RecordOptions, RecordState)
forall a b. b -> Either a b
Right (RecordOptions
opt, RecordState
st)

getRecordVersion :: RecordM Version
getRecordVersion :: RecordM Version
getRecordVersion = RecordOptions -> Version
recordVersion (RecordOptions -> Version)
-> RecordM RecordOptions -> RecordM Version
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordOptions
getRecordOptions

instance MonadState RecordState RecordM where
    put :: RecordState -> RecordM ()
put RecordState
x = (RecordOptions -> RecordState -> Either TLSError ((), RecordState))
-> RecordM ()
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
  -> RecordState -> Either TLSError ((), RecordState))
 -> RecordM ())
-> (RecordOptions
    -> RecordState -> Either TLSError ((), RecordState))
-> RecordM ()
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_  RecordState
_  -> ((), RecordState) -> Either TLSError ((), RecordState)
forall a b. b -> Either a b
Right ((), RecordState
x)
    get :: RecordM RecordState
get   = (RecordOptions
 -> RecordState -> Either TLSError (RecordState, RecordState))
-> RecordM RecordState
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions
  -> RecordState -> Either TLSError (RecordState, RecordState))
 -> RecordM RecordState)
-> (RecordOptions
    -> RecordState -> Either TLSError (RecordState, RecordState))
-> RecordM RecordState
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_  RecordState
st -> (RecordState, RecordState)
-> Either TLSError (RecordState, RecordState)
forall a b. b -> Either a b
Right (RecordState
st, RecordState
st)
    state :: (RecordState -> (a, RecordState)) -> RecordM a
state RecordState -> (a, RecordState)
f = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> (a, RecordState) -> Either TLSError (a, RecordState)
forall a b. b -> Either a b
Right (RecordState -> (a, RecordState)
f RecordState
st)

instance MonadError TLSError RecordM where
    throwError :: TLSError -> RecordM a
throwError TLSError
e   = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> TLSError -> Either TLSError (a, RecordState)
forall a b. a -> Either a b
Left TLSError
e
    catchError :: RecordM a -> (TLSError -> RecordM a) -> RecordM a
catchError RecordM a
m TLSError -> RecordM a
f = (RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM ((RecordOptions -> RecordState -> Either TLSError (a, RecordState))
 -> RecordM a)
-> (RecordOptions
    -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
                        case RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
                            Left TLSError
err -> RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (TLSError -> RecordM a
f TLSError
err) RecordOptions
opt RecordState
st
                            Either TLSError (a, RecordState)
r        -> Either TLSError (a, RecordState)
r

newRecordState :: RecordState
newRecordState :: RecordState
newRecordState = RecordState :: Maybe Cipher
-> Compression
-> CryptLevel
-> CryptState
-> MacState
-> RecordState
RecordState
    { stCipher :: Maybe Cipher
stCipher      = Maybe Cipher
forall a. Maybe a
Nothing
    , stCompression :: Compression
stCompression = Compression
nullCompression
    , stCryptLevel :: CryptLevel
stCryptLevel  = CryptLevel
CryptInitial
    , stCryptState :: CryptState
stCryptState  = BulkState -> ByteString -> ByteString -> CryptState
CryptState BulkState
BulkStateUninitialized ByteString
B.empty ByteString
B.empty
    , stMacState :: MacState
stMacState    = Word64 -> MacState
MacState Word64
0
    }

incrRecordState :: RecordState -> RecordState
incrRecordState :: RecordState -> RecordState
incrRecordState RecordState
ts = RecordState
ts { stMacState :: MacState
stMacState = Word64 -> MacState
MacState (Word64
ms Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1) }
  where (MacState Word64
ms) = RecordState -> MacState
stMacState RecordState
ts

setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV ByteString
iv RecordState
st = RecordState
st { stCryptState :: CryptState
stCryptState = (RecordState -> CryptState
stCryptState RecordState
st) { cstIV :: ByteString
cstIV = ByteString
iv } }

withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression Compression -> (Compression, a)
f = do
    RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
    let (Compression
nc, a
a) = Compression -> (Compression, a)
f (Compression -> (Compression, a))
-> Compression -> (Compression, a)
forall a b. (a -> b) -> a -> b
$ RecordState -> Compression
stCompression RecordState
st
    RecordState -> RecordM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (RecordState -> RecordM ()) -> RecordState -> RecordM ()
forall a b. (a -> b) -> a -> b
$ RecordState
st { stCompression :: Compression
stCompression = Compression
nc }
    a -> RecordM a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

computeDigest :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest :: Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
tstate Header
hdr ByteString
content = (ByteString
digest, RecordState -> RecordState
incrRecordState RecordState
tstate)
  where digest :: ByteString
digest = HMAC
macF (CryptState -> ByteString
cstMacSecret CryptState
cst) ByteString
msg
        cst :: CryptState
cst    = RecordState -> CryptState
stCryptState RecordState
tstate
        cipher :: Cipher
cipher = String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tstate
        hashA :: Hash
hashA  = Cipher -> Hash
cipherHash Cipher
cipher
        encodedSeq :: ByteString
encodedSeq = Word64 -> ByteString
encodeWord64 (Word64 -> ByteString) -> Word64 -> ByteString
forall a b. (a -> b) -> a -> b
$ MacState -> Word64
msSequence (MacState -> Word64) -> MacState -> Word64
forall a b. (a -> b) -> a -> b
$ RecordState -> MacState
stMacState RecordState
tstate

        (HMAC
macF, ByteString
msg)
            | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS10 = (Hash -> HMAC
macSSL Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeaderNoVer Header
hdr, ByteString
content ])
            | Bool
otherwise   = (Hash -> HMAC
hmac Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr, ByteString
content ])

makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest Header
hdr ByteString
content = do
    Version
ver <- RecordM Version
getRecordVersion
    RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
    let (ByteString
digest, RecordState
nstate) = Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
st Header
hdr ByteString
content
    RecordState -> RecordM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put RecordState
nstate
    ByteString -> RecordM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
digest

getBulk :: RecordM Bulk
getBulk :: RecordM Bulk
getBulk = Cipher -> Bulk
cipherBulk (Cipher -> Bulk) -> (RecordState -> Cipher) -> RecordState -> Bulk
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher)
-> (RecordState -> Maybe Cipher) -> RecordState -> Cipher
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> Maybe Cipher
stCipher (RecordState -> Bulk) -> RecordM RecordState -> RecordM Bulk
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get

getMacSequence :: RecordM Word64
getMacSequence :: RecordM Word64
getMacSequence = MacState -> Word64
msSequence (MacState -> Word64)
-> (RecordState -> MacState) -> RecordState -> Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> MacState
stMacState (RecordState -> Word64) -> RecordM RecordState -> RecordM Word64
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get