-- | -- Module : Network.TLS.Record.Disengage -- License : BSD-style -- Maintainer : Vincent Hanquez -- Stability : experimental -- Portability : unknown -- -- Disengage a record from the Record layer. -- The record is decrypted, checked for integrity and then decompressed. -- module Network.TLS.Record.Disengage ( disengageRecord ) where import Control.Monad.State import Control.Monad.Error import Network.TLS.Struct import Network.TLS.Cap import Network.TLS.State import Network.TLS.Record.Types import Network.TLS.Cipher import Network.TLS.Compression import Network.TLS.Util import Data.ByteString (ByteString) import qualified Data.ByteString as B disengageRecord :: Record Ciphertext -> TLSSt (Record Plaintext) disengageRecord = decryptRecord >=> uncompressRecord uncompressRecord :: Record Compressed -> TLSSt (Record Plaintext) uncompressRecord record = onRecordFragment record $ fragmentUncompress $ \bytes -> withCompression $ compressionInflate bytes decryptRecord :: Record Ciphertext -> TLSSt (Record Compressed) decryptRecord record = onRecordFragment record $ fragmentUncipher $ \e -> do st <- get if stRxEncrypted st then get >>= decryptData record e else return e getCipherData :: Record a -> CipherData -> TLSSt ByteString getCipherData (Record pt ver _) cdata = do -- check if the MAC is valid. macValid <- case cipherDataMAC cdata of Nothing -> return True Just digest -> do let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata) expected_digest <- makeDigest False new_hdr $ cipherDataContent cdata return (expected_digest `bytesEq` digest) -- check if the padding is filled with the correct pattern if it exists paddingValid <- case cipherDataPadding cdata of Nothing -> return True Just pad -> do cver <- gets stVersion let b = B.length pad - 1 return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad) unless (macValid &&! paddingValid) $ do throwError $ Error_Protocol ("bad record mac", True, BadRecordMac) return $ cipherDataContent cdata decryptData :: Record Ciphertext -> Bytes -> TLSState -> TLSSt Bytes decryptData record econtent st = decryptOf (bulkF bulk) where cipher = fromJust "cipher" $ stActiveRxCipher st bulk = cipherBulk cipher cst = fromJust "rx crypt state" $ stActiveRxCryptState st macSize = hashSize $ cipherHash cipher writekey = cstKey cst blockSize = bulkBlockSize bulk econtentLen = B.length econtent explicitIV = hasExplicitBlockIV $ stVersion st sanityCheckError = throwError (Error_Packet "encrypted content too small for encryption parameters") decryptOf :: BulkFunctions -> TLSSt Bytes decryptOf (BulkBlockF _ decryptF) = do let minContent = (if explicitIV then bulkIVSize bulk else 0) + max (macSize + 1) blockSize when ((econtentLen `mod` blockSize) /= 0 || econtentLen < minContent) $ sanityCheckError {- update IV -} (iv, econtent') <- if explicitIV then get2 econtent (bulkIVSize bulk, econtentLen - bulkIVSize bulk) else return (cstIV cst, econtent) let newiv = fromJust "new iv" $ takelast (bulkBlockSize bulk) econtent' put $ st { stActiveRxCryptState = Just $ cst { cstIV = newiv } } let content' = decryptF writekey iv econtent' let paddinglength = fromIntegral (B.last content') + 1 let contentlen = B.length content' - paddinglength - macSize (content, mac, padding) <- get3 content' (contentlen, macSize, paddinglength) getCipherData record $ CipherData { cipherDataContent = content , cipherDataMAC = Just mac , cipherDataPadding = Just padding } decryptOf (BulkStreamF initF _ decryptF) = do when (econtentLen < macSize) $ sanityCheckError let iv = cstIV cst let (content', newiv) = decryptF (if iv /= B.empty then iv else initF writekey) econtent {- update Ctx -} let contentlen = B.length content' - macSize (content, mac) <- get2 content' (contentlen, macSize) put $ st { stActiveRxCryptState = Just $ cst { cstIV = newiv } } getCipherData record $ CipherData { cipherDataContent = content , cipherDataMAC = Just mac , cipherDataPadding = Nothing } get3 s ls = maybe (throwError $ Error_Packet "record bad format") return $ partition3 s ls get2 s (d1,d2) = get3 s (d1,d2,0) >>= \(r1,r2,_) -> return (r1,r2)