-- Decrypt.hs: OpenPGP (RFC4880) recursive packet decryption -- Copyright © 2013-2019 Clint Adams -- This software is released under the terms of the Expat license. -- (See the LICENSE file). {-# LANGUAGE FlexibleContexts #-} module Data.Conduit.OpenPGP.Decrypt ( conduitDecrypt ) where import Control.Monad (when) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Trans.Resource (MonadResource, MonadThrow) import qualified Crypto.Hash as CH import qualified Crypto.Hash.Algorithms as CHA import Data.Binary (get) import qualified Data.ByteArray as BA import qualified Data.ByteString as B import qualified Data.ByteString.Base16.Lazy as B16L import qualified Data.ByteString.Lazy as BL import Data.Conduit import qualified Data.Conduit.Binary as CB import qualified Data.Conduit.Combinators as CC import qualified Data.Conduit.List as CL import Data.Conduit.OpenPGP.Compression (conduitDecompress) import Data.Conduit.Serialization.Binary (conduitGet) import Data.Maybe (fromJust, isNothing) import Codec.Encryption.OpenPGP.CFB (decryptOpenPGPCfb, decryptPreservingNonce) import Codec.Encryption.OpenPGP.S2K (skesk2Key) import Codec.Encryption.OpenPGP.Types data RecursorState = RecursorState { _depth :: Int , _lastPKESK :: Maybe PKESK , _lastSKESK :: Maybe SKESK , _lastNonce :: Maybe B.ByteString , _lastClearText :: Maybe B.ByteString } deriving (Eq, Show) def :: RecursorState def = RecursorState 0 Nothing Nothing Nothing Nothing type InputCallback m = String -> m BL.ByteString conduitDecrypt :: (MonadUnliftIO m, MonadResource m, MonadThrow m) => InputCallback IO -> ConduitT Pkt Pkt m () conduitDecrypt = conduitDecrypt' def conduitDecrypt' :: (MonadUnliftIO m, MonadResource m, MonadThrow m) => RecursorState -> InputCallback IO -> ConduitT Pkt Pkt m () conduitDecrypt' rs cb = CC.concatMapAccumM push rs where push :: (MonadUnliftIO m, MonadResource m, MonadThrow m) => Pkt -> RecursorState -> m (RecursorState, [Pkt]) push i s | _depth s > 42 = fail "I think we've been quine-attacked" | otherwise = case i of SKESKPkt {} -> return (s {_lastSKESK = Just (fromPkt i)}, []) (SymEncDataPkt bs) -> do d <- decryptSEDP s cb (fromJust . _lastSKESK $ s) bs return (s, d) (SymEncIntegrityProtectedDataPkt _ bs) -> do d <- decryptSEIPDP s cb (fromJust . _lastSKESK $ s) bs return (s, d) m@(ModificationDetectionCodePkt mdc) -> do when (isNothing (_lastClearText s)) $ fail "MDC with no referent" let mcalculated = calculateMDC <$> _lastNonce s <*> _lastClearText s when (mcalculated /= Just mdc) $ fail $ "MDC indicates tampering: " ++ show (B16L.encode mdc) ++ " versus " ++ maybe "" (show . B16L.encode) mcalculated ++ " ... " ++ show (_lastNonce s) ++ " / " ++ show (_lastClearText s) return (s, [m]) p -> return (s, [p]) decryptSEDP :: (MonadUnliftIO m, MonadIO m, MonadThrow m) => RecursorState -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt] decryptSEDP rs cb skesk bs -- FIXME: this shouldn't pass the whole SKESK = do passphrase <- liftIO $ cb "Input the passphrase I want" let key = skesk2Key skesk passphrase decrypted = case decryptOpenPGPCfb (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of Left e -> error e Right x -> x runConduitRes $ CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .| conduitDecompress .| conduitDecrypt' rs {_depth = _depth rs + 1} cb .| CL.consume decryptSEIPDP :: (MonadUnliftIO m, MonadIO m, MonadThrow m) => RecursorState -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt] decryptSEIPDP rs cb skesk bs -- FIXME: this shouldn't pass the whole SKESK = do passphrase <- liftIO $ cb "Input the passphrase I want" let key = skesk2Key skesk passphrase (nonce, decrypted) = case decryptPreservingNonce (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of Left e -> error e Right x -> x runConduitRes $ CB.sourceLbs (BL.fromStrict decrypted) .| conduitGet get .| conduitDecompress .| conduitDecrypt' rs { _depth = _depth rs + 1 , _lastNonce = Just nonce , _lastClearText = Just decrypted } cb .| CL.consume calculateMDC :: B.ByteString -> B.ByteString -> BL.ByteString calculateMDC nonce garbage | B.length garbage < 23 = mempty -- FIXME: this is horrible | otherwise = BL.fromStrict . BA.convert . (CH.hash :: B.ByteString -> CH.Digest CHA.SHA1) $ nonce <> B.take (B.length garbage - 22) garbage <> B.pack [211, 20]