{-# LANGUAGE MultiParamTypeClasses #-}
-- |
-- Module      : Network.TLS.Record.State
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Record.State
    ( CryptState(..)
    , CryptLevel(..)
    , HasCryptLevel(..)
    , MacState(..)
    , RecordOptions(..)
    , RecordState(..)
    , newRecordState
    , incrRecordState
    , RecordM
    , runRecordM
    , getRecordOptions
    , getRecordVersion
    , setRecordIV
    , withCompression
    , computeDigest
    , makeDigest
    , getBulk
    , getMacSequence
    ) where

import Control.Monad.State.Strict
import Network.TLS.Compression
import Network.TLS.Cipher
import Network.TLS.ErrT
import Network.TLS.Struct
import Network.TLS.Wire

import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Imports
import Network.TLS.Types

import qualified Data.ByteString as B

data CryptState = CryptState
    { CryptState -> BulkState
cstKey        :: !BulkState
    , CryptState -> ByteString
cstIV         :: !ByteString
    -- In TLS 1.2 or earlier, this holds mac secret.
    -- In TLS 1.3, this holds application traffic secret N.
    , CryptState -> ByteString
cstMacSecret  :: !ByteString
    } deriving (Int -> CryptState -> ShowS
[CryptState] -> ShowS
CryptState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptState] -> ShowS
$cshowList :: [CryptState] -> ShowS
show :: CryptState -> String
$cshow :: CryptState -> String
showsPrec :: Int -> CryptState -> ShowS
$cshowsPrec :: Int -> CryptState -> ShowS
Show)

newtype MacState = MacState
    { MacState -> Word64
msSequence :: Word64
    } deriving (Int -> MacState -> ShowS
[MacState] -> ShowS
MacState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MacState] -> ShowS
$cshowList :: [MacState] -> ShowS
show :: MacState -> String
$cshow :: MacState -> String
showsPrec :: Int -> MacState -> ShowS
$cshowsPrec :: Int -> MacState -> ShowS
Show)

data RecordOptions = RecordOptions
    { RecordOptions -> Version
recordVersion :: Version                -- version to use when sending/receiving
    , RecordOptions -> Bool
recordTLS13 :: Bool                     -- TLS13 record processing
    }

-- | TLS encryption level.
data CryptLevel
    = CryptInitial            -- ^ Unprotected traffic
    | CryptMasterSecret       -- ^ Protected with master secret (TLS < 1.3)
    | CryptEarlySecret        -- ^ Protected with early traffic secret (TLS 1.3)
    | CryptHandshakeSecret    -- ^ Protected with handshake traffic secret (TLS 1.3)
    | CryptApplicationSecret  -- ^ Protected with application traffic secret (TLS 1.3)
    deriving (CryptLevel -> CryptLevel -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CryptLevel -> CryptLevel -> Bool
$c/= :: CryptLevel -> CryptLevel -> Bool
== :: CryptLevel -> CryptLevel -> Bool
$c== :: CryptLevel -> CryptLevel -> Bool
Eq,Int -> CryptLevel -> ShowS
[CryptLevel] -> ShowS
CryptLevel -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptLevel] -> ShowS
$cshowList :: [CryptLevel] -> ShowS
show :: CryptLevel -> String
$cshow :: CryptLevel -> String
showsPrec :: Int -> CryptLevel -> ShowS
$cshowsPrec :: Int -> CryptLevel -> ShowS
Show)

class HasCryptLevel a where getCryptLevel :: proxy a -> CryptLevel
instance HasCryptLevel EarlySecret where getCryptLevel :: forall (proxy :: * -> *). proxy EarlySecret -> CryptLevel
getCryptLevel proxy EarlySecret
_ = CryptLevel
CryptEarlySecret
instance HasCryptLevel HandshakeSecret where getCryptLevel :: forall (proxy :: * -> *). proxy HandshakeSecret -> CryptLevel
getCryptLevel proxy HandshakeSecret
_ = CryptLevel
CryptHandshakeSecret
instance HasCryptLevel ApplicationSecret where getCryptLevel :: forall (proxy :: * -> *). proxy ApplicationSecret -> CryptLevel
getCryptLevel proxy ApplicationSecret
_ = CryptLevel
CryptApplicationSecret

data RecordState = RecordState
    { RecordState -> Maybe Cipher
stCipher      :: Maybe Cipher
    , RecordState -> Compression
stCompression :: Compression
    , RecordState -> CryptLevel
stCryptLevel  :: !CryptLevel
    , RecordState -> CryptState
stCryptState  :: !CryptState
    , RecordState -> MacState
stMacState    :: !MacState
    } deriving (Int -> RecordState -> ShowS
[RecordState] -> ShowS
RecordState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RecordState] -> ShowS
$cshowList :: [RecordState] -> ShowS
show :: RecordState -> String
$cshow :: RecordState -> String
showsPrec :: Int -> RecordState -> ShowS
$cshowsPrec :: Int -> RecordState -> ShowS
Show)

newtype RecordM a = RecordM { forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM :: RecordOptions
                                         -> RecordState
                                         -> Either TLSError (a, RecordState) }

instance Applicative RecordM where
    pure :: forall a. a -> RecordM a
pure = forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: forall a b. RecordM (a -> b) -> RecordM a -> RecordM b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad RecordM where
    return :: forall a. a -> RecordM a
return a
a  = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st  -> forall a b. b -> Either a b
Right (a
a, RecordState
st)
    RecordM a
m1 >>= :: forall a b. RecordM a -> (a -> RecordM b) -> RecordM b
>>= a -> RecordM b
m2 = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
                    case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m1 RecordOptions
opt RecordState
st of
                        Left TLSError
err       -> forall a b. a -> Either a b
Left TLSError
err
                        Right (a
a, RecordState
st2) -> forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (a -> RecordM b
m2 a
a) RecordOptions
opt RecordState
st2

instance Functor RecordM where
    fmap :: forall a b. (a -> b) -> RecordM a -> RecordM b
fmap a -> b
f RecordM a
m = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
                case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
                    Left TLSError
err       -> forall a b. a -> Either a b
Left TLSError
err
                    Right (a
a, RecordState
st2) -> forall a b. b -> Either a b
Right (a -> b
f a
a, RecordState
st2)

getRecordOptions :: RecordM RecordOptions
getRecordOptions :: RecordM RecordOptions
getRecordOptions = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st -> forall a b. b -> Either a b
Right (RecordOptions
opt, RecordState
st)

getRecordVersion :: RecordM Version
getRecordVersion :: RecordM Version
getRecordVersion = RecordOptions -> Version
recordVersion forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordOptions
getRecordOptions

instance MonadState RecordState RecordM where
    put :: RecordState -> RecordM ()
put RecordState
x = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_  RecordState
_  -> forall a b. b -> Either a b
Right ((), RecordState
x)
    get :: RecordM RecordState
get   = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_  RecordState
st -> forall a b. b -> Either a b
Right (RecordState
st, RecordState
st)
    state :: forall a. (RecordState -> (a, RecordState)) -> RecordM a
state RecordState -> (a, RecordState)
f = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> forall a b. b -> Either a b
Right (RecordState -> (a, RecordState)
f RecordState
st)

instance MonadError TLSError RecordM where
    throwError :: forall a. TLSError -> RecordM a
throwError TLSError
e   = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> forall a b. a -> Either a b
Left TLSError
e
    catchError :: forall a. RecordM a -> (TLSError -> RecordM a) -> RecordM a
catchError RecordM a
m TLSError -> RecordM a
f = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
                        case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
                            Left TLSError
err -> forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (TLSError -> RecordM a
f TLSError
err) RecordOptions
opt RecordState
st
                            Either TLSError (a, RecordState)
r        -> Either TLSError (a, RecordState)
r

newRecordState :: RecordState
newRecordState :: RecordState
newRecordState = RecordState
    { stCipher :: Maybe Cipher
stCipher      = forall a. Maybe a
Nothing
    , stCompression :: Compression
stCompression = Compression
nullCompression
    , stCryptLevel :: CryptLevel
stCryptLevel  = CryptLevel
CryptInitial
    , stCryptState :: CryptState
stCryptState  = BulkState -> ByteString -> ByteString -> CryptState
CryptState BulkState
BulkStateUninitialized ByteString
B.empty ByteString
B.empty
    , stMacState :: MacState
stMacState    = Word64 -> MacState
MacState Word64
0
    }

incrRecordState :: RecordState -> RecordState
incrRecordState :: RecordState -> RecordState
incrRecordState RecordState
ts = RecordState
ts { stMacState :: MacState
stMacState = Word64 -> MacState
MacState (Word64
ms forall a. Num a => a -> a -> a
+ Word64
1) }
  where (MacState Word64
ms) = RecordState -> MacState
stMacState RecordState
ts

setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV ByteString
iv RecordState
st = RecordState
st { stCryptState :: CryptState
stCryptState = (RecordState -> CryptState
stCryptState RecordState
st) { cstIV :: ByteString
cstIV = ByteString
iv } }

withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression :: forall a. (Compression -> (Compression, a)) -> RecordM a
withCompression Compression -> (Compression, a)
f = do
    RecordState
st <- forall s (m :: * -> *). MonadState s m => m s
get
    let (Compression
nc, a
a) = Compression -> (Compression, a)
f forall a b. (a -> b) -> a -> b
$ RecordState -> Compression
stCompression RecordState
st
    forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ RecordState
st { stCompression :: Compression
stCompression = Compression
nc }
    forall (m :: * -> *) a. Monad m => a -> m a
return a
a

computeDigest :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest :: Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
tstate Header
hdr ByteString
content = (ByteString
digest, RecordState -> RecordState
incrRecordState RecordState
tstate)
  where digest :: ByteString
digest = HMAC
macF (CryptState -> ByteString
cstMacSecret CryptState
cst) ByteString
msg
        cst :: CryptState
cst    = RecordState -> CryptState
stCryptState RecordState
tstate
        cipher :: Cipher
cipher = forall a. String -> Maybe a -> a
fromJust String
"cipher" forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tstate
        hashA :: Hash
hashA  = Cipher -> Hash
cipherHash Cipher
cipher
        encodedSeq :: ByteString
encodedSeq = Word64 -> ByteString
encodeWord64 forall a b. (a -> b) -> a -> b
$ MacState -> Word64
msSequence forall a b. (a -> b) -> a -> b
$ RecordState -> MacState
stMacState RecordState
tstate

        (HMAC
macF, ByteString
msg)
            | Version
ver forall a. Ord a => a -> a -> Bool
< Version
TLS10 = (Hash -> HMAC
macSSL Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeaderNoVer Header
hdr, ByteString
content ])
            | Bool
otherwise   = (Hash -> HMAC
hmac Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr, ByteString
content ])

makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest Header
hdr ByteString
content = do
    Version
ver <- RecordM Version
getRecordVersion
    RecordState
st <- forall s (m :: * -> *). MonadState s m => m s
get
    let (ByteString
digest, RecordState
nstate) = Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
st Header
hdr ByteString
content
    forall s (m :: * -> *). MonadState s m => s -> m ()
put RecordState
nstate
    forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
digest

getBulk :: RecordM Bulk
getBulk :: RecordM Bulk
getBulk = Cipher -> Bulk
cipherBulk forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. String -> Maybe a -> a
fromJust String
"cipher" forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> Maybe Cipher
stCipher forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *). MonadState s m => m s
get

getMacSequence :: RecordM Word64
getMacSequence :: RecordM Word64
getMacSequence = MacState -> Word64
msSequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> MacState
stMacState forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *). MonadState s m => m s
get