{-# LANGUAGE ForeignFunctionInterface #-}
-- |Asymmetric cipher decryption using encrypted symmetric key. This
-- is an opposite of "OpenSSL.EVP.Open".
module OpenSSL.EVP.Seal
    ( seal
    , sealBS
    , sealLBS
    )
    where
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy.Char8 as L8
import           Foreign
import           Foreign.C
import           OpenSSL.EVP.Cipher hiding (cipher)
import           OpenSSL.EVP.PKey
import           OpenSSL.EVP.Internal
import           OpenSSL.Utils


foreign import ccall unsafe "EVP_SealInit"
        _SealInit :: Ptr EVP_CIPHER_CTX
                  -> Cipher
                  -> Ptr (Ptr CChar)
                  -> Ptr CInt
                  -> CString
                  -> Ptr (Ptr EVP_PKEY)
                  -> CInt
                  -> IO CInt


sealInit :: Cipher
         -> [SomePublicKey]
         -> IO (CipherCtx, [B8.ByteString], B8.ByteString)

sealInit :: Cipher
-> [SomePublicKey] -> IO (CipherCtx, [ByteString], ByteString)
sealInit Cipher
_ []
    = String -> IO (CipherCtx, [ByteString], ByteString)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"sealInit: at least one public key is required"

sealInit Cipher
cipher [SomePublicKey]
pubKeys
    = do CipherCtx
ctx <- IO CipherCtx
newCipherCtx

         -- Allocate a list of buffers to write encrypted symmetric
         -- keys. Each keys will be at most pkeySize bytes long.
         [Ptr CChar]
encKeyBufs <- (SomePublicKey -> IO (Ptr CChar))
-> [SomePublicKey] -> IO [Ptr CChar]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SomePublicKey -> IO (Ptr CChar)
forall k a. (PKey k, Storable a) => k -> IO (Ptr a)
mallocEncKeyBuf [SomePublicKey]
pubKeys

         -- encKeyBufs is [Ptr a] but we want Ptr (Ptr CChar).
         Ptr (Ptr CChar)
encKeyBufsPtr <- [Ptr CChar] -> IO (Ptr (Ptr CChar))
forall a. Storable a => [a] -> IO (Ptr a)
newArray [Ptr CChar]
encKeyBufs

         -- Allocate a buffer to write lengths of each encrypted
         -- symmetric keys.
         Ptr CInt
encKeyBufsLenPtr <- Int -> IO (Ptr CInt)
forall a. Storable a => Int -> IO (Ptr a)
mallocArray Int
nKeys

         -- Allocate a buffer to write IV.
         Ptr CChar
ivPtr <- Int -> IO (Ptr CChar)
forall a. Storable a => Int -> IO (Ptr a)
mallocArray (Cipher -> Int
cipherIvLength Cipher
cipher)

         -- Create Ptr (Ptr EVP_PKEY) from [PKey]. Don't forget to
         -- apply touchForeignPtr to each PKey's later.
         [VaguePKey]
pkeys      <- (SomePublicKey -> IO VaguePKey)
-> [SomePublicKey] -> IO [VaguePKey]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SomePublicKey -> IO VaguePKey
forall k. PKey k => k -> IO VaguePKey
toPKey [SomePublicKey]
pubKeys
         Ptr (Ptr EVP_PKEY)
pubKeysPtr <- [Ptr EVP_PKEY] -> IO (Ptr (Ptr EVP_PKEY))
forall a. Storable a => [a] -> IO (Ptr a)
newArray ([Ptr EVP_PKEY] -> IO (Ptr (Ptr EVP_PKEY)))
-> [Ptr EVP_PKEY] -> IO (Ptr (Ptr EVP_PKEY))
forall a b. (a -> b) -> a -> b
$ (VaguePKey -> Ptr EVP_PKEY) -> [VaguePKey] -> [Ptr EVP_PKEY]
forall a b. (a -> b) -> [a] -> [b]
map VaguePKey -> Ptr EVP_PKEY
unsafePKeyToPtr [VaguePKey]
pkeys

         -- Prepare an IO action to free buffers we allocated above.
         let cleanup :: IO ()
cleanup = do (Ptr CChar -> IO ()) -> [Ptr CChar] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Ptr CChar -> IO ()
forall a. Ptr a -> IO ()
free [Ptr CChar]
encKeyBufs
                          Ptr (Ptr CChar) -> IO ()
forall a. Ptr a -> IO ()
free Ptr (Ptr CChar)
encKeyBufsPtr
                          Ptr CInt -> IO ()
forall a. Ptr a -> IO ()
free Ptr CInt
encKeyBufsLenPtr
                          Ptr CChar -> IO ()
forall a. Ptr a -> IO ()
free Ptr CChar
ivPtr
                          Ptr (Ptr EVP_PKEY) -> IO ()
forall a. Ptr a -> IO ()
free Ptr (Ptr EVP_PKEY)
pubKeysPtr
                          (VaguePKey -> IO ()) -> [VaguePKey] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VaguePKey -> IO ()
touchPKey [VaguePKey]
pkeys

         -- Call EVP_SealInit finally.
         CInt
ret <- CipherCtx -> (Ptr EVP_CIPHER_CTX -> IO CInt) -> IO CInt
forall a. CipherCtx -> (Ptr EVP_CIPHER_CTX -> IO a) -> IO a
withCipherCtxPtr CipherCtx
ctx ((Ptr EVP_CIPHER_CTX -> IO CInt) -> IO CInt)
-> (Ptr EVP_CIPHER_CTX -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \ Ptr EVP_CIPHER_CTX
ctxPtr ->
                Ptr EVP_CIPHER_CTX
-> Cipher
-> Ptr (Ptr CChar)
-> Ptr CInt
-> Ptr CChar
-> Ptr (Ptr EVP_PKEY)
-> CInt
-> IO CInt
_SealInit Ptr EVP_CIPHER_CTX
ctxPtr Cipher
cipher Ptr (Ptr CChar)
encKeyBufsPtr Ptr CInt
encKeyBufsLenPtr Ptr CChar
ivPtr Ptr (Ptr EVP_PKEY)
pubKeysPtr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nKeys)

         if CInt
ret CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
== CInt
0 then
             IO ()
cleanup IO ()
-> IO (CipherCtx, [ByteString], ByteString)
-> IO (CipherCtx, [ByteString], ByteString)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (CipherCtx, [ByteString], ByteString)
forall a. IO a
raiseOpenSSLError
           else
             do [CInt]
encKeysLen <- Int -> Ptr CInt -> IO [CInt]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
nKeys Ptr CInt
encKeyBufsLenPtr
                [ByteString]
encKeys    <- (CStringLen -> IO ByteString) -> [CStringLen] -> IO [ByteString]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM CStringLen -> IO ByteString
B8.packCStringLen ([CStringLen] -> IO [ByteString])
-> [CStringLen] -> IO [ByteString]
forall a b. (a -> b) -> a -> b
$ [Ptr CChar] -> [Int] -> [CStringLen]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ptr CChar]
encKeyBufs (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> [CInt] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` [CInt]
encKeysLen)
                ByteString
iv         <- CStringLen -> IO ByteString
B8.packCStringLen (Ptr CChar
ivPtr, Cipher -> Int
cipherIvLength Cipher
cipher)
                IO ()
cleanup
                (CipherCtx, [ByteString], ByteString)
-> IO (CipherCtx, [ByteString], ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (CipherCtx
ctx, [ByteString]
encKeys, ByteString
iv)
    where
      nKeys :: Int
      nKeys :: Int
nKeys = [SomePublicKey] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SomePublicKey]
pubKeys

      mallocEncKeyBuf :: (PKey k, Storable a) => k -> IO (Ptr a)
      mallocEncKeyBuf :: k -> IO (Ptr a)
mallocEncKeyBuf = Int -> IO (Ptr a)
forall a. Storable a => Int -> IO (Ptr a)
mallocArray (Int -> IO (Ptr a)) -> (k -> Int) -> k -> IO (Ptr a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. k -> Int
forall k. PKey k => k -> Int
pkeySize

-- |@'seal'@ lazilly encrypts a stream of data. The input string
-- doesn't necessarily have to be finite.
seal :: Cipher          -- ^ symmetric cipher algorithm to use
     -> [SomePublicKey] -- ^ A list of public keys to encrypt a
                        --   symmetric key. At least one public key
                        --   must be supplied. If two or more keys are
                        --   given, the symmetric key are encrypted by
                        --   each public keys so that any of the
                        --   corresponding private keys can decrypt
                        --   the message.
     -> String          -- ^ input string to encrypt
     -> IO ( String
           , [String]
           , String
           ) -- ^ (encrypted string, list of encrypted asymmetric
             -- keys, IV)
{-# DEPRECATED seal "Use sealBS or sealLBS instead." #-}
seal :: Cipher
-> [SomePublicKey] -> String -> IO (String, [String], String)
seal Cipher
cipher [SomePublicKey]
pubKeys String
input
    = do (ByteString
output, [ByteString]
encKeys, ByteString
iv) <- Cipher
-> [SomePublicKey]
-> ByteString
-> IO (ByteString, [ByteString], ByteString)
sealLBS Cipher
cipher [SomePublicKey]
pubKeys (ByteString -> IO (ByteString, [ByteString], ByteString))
-> ByteString -> IO (ByteString, [ByteString], ByteString)
forall a b. (a -> b) -> a -> b
$ String -> ByteString
L8.pack String
input
         (String, [String], String) -> IO (String, [String], String)
forall (m :: * -> *) a. Monad m => a -> m a
return ( ByteString -> String
L8.unpack ByteString
output
                , ByteString -> String
B8.unpack (ByteString -> String) -> [ByteString] -> [String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` [ByteString]
encKeys
                , ByteString -> String
B8.unpack ByteString
iv
                )

-- |@'sealBS'@ strictly encrypts a chunk of data.
sealBS :: Cipher          -- ^ symmetric cipher algorithm to use
       -> [SomePublicKey] -- ^ list of public keys to encrypt a
                          --   symmetric key
       -> B8.ByteString   -- ^ input string to encrypt
       -> IO ( B8.ByteString
             , [B8.ByteString]
             , B8.ByteString
             ) -- ^ (encrypted string, list of encrypted asymmetric
               -- keys, IV)
sealBS :: Cipher
-> [SomePublicKey]
-> ByteString
-> IO (ByteString, [ByteString], ByteString)
sealBS Cipher
cipher [SomePublicKey]
pubKeys ByteString
input
    = do (CipherCtx
ctx, [ByteString]
encKeys, ByteString
iv) <- Cipher
-> [SomePublicKey] -> IO (CipherCtx, [ByteString], ByteString)
sealInit Cipher
cipher [SomePublicKey]
pubKeys
         ByteString
output             <- CipherCtx -> ByteString -> IO ByteString
cipherStrictly CipherCtx
ctx ByteString
input
         (ByteString, [ByteString], ByteString)
-> IO (ByteString, [ByteString], ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
output, [ByteString]
encKeys, ByteString
iv)

-- |@'sealLBS'@ lazilly encrypts a stream of data. The input string
-- doesn't necessarily have to be finite.
sealLBS :: Cipher          -- ^ symmetric cipher algorithm to use
        -> [SomePublicKey] -- ^ list of public keys to encrypt a
                           --   symmetric key
        -> L8.ByteString   -- ^ input string to encrypt
        -> IO ( L8.ByteString
              , [B8.ByteString]
              , B8.ByteString
              ) -- ^ (encrypted string, list of encrypted asymmetric
                -- keys, IV)
sealLBS :: Cipher
-> [SomePublicKey]
-> ByteString
-> IO (ByteString, [ByteString], ByteString)
sealLBS Cipher
cipher [SomePublicKey]
pubKeys ByteString
input
    = do (CipherCtx
ctx, [ByteString]
encKeys, ByteString
iv) <- Cipher
-> [SomePublicKey] -> IO (CipherCtx, [ByteString], ByteString)
sealInit Cipher
cipher [SomePublicKey]
pubKeys
         ByteString
output             <- CipherCtx -> ByteString -> IO ByteString
cipherLazily CipherCtx
ctx ByteString
input
         (ByteString, [ByteString], ByteString)
-> IO (ByteString, [ByteString], ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
output, [ByteString]
encKeys, ByteString
iv)