-- | -- Module : Network.TLS.Record.Disengage -- License : BSD-style -- Maintainer : Vincent Hanquez -- Stability : experimental -- Portability : unknown -- -- Disengage a record from the Record layer. -- The record is decrypted, checked for integrity and then decompressed. -- module Network.TLS.Record.Disengage ( disengageRecord ) where import Control.Monad.State import Control.Monad.Error import Network.TLS.Struct import Network.TLS.Cap import Network.TLS.Record.State import Network.TLS.Record.Types import Network.TLS.Cipher import Network.TLS.Compression import Network.TLS.Util import Data.ByteString (ByteString) import qualified Data.ByteString as B disengageRecord :: Record Ciphertext -> RecordM (Record Plaintext) disengageRecord = decryptRecord >=> uncompressRecord uncompressRecord :: Record Compressed -> RecordM (Record Plaintext) uncompressRecord record = onRecordFragment record $ fragmentUncompress $ \bytes -> withCompression $ compressionInflate bytes decryptRecord :: Record Ciphertext -> RecordM (Record Compressed) decryptRecord record = onRecordFragment record $ fragmentUncipher $ \e -> do st <- get case stCipher st of Nothing -> return e _ -> getRecordVersion >>= \ver -> decryptData ver record e st getCipherData :: Record a -> CipherData -> RecordM ByteString getCipherData (Record pt ver _) cdata = do -- check if the MAC is valid. macValid <- case cipherDataMAC cdata of Nothing -> return True Just digest -> do let new_hdr = Header pt ver (fromIntegral $ B.length $ cipherDataContent cdata) expected_digest <- makeDigest new_hdr $ cipherDataContent cdata return (expected_digest `bytesEq` digest) -- check if the padding is filled with the correct pattern if it exists paddingValid <- case cipherDataPadding cdata of Nothing -> return True Just pad -> do cver <- getRecordVersion let b = B.length pad - 1 return (if cver < TLS10 then True else B.replicate (B.length pad) (fromIntegral b) `bytesEq` pad) unless (macValid &&! paddingValid) $ do throwError $ Error_Protocol ("bad record mac", True, BadRecordMac) return $ cipherDataContent cdata decryptData :: Version -> Record Ciphertext -> Bytes -> RecordState -> RecordM Bytes decryptData ver record econtent tst = decryptOf (bulkF bulk) where cipher = fromJust "cipher" $ stCipher tst bulk = cipherBulk cipher cst = stCryptState tst macSize = hashSize $ cipherHash cipher writekey = cstKey cst blockSize = bulkBlockSize bulk econtentLen = B.length econtent explicitIV = hasExplicitBlockIV ver sanityCheckError = throwError (Error_Packet "encrypted content too small for encryption parameters") decryptOf :: BulkFunctions -> RecordM Bytes decryptOf (BulkBlockF _ decryptF) = do let minContent = (if explicitIV then bulkIVSize bulk else 0) + max (macSize + 1) blockSize when ((econtentLen `mod` blockSize) /= 0 || econtentLen < minContent) $ sanityCheckError {- update IV -} (iv, econtent') <- if explicitIV then get2 econtent (bulkIVSize bulk, econtentLen - bulkIVSize bulk) else return (cstIV cst, econtent) let newiv = fromJust "new iv" $ takelast (bulkBlockSize bulk) econtent' modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } } let content' = decryptF writekey iv econtent' let paddinglength = fromIntegral (B.last content') + 1 let contentlen = B.length content' - paddinglength - macSize (content, mac, padding) <- get3 content' (contentlen, macSize, paddinglength) getCipherData record $ CipherData { cipherDataContent = content , cipherDataMAC = Just mac , cipherDataPadding = Just padding } decryptOf (BulkStreamF initF _ decryptF) = do when (econtentLen < macSize) $ sanityCheckError let (content', newiv) = decryptF (if cstIV cst /= B.empty then cstIV cst else initF writekey) econtent {- update Ctx -} let contentlen = B.length content' - macSize (content, mac) <- get2 content' (contentlen, macSize) modify $ \txs -> txs { stCryptState = cst { cstIV = newiv } } getCipherData record $ CipherData { cipherDataContent = content , cipherDataMAC = Just mac , cipherDataPadding = Nothing } get3 s ls = maybe (throwError $ Error_Packet "record bad format") return $ partition3 s ls get2 s (d1,d2) = get3 s (d1,d2,0) >>= \(r1,r2,_) -> return (r1,r2)