-- Decrypt.hs: OpenPGP (RFC4880) recursive packet decryption -- Copyright © 2013-2016 Clint Adams -- This software is released under the terms of the Expat license. -- (See the LICENSE file). {-# LANGUAGE FlexibleContexts #-} module Data.Conduit.OpenPGP.Decrypt ( conduitDecrypt ) where import Control.Monad (when) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Trans.Resource (MonadBaseControl, MonadResource, MonadThrow, runResourceT) import qualified Control.Monad.Trans.State.Lazy as S import qualified Crypto.Hash as CH import qualified Crypto.Hash.Algorithms as CHA import qualified Data.ByteArray as BA import qualified Data.ByteString.Lazy as BL import Data.Conduit import qualified Data.Conduit.Binary as CB import Data.Conduit.Serialization.Binary (conduitGet) import Data.Conduit.OpenPGP.Compression (conduitDecompress) import qualified Data.Conduit.List as CL import Data.Default.Class (Default, def) import Data.Maybe (fromJust, isNothing) import Data.Binary (get) import Codec.Encryption.OpenPGP.S2K (skesk2Key) import Codec.Encryption.OpenPGP.CFB (decrypt, decryptOpenPGPCfb) import Codec.Encryption.OpenPGP.Types data RecursorState = RecursorState { _depth :: Int , _lastPKESK :: Maybe PKESK , _lastSKESK :: Maybe SKESK , _lastLDP :: Maybe LiteralData } deriving (Eq, Show) instance Default RecursorState where def = RecursorState 0 Nothing Nothing Nothing type InputCallback m = String -> m BL.ByteString conduitDecrypt :: (MonadBaseControl IO m, MonadResource m) => InputCallback IO -> Conduit Pkt m Pkt conduitDecrypt = conduitDecrypt' 0 conduitDecrypt' :: (MonadBaseControl IO m, MonadResource m) => Int -> InputCallback IO -> Conduit Pkt m Pkt conduitDecrypt' depth cb = CL.concatMapAccumM push def { _depth = depth } -- FIXME: this depth stuff is convoluted where push :: (MonadBaseControl IO m, MonadResource m) => Pkt -> RecursorState -> m (RecursorState, [Pkt]) push i s | _depth s > 42 = fail "I think we've been quine-attacked" | otherwise = case i of SKESKPkt{} -> return (s { _lastSKESK = Just (fromPkt i) }, []) (SymEncDataPkt bs) -> do d <- decryptSEDP (_depth s) cb (fromJust . _lastSKESK $ s) bs return (processLDPs s d, d) (SymEncIntegrityProtectedDataPkt _ bs) -> do d <- decryptSEIPDP (_depth s) cb (fromJust . _lastSKESK $ s) bs return (processLDPs s d, d) m@(ModificationDetectionCodePkt mdc) -> do when (isNothing (_lastLDP s)) $ fail "MDC with no referent" when (fmap (BL.fromStrict . BA.convert . (CH.hashlazy :: BL.ByteString -> CH.Digest CHA.SHA1) . _literalDataPayload) (_lastLDP s) /= Just mdc) $ fail "MDC indicates tampering" return (s, [m]) p -> return (s, [p]) processLDPs s ds = S.execState (mapM_ ldpCheck ds) s ldpCheck l@LiteralDataPkt{} = S.get >>= \o -> S.put o { _lastLDP = Just . fromPkt $ l } ldpCheck _ = return () decryptSEDP :: (MonadBaseControl IO m, MonadIO m, MonadThrow m) => Int -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt] decryptSEDP depth cb skesk bs = do -- FIXME: this shouldn't pass the whole SKESK passphrase <- liftIO $ cb "Input the passphrase I want" let key = skesk2Key skesk passphrase decrypted = case decryptOpenPGPCfb (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of Left e -> error e Right x -> x runResourceT $ CB.sourceLbs (BL.fromStrict decrypted) $= conduitGet get $= conduitDecompress $= conduitDecrypt' depth cb $$ CL.consume decryptSEIPDP :: (MonadBaseControl IO m, MonadIO m, MonadThrow m) => Int -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt] decryptSEIPDP depth cb skesk bs = do -- FIXME: this shouldn't pass the whole SKESK passphrase <- liftIO $ cb "Input the passphrase I want" let key = skesk2Key skesk passphrase decrypted = case decrypt (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of Left e -> error e Right x -> x runResourceT $ CB.sourceLbs (BL.fromStrict decrypted) $= conduitGet get $= conduitDecompress $= conduitDecrypt' depth cb $$ CL.consume