{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.RNCryptor.V3.Decrypt
  ( parseHeader
  , decrypt
  , decryptBlock
  , decryptStream
  ) where

import           Control.Monad.State
import           Control.Exception           (throwIO)
import           Crypto.Cipher.AES           (AES256)
import           Crypto.Cipher.Types         (IV, makeIV, BlockCipher, cbcDecrypt)
import           Crypto.MAC.HMAC             (update, finalize)
import           Crypto.RNCryptor.Types
import           Crypto.RNCryptor.V3.Stream
import           Data.Bits                   (xor, (.|.))
import           Data.ByteArray              (convert)
import           Data.ByteString             (ByteString)
import qualified Data.ByteString as B
import           Data.Foldable
import           Data.Maybe                  (fromMaybe)
import           Data.Monoid
import           Data.Word
import qualified System.IO.Streams as S

--------------------------------------------------------------------------------
-- | Parse the input 'ByteString' to extract the 'RNCryptorHeader', as
-- defined in the V3 spec. The incoming 'ByteString' is expected to have
-- at least 34 bytes available. As the HMAC can be found only at the very
-- end of an encrypted file, 'RNCryptorHeader' provides by default a function
-- to parse the HMAC, callable at the right time during streaming/parsing.
parseHeader :: ByteString -> RNCryptorHeader
parseHeader :: ByteString -> RNCryptorHeader
parseHeader ByteString
input = (State ByteString RNCryptorHeader -> ByteString -> RNCryptorHeader)
-> ByteString
-> State ByteString RNCryptorHeader
-> RNCryptorHeader
forall a b c. (a -> b -> c) -> b -> a -> c
flip State ByteString RNCryptorHeader -> ByteString -> RNCryptorHeader
forall s a. State s a -> s -> a
evalState ByteString
input (State ByteString RNCryptorHeader -> RNCryptorHeader)
-> State ByteString RNCryptorHeader -> RNCryptorHeader
forall a b. (a -> b) -> a -> b
$ do
  Word8
v <- State ByteString Word8
parseVersion
  Word8
o <- State ByteString Word8
parseOptions
  ByteString
eSalt <- State ByteString ByteString
parseEncryptionSalt
  ByteString
hmacSalt <- State ByteString ByteString
parseHMACSalt
  ByteString
iv <- State ByteString ByteString
parseIV
  RNCryptorHeader -> State ByteString RNCryptorHeader
forall (m :: * -> *) a. Monad m => a -> m a
return RNCryptorHeader :: Word8
-> Word8
-> ByteString
-> ByteString
-> ByteString
-> RNCryptorHeader
RNCryptorHeader {
      rncVersion :: Word8
rncVersion = Word8
v
    , rncOptions :: Word8
rncOptions = Word8
o
    , rncEncryptionSalt :: ByteString
rncEncryptionSalt = ByteString
eSalt
    , rncHMACSalt :: ByteString
rncHMACSalt = ByteString
hmacSalt
    , rncIV :: ByteString
rncIV = ByteString
iv
    }

--------------------------------------------------------------------------------
parseSingleWord8 :: String -> State ByteString Word8
parseSingleWord8 :: String -> State ByteString Word8
parseSingleWord8 String
err = do
  ByteString
bs <- State ByteString ByteString
forall s (m :: * -> *). MonadState s m => m s
get
  let (ByteString
v,ByteString
vs) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
1 ByteString
bs
  ByteString -> StateT ByteString Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ByteString
vs
  case ByteString -> [Word8]
B.unpack ByteString
v of
    [Word8
x] -> Word8 -> State ByteString Word8
forall (m :: * -> *) a. Monad m => a -> m a
return Word8
x
    [Word8]
_   -> String -> State ByteString Word8
forall a. HasCallStack => String -> a
error String
err

--------------------------------------------------------------------------------
parseBSOfSize :: Int -> String -> State ByteString ByteString
parseBSOfSize :: Int -> String -> State ByteString ByteString
parseBSOfSize Int
sz String
err = do
  ByteString
bs <- State ByteString ByteString
forall s (m :: * -> *). MonadState s m => m s
get
  let (ByteString
v,ByteString
vs) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
sz ByteString
bs
  ByteString -> StateT ByteString Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put ByteString
vs
  case ByteString -> [Word8]
B.unpack ByteString
v of
    [] -> String -> State ByteString ByteString
forall a. HasCallStack => String -> a
error String
err
    [Word8]
_ -> ByteString -> State ByteString ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
v

--------------------------------------------------------------------------------
parseVersion :: State ByteString Word8
parseVersion :: State ByteString Word8
parseVersion = String -> State ByteString Word8
parseSingleWord8 String
"parseVersion: not enough bytes."

--------------------------------------------------------------------------------
parseOptions :: State ByteString Word8
parseOptions :: State ByteString Word8
parseOptions = String -> State ByteString Word8
parseSingleWord8 String
"parseOptions: not enough bytes."

--------------------------------------------------------------------------------
parseEncryptionSalt :: State ByteString ByteString
parseEncryptionSalt :: State ByteString ByteString
parseEncryptionSalt = Int -> String -> State ByteString ByteString
parseBSOfSize Int
8 String
"parseEncryptionSalt: not enough bytes."

--------------------------------------------------------------------------------
parseHMACSalt :: State ByteString ByteString
parseHMACSalt :: State ByteString ByteString
parseHMACSalt = Int -> String -> State ByteString ByteString
parseBSOfSize Int
8 String
"parseHMACSalt: not enough bytes."

--------------------------------------------------------------------------------
parseIV :: State ByteString ByteString
parseIV :: State ByteString ByteString
parseIV = Int -> String -> State ByteString ByteString
parseBSOfSize Int
16 String
"parseIV: not enough bytes."

--------------------------------------------------------------------------------
-- | This was taken directly from the Python implementation, see "post_decrypt_data",
-- even though it doesn't seem to be a usual PKCS#7 removal:
-- data = data[:-bord(data[-1])]
-- https://github.com/RNCryptor/RNCryptor-python/blob/master/RNCryptor.py#L69
removePaddingSymbols :: ByteString -> ByteString
removePaddingSymbols :: ByteString -> ByteString
removePaddingSymbols ByteString
input =
  let lastWord :: Word8
lastWord = ByteString -> Word8
B.last ByteString
input
  in Int -> ByteString -> ByteString
B.take (ByteString -> Int
B.length ByteString
input Int -> Int -> Int
forall a. Num a => a -> a -> a
- Word8 -> Int
forall a. Enum a => a -> Int
fromEnum Word8
lastWord) ByteString
input

--------------------------------------------------------------------------------
decryptBytes :: AES256 -> ByteString -> ByteString -> ByteString
decryptBytes :: AES256 -> ByteString -> ByteString -> ByteString
decryptBytes AES256
a ByteString
iv = AES256 -> IV AES256 -> ByteString -> ByteString
forall cipher ba.
(BlockCipher cipher, ByteArray ba) =>
cipher -> IV cipher -> ba -> ba
cbcDecrypt AES256
a IV AES256
iv'
  where
    iv' :: IV AES256
iv' = IV AES256 -> Maybe (IV AES256) -> IV AES256
forall a. a -> Maybe a -> a
fromMaybe (String -> IV AES256
forall a. HasCallStack => String -> a
error String
"decryptBytes: makeIV failed.") (Maybe (IV AES256) -> IV AES256) -> Maybe (IV AES256) -> IV AES256
forall a b. (a -> b) -> a -> b
$ ByteString -> Maybe (IV AES256)
forall b c. (ByteArrayAccess b, BlockCipher c) => b -> Maybe (IV c)
makeIV ByteString
iv

--------------------------------------------------------------------------------
-- | Decrypt a raw Bytestring block. The function returns the clear text block
-- plus a new 'RNCryptorContext', which is needed because the IV needs to be
-- set to the last 16 bytes of the previous cipher text. (Thanks to Rob Napier
-- for the insight).
decryptBlock :: RNCryptorContext
             -> ByteString
             -> (RNCryptorContext, ByteString)
decryptBlock :: RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
decryptBlock RNCryptorContext
ctx ByteString
cipherText =
  let clearText :: ByteString
clearText   = AES256 -> ByteString -> ByteString -> ByteString
decryptBytes (RNCryptorContext -> AES256
ctxCipher RNCryptorContext
ctx) (RNCryptorHeader -> ByteString
rncIV (RNCryptorHeader -> ByteString)
-> (RNCryptorContext -> RNCryptorHeader)
-> RNCryptorContext
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RNCryptorContext -> RNCryptorHeader
ctxHeader (RNCryptorContext -> ByteString) -> RNCryptorContext -> ByteString
forall a b. (a -> b) -> a -> b
$ RNCryptorContext
ctx) ByteString
cipherText
      !newHMACCtx :: Context SHA256
newHMACCtx = Context SHA256 -> ByteString -> Context SHA256
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
update (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx) ByteString
cipherText
      !sz :: Int
sz         = ByteString -> Int
B.length ByteString
cipherText
      !newHeader :: RNCryptorHeader
newHeader  = (RNCryptorContext -> RNCryptorHeader
ctxHeader RNCryptorContext
ctx) { rncIV :: ByteString
rncIV = Int -> ByteString -> ByteString
B.drop (Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ByteString
cipherText }
      in (RNCryptorContext
ctx { ctxHeader :: RNCryptorHeader
ctxHeader = RNCryptorHeader
newHeader, ctxHMACCtx :: Context SHA256
ctxHMACCtx = Context SHA256
newHMACCtx }, ByteString
clearText)

--------------------------------------------------------------------------------
-- "A consistent time function needs to be clear on which parameter is secret and
-- which one is untrusted. Your complexity must always be proportional to the length
-- of the untrusted data, not the secret."
--
-- Below, untrusted == arrived in the message, secret == computed
--
consistentTimeEqual :: ByteString -> ByteString -> Bool
consistentTimeEqual :: ByteString -> ByteString -> Bool
consistentTimeEqual ByteString
untrusted ByteString
secret =
  let (Word8
initialResult :: Word8) = if ByteString -> Int
B.length ByteString
secret Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString -> Int
B.length ByteString
untrusted then Word8
0 else Word8
1
      secretCycle :: [Word8]
secretCycle = [Word8] -> [Word8]
forall a. [a] -> [a]
cycle (ByteString -> [Word8]
B.unpack ByteString
secret)
      xorResults :: [Word8]
xorResults = (Word8 -> Word8 -> Word8) -> [Word8] -> [Word8] -> [Word8]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor (ByteString -> [Word8]
B.unpack ByteString
untrusted) [Word8]
secretCycle
  in Word8
0 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== (Word8 -> Word8 -> Word8) -> Word8 -> [Word8] -> Word8
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
(.|.) Word8
initialResult [Word8]
xorResults

--------------------------------------------------------------------------------
-- | Decrypt an encrypted message. Please be aware that this is a user-friendly
-- but dangerous function, in the sense that it will load the *ENTIRE* input in
-- memory. It's mostly suitable for small inputs like passwords. For large
-- inputs, where size exceeds the available memory, please use 'decryptStream'.
--
-- Returns either the reason for failure, or the successfully decrypted message.
decrypt :: ByteString -> ByteString -> Either RNCryptorException ByteString
decrypt :: ByteString -> ByteString -> Either RNCryptorException ByteString
decrypt ByteString
input ByteString
pwd =
  let (ByteString
rawHdr, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
34 ByteString
input
      -- remove the hmac at the end of the file
      (ByteString
cipherText, ByteString
msgHMAC) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length ByteString
rest Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
32) ByteString
rest
      hdr :: RNCryptorHeader
hdr = ByteString -> RNCryptorHeader
parseHeader ByteString
rawHdr
      ctx :: RNCryptorContext
ctx = ByteString -> RNCryptorHeader -> RNCryptorContext
newRNCryptorContext ByteString
pwd RNCryptorHeader
hdr
      clearText :: ByteString
clearText = AES256 -> ByteString -> ByteString -> ByteString
decryptBytes (RNCryptorContext -> AES256
ctxCipher RNCryptorContext
ctx) (RNCryptorHeader -> ByteString
rncIV (RNCryptorHeader -> ByteString)
-> (RNCryptorContext -> RNCryptorHeader)
-> RNCryptorContext
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RNCryptorContext -> RNCryptorHeader
ctxHeader (RNCryptorContext -> ByteString) -> RNCryptorContext -> ByteString
forall a b. (a -> b) -> a -> b
$ RNCryptorContext
ctx) ByteString
cipherText
      hmac :: ByteString
hmac = ByteString -> ByteString -> ByteString -> ByteString
makeHMAC (RNCryptorHeader -> ByteString
rncHMACSalt (RNCryptorHeader -> ByteString)
-> (RNCryptorContext -> RNCryptorHeader)
-> RNCryptorContext
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RNCryptorContext -> RNCryptorHeader
ctxHeader (RNCryptorContext -> ByteString) -> RNCryptorContext -> ByteString
forall a b. (a -> b) -> a -> b
$ RNCryptorContext
ctx) ByteString
pwd (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
rawHdr ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
cipherText
  in case ByteString -> ByteString -> Bool
consistentTimeEqual ByteString
msgHMAC ByteString
hmac of
       Bool
True  -> ByteString -> Either RNCryptorException ByteString
forall a b. b -> Either a b
Right (ByteString -> ByteString
removePaddingSymbols ByteString
clearText)
       Bool
False -> RNCryptorException -> Either RNCryptorException ByteString
forall a b. a -> Either a b
Left (ByteString -> ByteString -> RNCryptorException
InvalidHMACException ByteString
msgHMAC ByteString
hmac)

--------------------------------------------------------------------------------
-- | Efficiently decrypts an incoming stream of bytes.
decryptStream :: ByteString
              -- ^ The user key (e.g. password)
              -> S.InputStream ByteString
              -- ^ The input source (mostly likely stdin)
              -> S.OutputStream ByteString
              -- ^ The output source (mostly likely stdout)
              -> IO ()
decryptStream :: ByteString
-> InputStream ByteString -> OutputStream ByteString -> IO ()
decryptStream ByteString
userKey InputStream ByteString
inS OutputStream ByteString
outS = do
  ByteString
rawHdr <- Int -> InputStream ByteString -> IO ByteString
S.readExactly Int
34 InputStream ByteString
inS
  let hdr :: RNCryptorHeader
hdr = ByteString -> RNCryptorHeader
parseHeader ByteString
rawHdr
      ctx :: RNCryptorContext
ctx = ByteString -> RNCryptorHeader -> RNCryptorContext
newRNCryptorContext ByteString
userKey RNCryptorHeader
hdr
      ctx' :: RNCryptorContext
ctx' = RNCryptorContext
ctx { ctxHMACCtx :: Context SHA256
ctxHMACCtx = Context SHA256 -> ByteString -> Context SHA256
forall message a.
(ByteArrayAccess message, HashAlgorithm a) =>
Context a -> message -> Context a
update (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx) ByteString
rawHdr }
  RNCryptorContext
-> InputStream ByteString
-> OutputStream ByteString
-> (RNCryptorContext
    -> ByteString -> (RNCryptorContext, ByteString))
-> (ByteString -> RNCryptorContext -> IO ())
-> IO ()
processStream RNCryptorContext
ctx' InputStream ByteString
inS OutputStream ByteString
outS RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
decryptBlock ByteString -> RNCryptorContext -> IO ()
finaliseDecryption
  where
    finaliseDecryption :: ByteString -> RNCryptorContext -> IO ()
finaliseDecryption ByteString
lastBlock RNCryptorContext
ctx = do
      let (ByteString
cipherText, ByteString
msgHMAC) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt (ByteString -> Int
B.length ByteString
lastBlock Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
32) ByteString
lastBlock
          (RNCryptorContext
ctx', ByteString
clearText)     = RNCryptorContext -> ByteString -> (RNCryptorContext, ByteString)
decryptBlock RNCryptorContext
ctx ByteString
cipherText
          hmac :: ByteString
hmac = HMAC SHA256 -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (HMAC SHA256 -> ByteString) -> HMAC SHA256 -> ByteString
forall a b. (a -> b) -> a -> b
$ Context SHA256 -> HMAC SHA256
forall a. HashAlgorithm a => Context a -> HMAC a
finalize (RNCryptorContext -> Context SHA256
ctxHMACCtx RNCryptorContext
ctx')
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> ByteString -> Bool
consistentTimeEqual ByteString
msgHMAC ByteString
hmac) (RNCryptorException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (RNCryptorException -> IO ()) -> RNCryptorException -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> RNCryptorException
InvalidHMACException ByteString
msgHMAC ByteString
hmac)
      Maybe ByteString -> OutputStream ByteString -> IO ()
forall a. Maybe a -> OutputStream a -> IO ()
S.write (ByteString -> Maybe ByteString
forall a. a -> Maybe a
Just (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
removePaddingSymbols ByteString
clearText) OutputStream ByteString
outS