{-# LANGUAGE BangPatterns #-}
module Crypto.RNCryptor.V3.Encrypt
  ( encrypt
  , encryptBlock
  , encryptStream
  , encryptStreamWithContext
  ) where

import           Crypto.Cipher.AES          (AES256)
import           Crypto.Cipher.Types        (makeIV, IV, BlockCipher, cbcEncrypt)
import           Crypto.MAC.HMAC            (update, finalize)
import           Crypto.RNCryptor.Padding
import           Crypto.RNCryptor.Types
import           Crypto.RNCryptor.V3.Stream
import           Data.ByteArray             (convert)
import           Data.ByteString            (ByteString)
import qualified Data.ByteString as B
import           Data.Maybe                 (fromMaybe)
import           Data.Monoid
import qualified System.IO.Streams as S

encryptBytes :: AES256 -> ByteString -> ByteString -> ByteString
encryptBytes :: AES256 -> ByteString -> ByteString -> ByteString
encryptBytes AES256
a ByteString
iv = AES256 -> IV AES256 -> ByteString -> ByteString
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcEncrypt AES256
a IV AES256
iv'
  where
    iv' :: IV AES256
iv' = IV AES256 -> Maybe (IV AES256) -> IV AES256
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> IV AES256
forall a. HasCallStack => [Char] -> a
error ([Char] -> IV AES256) -> [Char] -> IV AES256
forall a b. (a -> b) -> a -> b
$ [Char]
"encryptBytes: makeIV failed (iv was: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Word8] -> [Char]
forall a. Show a => a -> [Char]
show (ByteString -> [Word8]
B.unpack ByteString
iv) [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
")") (Maybe (IV AES256) -> IV AES256) -> Maybe (IV AES256) -> IV AES256
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe (IV AES256)
forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV ByteString
iv

--------------------------------------------------------------------------------
-- | Encrypt a raw Bytestring block. The function returns the encrypt text block
-- plus a new 'RNCryptorContext', which is needed because the IV needs to be
-- set to the last 16 bytes of the previous cipher text. (Thanks to Rob Napier
-- for the insight).
encryptBlock :: RNCryptorContext
             -> ByteString
             -> (RNCryptorContext, ByteString)
encryptBlock :: RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
encryptBlock RNCryptorContext
ctx ByteString
clearText =
  let cipherText :: ByteString
cipherText = AES256 -> ByteString -> ByteString -> ByteString
encryptBytes (RNCryptorContext -> AES256
ctxCipher RNCryptorContext
ctx) (RNCryptorHeader -> ByteString
rncIV (RNCryptorHeader -> ByteString)
-> (RNCryptorContext -> RNCryptorHeader)
-> RNCryptorContext
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RNCryptorContext -> RNCryptorHeader
ctxHeader (RNCryptorContext -> ByteString) -> RNCryptorContext -> ByteString
forall a b. (a -> b) -> a -> b
$ RNCryptorContext
ctx) ByteString
clearText
      !newHmacCtx :: Context SHA256
newHmacCtx = Context SHA256 -> ByteString -> Context SHA256
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
update (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx) ByteString
cipherText
      !sz :: Int
sz         = ByteString -> Int
B.length ByteString
clearText
      !newHeader :: RNCryptorHeader
newHeader  = (RNCryptorContext -> RNCryptorHeader
ctxHeader RNCryptorContext
ctx) { rncIV :: ByteString
rncIV = Int -> ByteString -> ByteString
B.drop (Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
cipherText }
      in (RNCryptorContext
ctx { ctxHeader :: RNCryptorHeader
ctxHeader = RNCryptorHeader
newHeader, ctxHMACCtx :: Context SHA256
ctxHMACCtx = Context SHA256
newHmacCtx }, ByteString
cipherText)

--------------------------------------------------------------------------------
-- | Encrypt a message. Please be aware that this is a user-friendly
-- but dangerous function, in the sense that it will load the *ENTIRE* input in
-- memory. It's mostly suitable for small inputs like passwords. For large
-- inputs, where size exceeds the available memory, please use 'encryptStream'.
encrypt :: RNCryptorContext -> ByteString -> ByteString
encrypt :: RNCryptorContext -> ByteString -> ByteString
encrypt RNCryptorContext
ctx ByteString
input =
  let msgHdr :: ByteString
msgHdr  = RNCryptorHeader -> ByteString
renderRNCryptorHeader (RNCryptorHeader -> ByteString) -> RNCryptorHeader -> ByteString
forall a b. (a -> b) -> a -> b
$ RNCryptorContext -> RNCryptorHeader
ctxHeader RNCryptorContext
ctx
      ctx' :: RNCryptorContext
ctx'    = RNCryptorContext
ctx { ctxHMACCtx :: Context SHA256
ctxHMACCtx = Context SHA256 -> ByteString -> Context SHA256
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
update (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx) ByteString
msgHdr }
      (RNCryptorContext
ctx'', ByteString
cipherText) = RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
encryptBlock RNCryptorContext
ctx' (ByteString
input ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Int -> ByteString
pkcs7Padding Int
blockSize (ByteString -> Int
B.length ByteString
input))
      msgHMAC :: ByteString
msgHMAC = HMAC SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (HMAC SHA256 -> ByteString) -> HMAC SHA256 -> ByteString
forall a b. (a -> b) -> a -> b
$ Context SHA256 -> HMAC SHA256
forall a. HashAlgorithm a => Context a -> HMAC a
finalize (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx'')
  in ByteString
msgHdr ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
cipherText ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
msgHMAC

--------------------------------------------------------------------------------
-- | Efficiently encrypt an incoming stream of bytes.
encryptStreamWithContext :: RNCryptorContext
                         -- ^ The RNCryptorContext
                         -> S.InputStream ByteString
                         -- ^ The input source (mostly likely stdin)
                         -> S.OutputStream ByteString
                         -- ^ The output source (mostly likely stdout)
                         -> IO ()
encryptStreamWithContext :: RNCryptorContext
-> InputStream ByteString -> OutputStream ByteString -> IO ()
encryptStreamWithContext RNCryptorContext
ctx InputStream ByteString
inS OutputStream ByteString
outS = do
  Maybe ByteString -> OutputStream ByteString -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
S.write (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (RNCryptorHeader -> ByteString
renderRNCryptorHeader (RNCryptorHeader -> ByteString) -> RNCryptorHeader -> ByteString
forall a b. (a -> b) -> a -> b
$ RNCryptorContext -> RNCryptorHeader
ctxHeader RNCryptorContext
ctx)) OutputStream ByteString
outS
  RNCryptorContext
-> InputStream ByteString
-> OutputStream ByteString
-> (RNCryptorContext
    -> ByteString -> (RNCryptorContext, ByteString))
-> (ByteString -> RNCryptorContext -> IO ())
-> IO ()
processStream RNCryptorContext
ctx InputStream ByteString
inS OutputStream ByteString
outS RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
encryptBlock ByteString -> RNCryptorContext -> IO ()
finaliseEncryption
  where
    finaliseEncryption :: ByteString -> RNCryptorContext -> IO ()
finaliseEncryption ByteString
lastBlock RNCryptorContext
lastCtx = do
      let (RNCryptorContext
ctx', ByteString
cipherText) = RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
encryptBlock RNCryptorContext
lastCtx (ByteString
lastBlock ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Int -> Int -> ByteString
pkcs7Padding Int
blockSize (ByteString -> Int
B.length ByteString
lastBlock))
      Maybe ByteString -> OutputStream ByteString -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
S.write (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just ByteString
cipherText) OutputStream ByteString
outS
      Maybe ByteString -> OutputStream ByteString -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
S.write (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (HMAC SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (HMAC SHA256 -> ByteString) -> HMAC SHA256 -> ByteString
forall a b. (a -> b) -> a -> b
$ Context SHA256 -> HMAC SHA256
forall a. HashAlgorithm a => Context a -> HMAC a
finalize (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx'))) OutputStream ByteString
outS

--------------------------------------------------------------------------------
-- | Efficiently encrypt an incoming stream of bytes.
encryptStream :: Password
              -- ^ The user key (e.g. password)
              -> S.InputStream ByteString
              -- ^ The input source (mostly likely stdin)
              -> S.OutputStream ByteString
              -- ^ The output source (mostly likely stdout)
              -> IO ()
encryptStream :: ByteString
-> InputStream ByteString -> OutputStream ByteString -> IO ()
encryptStream ByteString
userKey InputStream ByteString
inS OutputStream ByteString
outS = do
  RNCryptorHeader
hdr <- IO RNCryptorHeader
newRNCryptorHeader
  let ctx :: RNCryptorContext
ctx     = ByteString -> RNCryptorHeader -> RNCryptorContext
newRNCryptorContext ByteString
userKey RNCryptorHeader
hdr
      msgHdr :: ByteString
msgHdr  = RNCryptorHeader -> ByteString
renderRNCryptorHeader RNCryptorHeader
hdr
      ctx' :: RNCryptorContext
ctx'    = RNCryptorContext
ctx { ctxHMACCtx :: Context SHA256
ctxHMACCtx = Context SHA256 -> ByteString -> Context SHA256
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
update (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx) ByteString
msgHdr }
  RNCryptorContext
-> InputStream ByteString -> OutputStream ByteString -> IO ()
encryptStreamWithContext RNCryptorContext
ctx' InputStream ByteString
inS OutputStream ByteString
outS