module Crypto.PubKey.ECIES.Conduit
( encrypt
, decrypt
) where
import Control.Monad.Catch (MonadThrow, throwM)
import Control.Monad.Trans.Class (lift)
import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Crypto.Cipher.ChaChaPoly1305.Conduit as ChaCha
import qualified Crypto.ECC as ECC
import qualified Crypto.Error as CE
import Crypto.Hash (SHA512 (..), hashWith)
import Crypto.PubKey.ECIES (deriveDecrypt,
deriveEncrypt)
import Crypto.Random (MonadRandom)
import qualified Data.ByteArray as BA
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.Conduit (ConduitM, yield)
import qualified Data.Conduit.Binary as CB
import Data.Proxy (Proxy (..))
import System.IO.Unsafe (unsafePerformIO)
getNonceKey :: ECC.SharedSecret -> (ByteString, ByteString)
getNonceKey shared =
let state1 = ChaCha.initializeSimple $ B.take 40 $ BA.convert $ hashWith SHA512 shared
(nonce, state2) = ChaCha.generateSimple state1 12
(key, _) = ChaCha.generateSimple state2 32
in (nonce, key)
type Curve = ECC.Curve_P256R1
proxy :: Proxy Curve
proxy = Proxy
pointBinarySize :: Int
pointBinarySize = B.length $ ECC.encodePoint proxy point
where
point = unsafePerformIO (ECC.keypairGetPublic <$> ECC.curveGenerateKeyPair proxy)
throwOnFail :: MonadThrow m => CE.CryptoFailable a -> m a
throwOnFail (CE.CryptoPassed a) = pure a
throwOnFail (CE.CryptoFailed e) = throwM e
encrypt
:: (MonadThrow m, MonadRandom m)
=> ECC.Point Curve
-> ConduitM ByteString ByteString m ()
encrypt point = do
(point', shared) <- lift (deriveEncryptCompat proxy point) >>= throwOnFail
let (nonce, key) = getNonceKey shared
yield $ ECC.encodePoint proxy point'
ChaCha.encrypt nonce key
where
#if MIN_VERSION_cryptonite(0,23,999)
deriveEncryptCompat prx p = deriveEncrypt prx p
#else
deriveEncryptCompat prx p = CE.CryptoPassed <$> deriveEncrypt prx p
#endif
decrypt
:: (MonadThrow m)
=> ECC.Scalar Curve
-> ConduitM ByteString ByteString m ()
decrypt scalar = do
pointBS <- fmap BL.toStrict $ CB.take pointBinarySize
point <- throwOnFail (ECC.decodePoint proxy pointBS)
shared <- throwOnFail (deriveDecryptCompat proxy point scalar)
let (_nonce, key) = getNonceKey shared
ChaCha.decrypt key
where
#if MIN_VERSION_cryptonite(0,23,999)
deriveDecryptCompat prx p s = deriveDecrypt prx p s
#else
deriveDecryptCompat prx p s = CE.CryptoPassed (deriveDecrypt prx p s)
#endif