-- |
-- Module      : Network.TLS.Record.Disengage
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Disengage a record from the Record layer.
-- The record is decrypted, checked for integrity and then decompressed.
--
-- Starting with TLS v1.3, only the "null" compression method is negotiated in
-- the handshake, so the decompression step will be a no-op.  Decryption and
-- integrity verification are performed using an AEAD cipher only.
--
{-# LANGUAGE FlexibleContexts #-}

module Network.TLS.Record.Disengage
        ( disengageRecord
        ) where

import Control.Monad.State.Strict

import Crypto.Cipher.Types (AuthTag(..))
import Network.TLS.Struct
import Network.TLS.ErrT
import Network.TLS.Cap
import Network.TLS.Record.State
import Network.TLS.Record.Types
import Network.TLS.Cipher
import Network.TLS.Crypto
import Network.TLS.Compression
import Network.TLS.Util
import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.Imports
import qualified Data.ByteString as B
import qualified Data.ByteArray as B (convert, xor)

disengageRecord :: Record Ciphertext -> RecordM (Record Plaintext)
disengageRecord :: Record Ciphertext -> RecordM (Record Plaintext)
disengageRecord = Record Ciphertext -> RecordM (Record Compressed)
decryptRecord (Record Ciphertext -> RecordM (Record Compressed))
-> (Record Compressed -> RecordM (Record Plaintext))
-> Record Ciphertext
-> RecordM (Record Plaintext)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Record Compressed -> RecordM (Record Plaintext)
uncompressRecord

uncompressRecord :: Record Compressed -> RecordM (Record Plaintext)
uncompressRecord :: Record Compressed -> RecordM (Record Plaintext)
uncompressRecord Record Compressed
record = Record Compressed
-> (Fragment Compressed -> RecordM (Fragment Plaintext))
-> RecordM (Record Plaintext)
forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Compressed
record ((Fragment Compressed -> RecordM (Fragment Plaintext))
 -> RecordM (Record Plaintext))
-> (Fragment Compressed -> RecordM (Fragment Plaintext))
-> RecordM (Record Plaintext)
forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Compressed -> RecordM (Fragment Plaintext)
fragmentUncompress ((ByteString -> RecordM ByteString)
 -> Fragment Compressed -> RecordM (Fragment Plaintext))
-> (ByteString -> RecordM ByteString)
-> Fragment Compressed
-> RecordM (Fragment Plaintext)
forall a b. (a -> b) -> a -> b
$ \ByteString
bytes ->
    (Compression -> (Compression, ByteString)) -> RecordM ByteString
forall a. (Compression -> (Compression, a)) -> RecordM a
withCompression ((Compression -> (Compression, ByteString)) -> RecordM ByteString)
-> (Compression -> (Compression, ByteString)) -> RecordM ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Compression -> (Compression, ByteString)
compressionInflate ByteString
bytes

decryptRecord :: Record Ciphertext -> RecordM (Record Compressed)
decryptRecord :: Record Ciphertext -> RecordM (Record Compressed)
decryptRecord record :: Record Ciphertext
record@(Record ProtocolType
ct Version
ver Fragment Ciphertext
fragment) = do
    RecordState
st <- RecordM RecordState
forall s (m :: * -> *). MonadState s m => m s
get
    case RecordState -> Maybe Cipher
stCipher RecordState
st of
        Maybe Cipher
Nothing -> RecordM (Record Compressed)
noDecryption
        Maybe Cipher
_       -> do
            RecordOptions
recOpts <- RecordM RecordOptions
getRecordOptions
            let mver :: Version
mver = RecordOptions -> Version
recordVersion RecordOptions
recOpts
            if RecordOptions -> Bool
recordTLS13 RecordOptions
recOpts
                then Version -> ByteString -> RecordState -> RecordM (Record Compressed)
decryptData13 Version
mver (Fragment Ciphertext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Ciphertext
fragment) RecordState
st
                else Record Ciphertext
-> (Fragment Ciphertext -> RecordM (Fragment Compressed))
-> RecordM (Record Compressed)
forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Ciphertext
record ((Fragment Ciphertext -> RecordM (Fragment Compressed))
 -> RecordM (Record Compressed))
-> (Fragment Ciphertext -> RecordM (Fragment Compressed))
-> RecordM (Record Compressed)
forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Ciphertext -> RecordM (Fragment Compressed)
fragmentUncipher ((ByteString -> RecordM ByteString)
 -> Fragment Ciphertext -> RecordM (Fragment Compressed))
-> (ByteString -> RecordM ByteString)
-> Fragment Ciphertext
-> RecordM (Fragment Compressed)
forall a b. (a -> b) -> a -> b
$ \ByteString
e ->
                        Version
-> Record Ciphertext
-> ByteString
-> RecordState
-> RecordM ByteString
decryptData Version
mver Record Ciphertext
record ByteString
e RecordState
st
  where
    noDecryption :: RecordM (Record Compressed)
noDecryption = Record Ciphertext
-> (Fragment Ciphertext -> RecordM (Fragment Compressed))
-> RecordM (Record Compressed)
forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Ciphertext
record ((Fragment Ciphertext -> RecordM (Fragment Compressed))
 -> RecordM (Record Compressed))
-> (Fragment Ciphertext -> RecordM (Fragment Compressed))
-> RecordM (Record Compressed)
forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Ciphertext -> RecordM (Fragment Compressed)
fragmentUncipher ByteString -> RecordM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return
    decryptData13 :: Version -> ByteString -> RecordState -> RecordM (Record Compressed)
decryptData13 Version
mver ByteString
e RecordState
st = case ProtocolType
ct of
      ProtocolType
ProtocolType_AppData -> do
          ByteString
inner <- Version
-> Record Ciphertext
-> ByteString
-> RecordState
-> RecordM ByteString
decryptData Version
mver Record Ciphertext
record ByteString
e RecordState
st
          case ByteString -> Either String (ProtocolType, ByteString)
unInnerPlaintext ByteString
inner of
            Left String
message   -> TLSError -> RecordM (Record Compressed)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> RecordM (Record Compressed))
-> TLSError -> RecordM (Record Compressed)
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
message, Bool
True, AlertDescription
UnexpectedMessage)
            Right (ProtocolType
ct', ByteString
d) -> Record Compressed -> RecordM (Record Compressed)
forall (m :: * -> *) a. Monad m => a -> m a
return (Record Compressed -> RecordM (Record Compressed))
-> Record Compressed -> RecordM (Record Compressed)
forall a b. (a -> b) -> a -> b
$ ProtocolType -> Version -> Fragment Compressed -> Record Compressed
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
ct' Version
ver (ByteString -> Fragment Compressed
fragmentCompressed ByteString
d)
      ProtocolType
ProtocolType_ChangeCipherSpec -> RecordM (Record Compressed)
noDecryption
      ProtocolType
ProtocolType_Alert            -> RecordM (Record Compressed)
noDecryption
      ProtocolType
_                             -> TLSError -> RecordM (Record Compressed)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> RecordM (Record Compressed))
-> TLSError -> RecordM (Record Compressed)
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"illegal plain text", Bool
True, AlertDescription
UnexpectedMessage)

unInnerPlaintext :: ByteString -> Either String (ProtocolType, ByteString)
unInnerPlaintext :: ByteString -> Either String (ProtocolType, ByteString)
unInnerPlaintext ByteString
inner =
    case ByteString -> Maybe (ByteString, Word8)
B.unsnoc ByteString
dc of
        Maybe (ByteString, Word8)
Nothing         -> String -> Either String (ProtocolType, ByteString)
forall a b. a -> Either a b
Left (String -> Either String (ProtocolType, ByteString))
-> String -> Either String (ProtocolType, ByteString)
forall a b. (a -> b) -> a -> b
$ Word8 -> String
forall a. Show a => a -> String
unknownContentType13 (Word8
0 :: Word8)
        Just (ByteString
bytes,Word8
c)  ->
            case Word8 -> Maybe ProtocolType
forall a. TypeValuable a => Word8 -> Maybe a
valToType Word8
c of
                Maybe ProtocolType
Nothing -> String -> Either String (ProtocolType, ByteString)
forall a b. a -> Either a b
Left (String -> Either String (ProtocolType, ByteString))
-> String -> Either String (ProtocolType, ByteString)
forall a b. (a -> b) -> a -> b
$ Word8 -> String
forall a. Show a => a -> String
unknownContentType13 Word8
c
                Just ProtocolType
ct
                    | ByteString -> Bool
B.null ByteString
bytes Bool -> Bool -> Bool
&& ProtocolType
ct ProtocolType -> [ProtocolType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ProtocolType]
nonEmptyContentTypes ->
                        String -> Either String (ProtocolType, ByteString)
forall a b. a -> Either a b
Left (String
"empty " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ProtocolType -> String
forall a. Show a => a -> String
show ProtocolType
ct String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" record disallowed")
                    | Bool
otherwise -> (ProtocolType, ByteString)
-> Either String (ProtocolType, ByteString)
forall a b. b -> Either a b
Right (ProtocolType
ct, ByteString
bytes)
  where
    (ByteString
dc,ByteString
_pad) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
B.spanEnd (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0) ByteString
inner
    nonEmptyContentTypes :: [ProtocolType]
nonEmptyContentTypes   = [ ProtocolType
ProtocolType_Handshake, ProtocolType
ProtocolType_Alert ]
    unknownContentType13 :: a -> String
unknownContentType13 a
c = String
"unknown TLS 1.3 content type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
c

getCipherData :: Record a -> CipherData -> RecordM ByteString
getCipherData :: Record a -> CipherData -> RecordM ByteString
getCipherData (Record ProtocolType
pt Version
ver Fragment a
_) CipherData
cdata = do
    -- check if the MAC is valid.
    Bool
macValid <- case CipherData -> Maybe ByteString
cipherDataMAC CipherData
cdata of
        Maybe ByteString
Nothing     -> Bool -> RecordM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        Just ByteString
digest -> do
            let new_hdr :: Header
new_hdr = ProtocolType -> Version -> Word16 -> Header
Header ProtocolType
pt Version
ver (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word16) -> Int -> Word16
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length (ByteString -> Int) -> ByteString -> Int
forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherDataContent CipherData
cdata)
            ByteString
expected_digest <- Header -> ByteString -> RecordM ByteString
makeDigest Header
new_hdr (ByteString -> RecordM ByteString)
-> ByteString -> RecordM ByteString
forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherDataContent CipherData
cdata
            Bool -> RecordM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
expected_digest ByteString -> ByteString -> Bool
`bytesEq` ByteString
digest)

    -- check if the padding is filled with the correct pattern if it exists
    -- (before TLS10 this checks instead that the padding length is minimal)
    Bool
paddingValid <- case CipherData -> Maybe (ByteString, Int)
cipherDataPadding CipherData
cdata of
        Maybe (ByteString, Int)
Nothing           -> Bool -> RecordM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        Just (ByteString
pad, Int
blksz) -> do
            Version
cver <- RecordM Version
getRecordVersion
            let b :: Int
b = ByteString -> Int
B.length ByteString
pad Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
            Bool -> RecordM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> RecordM Bool) -> Bool -> RecordM Bool
forall a b. (a -> b) -> a -> b
$ if Version
cver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS10
                then Int
b Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
blksz
                else Int -> Word8 -> ByteString
B.replicate (ByteString -> Int
B.length ByteString
pad) (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
b) ByteString -> ByteString -> Bool
`bytesEq` ByteString
pad

    Bool -> RecordM () -> RecordM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Bool
macValid Bool -> Bool -> Bool
&&! Bool
paddingValid) (RecordM () -> RecordM ()) -> RecordM () -> RecordM ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> RecordM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> RecordM ()) -> TLSError -> RecordM ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"bad record mac", Bool
True, AlertDescription
BadRecordMac)

    ByteString -> RecordM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> RecordM ByteString)
-> ByteString -> RecordM ByteString
forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherDataContent CipherData
cdata

decryptData :: Version -> Record Ciphertext -> ByteString -> RecordState -> RecordM ByteString
decryptData :: Version
-> Record Ciphertext
-> ByteString
-> RecordState
-> RecordM ByteString
decryptData Version
ver Record Ciphertext
record ByteString
econtent RecordState
tst = BulkState -> RecordM ByteString
decryptOf (CryptState -> BulkState
cstKey CryptState
cst)
  where cipher :: Cipher
cipher     = String -> Maybe Cipher -> Cipher
forall a. String -> Maybe a -> a
fromJust String
"cipher" (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tst
        bulk :: Bulk
bulk       = Cipher -> Bulk
cipherBulk Cipher
cipher
        cst :: CryptState
cst        = RecordState -> CryptState
stCryptState RecordState
tst
        macSize :: Int
macSize    = Hash -> Int
hashDigestSize (Hash -> Int) -> Hash -> Int
forall a b. (a -> b) -> a -> b
$ Cipher -> Hash
cipherHash Cipher
cipher
        blockSize :: Int
blockSize  = Bulk -> Int
bulkBlockSize Bulk
bulk
        econtentLen :: Int
econtentLen = ByteString -> Int
B.length ByteString
econtent

        explicitIV :: Bool
explicitIV = Version -> Bool
hasExplicitBlockIV Version
ver

        sanityCheckError :: RecordM a
sanityCheckError = TLSError -> RecordM a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> TLSError
Error_Packet String
"encrypted content too small for encryption parameters")

        decryptOf :: BulkState -> RecordM ByteString
        decryptOf :: BulkState -> RecordM ByteString
decryptOf (BulkStateBlock BulkBlock
decryptF) = do
            let minContent :: Int
minContent = (if Bool
explicitIV then Bulk -> Int
bulkIVSize Bulk
bulk else Int
0) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (Int
macSize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
blockSize

            -- check if we have enough bytes to cover the minimum for this cipher
            Bool -> RecordM () -> RecordM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int
econtentLen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
blockSize) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 Bool -> Bool -> Bool
|| Int
econtentLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minContent) RecordM ()
forall a. RecordM a
sanityCheckError

            {- update IV -}
            (ByteString
iv, ByteString
econtent') <- if Bool
explicitIV
                                  then ByteString -> (Int, Int) -> RecordM (ByteString, ByteString)
forall (m :: * -> *).
MonadError TLSError m =>
ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2o ByteString
econtent (Bulk -> Int
bulkIVSize Bulk
bulk, Int
econtentLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Bulk -> Int
bulkIVSize Bulk
bulk)
                                  else (ByteString, ByteString) -> RecordM (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptState -> ByteString
cstIV CryptState
cst, ByteString
econtent)
            let (ByteString
content', ByteString
iv') = BulkBlock
decryptF ByteString
iv ByteString
econtent'
            (RecordState -> RecordState) -> RecordM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RecordState -> RecordState) -> RecordM ())
-> (RecordState -> RecordState) -> RecordM ()
forall a b. (a -> b) -> a -> b
$ \RecordState
txs -> RecordState
txs { stCryptState :: CryptState
stCryptState = CryptState
cst { cstIV :: ByteString
cstIV = ByteString
iv' } }

            let paddinglength :: Int
paddinglength = Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Word8
B.last ByteString
content') Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
            let contentlen :: Int
contentlen = ByteString -> Int
B.length ByteString
content' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
paddinglength Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
macSize
            (ByteString
content, ByteString
mac, ByteString
padding) <- ByteString
-> (Int, Int, Int) -> RecordM (ByteString, ByteString, ByteString)
forall (m :: * -> *).
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3i ByteString
content' (Int
contentlen, Int
macSize, Int
paddinglength)
            Record Ciphertext -> CipherData -> RecordM ByteString
forall a. Record a -> CipherData -> RecordM ByteString
getCipherData Record Ciphertext
record CipherData :: ByteString
-> Maybe ByteString -> Maybe (ByteString, Int) -> CipherData
CipherData
                    { cipherDataContent :: ByteString
cipherDataContent = ByteString
content
                    , cipherDataMAC :: Maybe ByteString
cipherDataMAC     = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
mac
                    , cipherDataPadding :: Maybe (ByteString, Int)
cipherDataPadding = (ByteString, Int) -> Maybe (ByteString, Int)
forall a. a -> Maybe a
Just (ByteString
padding, Int
blockSize)
                    }

        decryptOf (BulkStateStream (BulkStream ByteString -> (ByteString, BulkStream)
decryptF)) = do
            -- check if we have enough bytes to cover the minimum for this cipher
            Bool -> RecordM () -> RecordM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
econtentLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
macSize) RecordM ()
forall a. RecordM a
sanityCheckError

            let (ByteString
content', BulkStream
bulkStream') = ByteString -> (ByteString, BulkStream)
decryptF ByteString
econtent
            {- update Ctx -}
            let contentlen :: Int
contentlen        = ByteString -> Int
B.length ByteString
content' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
macSize
            (ByteString
content, ByteString
mac) <- ByteString -> (Int, Int) -> RecordM (ByteString, ByteString)
forall (m :: * -> *).
MonadError TLSError m =>
ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2i ByteString
content' (Int
contentlen, Int
macSize)
            (RecordState -> RecordState) -> RecordM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RecordState -> RecordState) -> RecordM ())
-> (RecordState -> RecordState) -> RecordM ()
forall a b. (a -> b) -> a -> b
$ \RecordState
txs -> RecordState
txs { stCryptState :: CryptState
stCryptState = CryptState
cst { cstKey :: BulkState
cstKey = BulkStream -> BulkState
BulkStateStream BulkStream
bulkStream' } }
            Record Ciphertext -> CipherData -> RecordM ByteString
forall a. Record a -> CipherData -> RecordM ByteString
getCipherData Record Ciphertext
record CipherData :: ByteString
-> Maybe ByteString -> Maybe (ByteString, Int) -> CipherData
CipherData
                    { cipherDataContent :: ByteString
cipherDataContent = ByteString
content
                    , cipherDataMAC :: Maybe ByteString
cipherDataMAC     = ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
mac
                    , cipherDataPadding :: Maybe (ByteString, Int)
cipherDataPadding = Maybe (ByteString, Int)
forall a. Maybe a
Nothing
                    }

        decryptOf (BulkStateAEAD BulkAEAD
decryptF) = do
            let authTagLen :: Int
authTagLen  = Bulk -> Int
bulkAuthTagLen Bulk
bulk
                nonceExpLen :: Int
nonceExpLen = Bulk -> Int
bulkExplicitIV Bulk
bulk
                cipherLen :: Int
cipherLen   = Int
econtentLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
authTagLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
nonceExpLen

            -- check if we have enough bytes to cover the minimum for this cipher
            Bool -> RecordM () -> RecordM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
econtentLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< (Int
authTagLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nonceExpLen)) RecordM ()
forall a. RecordM a
sanityCheckError

            (ByteString
enonce, ByteString
econtent', ByteString
authTag) <- ByteString
-> (Int, Int, Int) -> RecordM (ByteString, ByteString, ByteString)
forall (m :: * -> *).
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3o ByteString
econtent (Int
nonceExpLen, Int
cipherLen, Int
authTagLen)
            let encodedSeq :: ByteString
encodedSeq = Word64 -> ByteString
encodeWord64 (Word64 -> ByteString) -> Word64 -> ByteString
forall a b. (a -> b) -> a -> b
$ MacState -> Word64
msSequence (MacState -> Word64) -> MacState -> Word64
forall a b. (a -> b) -> a -> b
$ RecordState -> MacState
stMacState RecordState
tst
                iv :: ByteString
iv = CryptState -> ByteString
cstIV (RecordState -> CryptState
stCryptState RecordState
tst)
                ivlen :: Int
ivlen = ByteString -> Int
B.length ByteString
iv
                Header ProtocolType
typ Version
v Word16
_ = Record Ciphertext -> Header
forall a. Record a -> Header
recordToHeader Record Ciphertext
record
                hdrLen :: Int
hdrLen = if Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
TLS13 then Int
econtentLen else Int
cipherLen
                hdr :: Header
hdr = ProtocolType -> Version -> Word16 -> Header
Header ProtocolType
typ Version
v (Word16 -> Header) -> Word16 -> Header
forall a b. (a -> b) -> a -> b
$ Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
hdrLen
                ad :: ByteString
ad | Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
>= Version
TLS13 = Header -> ByteString
encodeHeader Header
hdr
                   | Bool
otherwise    = [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr ]
                sqnc :: ByteString
sqnc = Int -> Word8 -> ByteString
B.replicate (Int
ivlen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
8) Word8
0 ByteString -> ByteString -> ByteString
`B.append` ByteString
encodedSeq
                nonce :: ByteString
nonce | Int
nonceExpLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ByteString -> ByteString -> ByteString
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor ByteString
iv ByteString
sqnc
                      | Bool
otherwise = ByteString
iv ByteString -> ByteString -> ByteString
`B.append` ByteString
enonce
                (ByteString
content, AuthTag
authTag2) = BulkAEAD
decryptF ByteString
nonce ByteString
econtent' ByteString
ad

            Bool -> RecordM () -> RecordM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bytes -> AuthTag
AuthTag (ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert ByteString
authTag) AuthTag -> AuthTag -> Bool
forall a. Eq a => a -> a -> Bool
/= AuthTag
authTag2) (RecordM () -> RecordM ()) -> RecordM () -> RecordM ()
forall a b. (a -> b) -> a -> b
$
                TLSError -> RecordM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> RecordM ()) -> TLSError -> RecordM ()
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"bad record mac", Bool
True, AlertDescription
BadRecordMac)

            (RecordState -> RecordState) -> RecordM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify RecordState -> RecordState
incrRecordState
            ByteString -> RecordM ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
content

        decryptOf BulkState
BulkStateUninitialized =
            TLSError -> RecordM ByteString
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> RecordM ByteString) -> TLSError -> RecordM ByteString
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"decrypt state uninitialized", Bool
True, AlertDescription
InternalError)

        -- handling of outer format can report errors with Error_Packet
        get3o :: ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3o ByteString
s (Int, Int, Int)
ls = m (ByteString, ByteString, ByteString)
-> ((ByteString, ByteString, ByteString)
    -> m (ByteString, ByteString, ByteString))
-> Maybe (ByteString, ByteString, ByteString)
-> m (ByteString, ByteString, ByteString)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (TLSError -> m (ByteString, ByteString, ByteString)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> m (ByteString, ByteString, ByteString))
-> TLSError -> m (ByteString, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ String -> TLSError
Error_Packet String
"record bad format") (ByteString, ByteString, ByteString)
-> m (ByteString, ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ByteString, ByteString, ByteString)
 -> m (ByteString, ByteString, ByteString))
-> Maybe (ByteString, ByteString, ByteString)
-> m (ByteString, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString
-> (Int, Int, Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 ByteString
s (Int, Int, Int)
ls
        get2o :: ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2o ByteString
s (Int
d1,Int
d2) = ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
forall (m :: * -> *).
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3o ByteString
s (Int
d1,Int
d2,Int
0) m (ByteString, ByteString, ByteString)
-> ((ByteString, ByteString, ByteString)
    -> m (ByteString, ByteString))
-> m (ByteString, ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
r1,ByteString
r2,ByteString
_) -> (ByteString, ByteString) -> m (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
r1,ByteString
r2)

        -- all format errors related to decrypted content are reported
        -- externally as integrity failures, i.e. BadRecordMac
        get3i :: ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3i ByteString
s (Int, Int, Int)
ls = m (ByteString, ByteString, ByteString)
-> ((ByteString, ByteString, ByteString)
    -> m (ByteString, ByteString, ByteString))
-> Maybe (ByteString, ByteString, ByteString)
-> m (ByteString, ByteString, ByteString)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (TLSError -> m (ByteString, ByteString, ByteString)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TLSError -> m (ByteString, ByteString, ByteString))
-> TLSError -> m (ByteString, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"record bad format", Bool
True, AlertDescription
BadRecordMac)) (ByteString, ByteString, ByteString)
-> m (ByteString, ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ByteString, ByteString, ByteString)
 -> m (ByteString, ByteString, ByteString))
-> Maybe (ByteString, ByteString, ByteString)
-> m (ByteString, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString
-> (Int, Int, Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 ByteString
s (Int, Int, Int)
ls
        get2i :: ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2i ByteString
s (Int
d1,Int
d2) = ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
forall (m :: * -> *).
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3i ByteString
s (Int
d1,Int
d2,Int
0) m (ByteString, ByteString, ByteString)
-> ((ByteString, ByteString, ByteString)
    -> m (ByteString, ByteString))
-> m (ByteString, ByteString)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
r1,ByteString
r2,ByteString
_) -> (ByteString, ByteString) -> m (ByteString, ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
r1,ByteString
r2)