{-# LANGUAGE FlexibleContexts #-}
module Data.Conduit.OpenPGP.Decrypt (
conduitDecrypt
) where
import Control.Monad (when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Resource (MonadBaseControl, MonadResource, MonadThrow, runResourceT)
import qualified Control.Monad.Trans.State.Lazy as S
import qualified Crypto.Hash as CH
import qualified Crypto.Hash.Algorithms as CHA
import qualified Data.ByteArray as BA
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import qualified Data.Conduit.Binary as CB
import Data.Conduit.Serialization.Binary (conduitGet)
import Data.Conduit.OpenPGP.Compression (conduitDecompress)
import qualified Data.Conduit.List as CL
import Data.Default.Class (Default, def)
import Data.Maybe (fromJust, isNothing)
import Data.Binary (get)
import Codec.Encryption.OpenPGP.S2K (skesk2Key)
import Codec.Encryption.OpenPGP.CFB (decrypt, decryptOpenPGPCfb)
import Codec.Encryption.OpenPGP.Types
data RecursorState = RecursorState {
_depth :: Int
, _lastPKESK :: Maybe PKESK
, _lastSKESK :: Maybe SKESK
, _lastLDP :: Maybe LiteralData
} deriving (Eq, Show)
instance Default RecursorState where
def = RecursorState 0 Nothing Nothing Nothing
type InputCallback m = String -> m BL.ByteString
conduitDecrypt :: (MonadBaseControl IO m, MonadResource m) => InputCallback IO -> Conduit Pkt m Pkt
conduitDecrypt = conduitDecrypt' 0
conduitDecrypt' :: (MonadBaseControl IO m, MonadResource m) => Int -> InputCallback IO -> Conduit Pkt m Pkt
conduitDecrypt' depth cb = CL.concatMapAccumM push def { _depth = depth }
where
push :: (MonadBaseControl IO m, MonadResource m) => Pkt -> RecursorState -> m (RecursorState, [Pkt])
push i s
| _depth s > 42 = fail "I think we've been quine-attacked"
| otherwise = case i of
SKESKPkt{} -> return (s { _lastSKESK = Just (fromPkt i) }, [])
(SymEncDataPkt bs) -> do d <- decryptSEDP (_depth s) cb (fromJust . _lastSKESK $ s) bs
return (processLDPs s d, d)
(SymEncIntegrityProtectedDataPkt _ bs) -> do d <- decryptSEIPDP (_depth s) cb (fromJust . _lastSKESK $ s) bs
return (processLDPs s d, d)
m@(ModificationDetectionCodePkt mdc) -> do when (isNothing (_lastLDP s)) $ fail "MDC with no referent"
when (fmap (BL.fromStrict . BA.convert . (CH.hashlazy :: BL.ByteString -> CH.Digest CHA.SHA1) . _literalDataPayload) (_lastLDP s) /= Just mdc) $ fail "MDC indicates tampering"
return (s, [m])
p -> return (s, [p])
processLDPs s ds = S.execState (mapM_ ldpCheck ds) s
ldpCheck l@LiteralDataPkt{} = S.get >>= \o -> S.put o { _lastLDP = Just . fromPkt $ l }
ldpCheck _ = return ()
decryptSEDP :: (MonadBaseControl IO m, MonadIO m, MonadThrow m) => Int -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt]
decryptSEDP depth cb skesk bs = do
passphrase <- liftIO $ cb "Input the passphrase I want"
let key = skesk2Key skesk passphrase
decrypted = case decryptOpenPGPCfb (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of
Left e -> error e
Right x -> x
runResourceT $ CB.sourceLbs (BL.fromStrict decrypted) $= conduitGet get $= conduitDecompress $= conduitDecrypt' depth cb $$ CL.consume
decryptSEIPDP :: (MonadBaseControl IO m, MonadIO m, MonadThrow m) => Int -> InputCallback IO -> SKESK -> BL.ByteString -> m [Pkt]
decryptSEIPDP depth cb skesk bs = do
passphrase <- liftIO $ cb "Input the passphrase I want"
let key = skesk2Key skesk passphrase
decrypted = case decrypt (_skeskSymmetricAlgorithm skesk) (BL.toStrict bs) key of
Left e -> error e
Right x -> x
runResourceT $ CB.sourceLbs (BL.fromStrict decrypted) $= conduitGet get $= conduitDecompress $= conduitDecrypt' depth cb $$ CL.consume