{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

{-|
Description: Encryption of bytestrings using a type level nonce for determinism
License: BSD3

Given a strict 'ByteString' we compute a cryptographic hash of the associated
namespace (carried as a phantom type of kind 'Symbol').
The payload is then encrypted using the symmetric cipher in CBC mode using the
hashed namespace as an initialization vector (IV).

The probability of detecting a namespace mismatch is thus the density of valid
payloads within all 'ByteString's of the correct length.
-}
module Data.CryptoID.ByteString
  ( CryptoByteString
  , HasCryptoByteString
  , CryptoIDKey
  , genKey, readKeyFile
  , encrypt
  , decrypt
  , CryptoIDError(..)
  , CryptoCipher, CryptoHash
  , cipherBlockSize
  , module Data.CryptoID
  , module Data.CryptoID.Class
  ) where

import Data.CryptoID
import Data.CryptoID.Class hiding (encrypt, decrypt)
import qualified Data.CryptoID.Class as Class (encrypt, decrypt)

import Data.Binary
import Data.Binary.Put
import Data.Binary.Get

import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Char8 as ByteString.Char

import Data.List (sortOn)
import Data.Ord (Down(..))

import Data.ByteArray (ByteArrayAccess)
import qualified Data.ByteArray as ByteArray

import Data.Foldable (asum)
import Control.Monad.Catch (MonadThrow(..))
import Control.Monad.IO.Class
import Control.Monad
import Control.Exception
import System.IO.Error

import Data.Typeable
import GHC.TypeLits

import Crypto.Cipher.Types
import Crypto.Cipher.Blowfish (Blowfish)
import Crypto.Hash (hash, Digest)
import Crypto.Hash.Algorithms (SHAKE128)
import Crypto.Error

import Crypto.Random.Entropy

import System.Directory
import System.FilePath


-- | The symmetric cipher 'BlockCipher' this module uses 
type CryptoCipher = Blowfish
-- | The cryptographic 'HashAlgorithm' this module uses
--
-- We expect the block size of 'CryptoCipher' to be exactly the size of the
-- 'Digest' generated by 'CryptoHash' (since a 'Digest' is used as an 'IV').
--
-- Violation of this expectation causes runtime errors.
type CryptoHash   = SHAKE128 64


cipherBlockSize :: Int
cipherBlockSize = blockSize (undefined :: CryptoCipher)


-- | This newtype ensures only keys of the correct length can be created
--
-- Use 'genKey' to securely generate keys.
--
-- Use the 'Binary' instance to save and restore values of 'CryptoIDKey' across
-- executions.
newtype CryptoIDKey = CryptoIDKey { keyMaterial :: ByteString }
  deriving (Typeable, ByteArrayAccess)

-- | Does not actually show any key material
instance Show CryptoIDKey where
  show = show . typeOf

instance Binary CryptoIDKey where
  put = putByteString . keyMaterial
  get = CryptoIDKey <$> getKey (cipherKeySize cipher)
    where
      cipher :: CryptoCipher
      cipher = undefined

      -- Try key sizes from large to small ('Get' commits to the first branch
      -- that parses)
      getKey (KeySizeFixed n) = getByteString n
      getKey (KeySizeEnum ns) = asum [ getKey $ KeySizeFixed n | n <- sortOn Down ns ]
      getKey (KeySizeRange min max) = getKey $ KeySizeEnum [min .. max]


-- | Error cases that can be encountered during 'encrypt' and 'decrypt'
--
-- Care has been taken to ensure that presenting values of 'CryptoIDError' to an
-- attacker leaks no plaintext (it does leak information about the length of the
-- plaintext).
data CryptoIDError
  = AlgorithmError CryptoError
    -- ^ One of the underlying cryptographic algorithms
    --   ('CryptoHash' or 'CryptoCipher') failed.
  | PlaintextIsWrongLength Int
    -- ^ The length of the plaintext is not a multiple of the block size of
    --   'CryptoCipher'
    --
    -- The length of the offending plaintext is included.
  | CiphertextIsWrongLength ByteString
    -- ^ The length of the ciphertext is not a multiple of the block size of
    --   'CryptoCipher'
    --
    -- The offending ciphertext is included.
  | NamespaceHashIsWrongLength ByteString
    -- ^ The length of the digest produced by 'CryptoHash' does
    --   not match the block size of 'CryptoCipher'.
    --
    -- The offending digest is included.
    --
    -- This error should not occur and is included primarily
    -- for sake of totality.
  | CiphertextConversionFailed ByteString
    -- ^ The produced 'ByteString' is the wrong length for deserialization into
    --   a ciphertext.
    --
    -- The offending 'ByteString' is included.
  | DeserializationError
    -- ^ The plaintext obtained by decrypting a ciphertext with the given
    --   'CryptoIDKey' in the context of the @namespace@ could not be
    --   deserialized into a value of the expected @payload@-type.
    --
    -- This is expected behaviour if the @namespace@ or @payload@-type does not
    -- match the ones used during 'encrypt'ion or if the 'ciphertext' was
    -- tempered with.
  | InvalidNamespaceDetected
    -- ^ We have determined that, allthough deserializion succeded, the
    --   ciphertext was likely modified during transit or created using a
    --   different namespace.
  deriving (Show, Eq)

instance Exception CryptoIDError

-- | Securely generate a new key using system entropy
--
-- When 'CryptoCipher' accepts keys of varying lengths this function generates a
-- key of the largest accepted size.
genKey :: MonadIO m => m CryptoIDKey
genKey = CryptoIDKey <$> liftIO (getEntropy keySize)
  where
    keySize' = cipherKeySize (undefined :: CryptoCipher)

    keySize
      | KeySizeFixed n <- keySize' = n
      | KeySizeEnum ns <- keySize' = maximum ns
      | KeySizeRange _ max <- keySize' = max

-- | Try to read a 'CryptoIDKey' from a file.
--   If the file does not exist, securely generate a key (using 'genKey') and
--   save it to the file. 
readKeyFile :: MonadIO m => FilePath -> m CryptoIDKey
readKeyFile keyFile = liftIO $ decodeFile keyFile `catch` generateInstead
  where
    generateInstead e
      | isDoesNotExistError e = do
          createDirectoryIfMissing True $ takeDirectory keyFile
          key <- genKey
          encodeFile keyFile key
          return key
      | otherwise = throw e


type CryptoByteString (namespace :: Symbol) = CryptoID namespace ByteString

type HasCryptoByteString (namespace :: Symbol) = HasCryptoID namespace ByteString


-- | Use 'CryptoHash' to generate a 'Digest' of the Symbol passed as proxy type
namespace' :: forall proxy namespace m.
              ( KnownSymbol namespace, MonadThrow m
              ) => proxy namespace -> m (IV CryptoCipher)
namespace' p = case makeIV namespaceHash of
                 Nothing -> throwM . NamespaceHashIsWrongLength $ ByteArray.convert namespaceHash
                 Just iv -> return iv
  where
    namespaceHash :: Digest CryptoHash
    namespaceHash = hash . ByteString.Char.pack $ symbolVal p

-- | Wrap failure of one of the cryptographic algorithms as a 'CryptoIDError'
cryptoFailable :: MonadThrow m => CryptoFailable a -> m a
cryptoFailable = either (throwM . AlgorithmError) return . eitherCryptoError

-- | Encrypt a serialized value
encrypt :: forall m namespace.
             ( KnownSymbol namespace
             , MonadThrow m
             ) => CryptoIDKey -> ByteString -> m (CryptoID namespace ByteString)
encrypt (keyMaterial -> key) plaintext = do
  cipher <- cryptoFailable (cipherInit key :: CryptoFailable CryptoCipher)
  namespace <- namespace' (Proxy :: Proxy namespace)
  when (ByteString.length plaintext `mod` blockSize cipher /= 0) $
    throwM . PlaintextIsWrongLength $ ByteString.length plaintext
  return . CryptoID $ cbcEncrypt cipher namespace plaintext


-- | Decrypt a serialized value
decrypt :: forall m namespace.
             ( KnownSymbol namespace
             , MonadThrow m
             ) => CryptoIDKey -> CryptoID namespace ByteString -> m ByteString
decrypt (keyMaterial -> key) CryptoID{..} = do
  cipher <- cryptoFailable (cipherInit key :: CryptoFailable CryptoCipher)
  namespace <- namespace' (Proxy :: Proxy namespace)
  when (ByteString.length ciphertext `mod` blockSize cipher /= 0) $
    throwM $ CiphertextIsWrongLength ciphertext
  return $ cbcDecrypt cipher namespace ciphertext

-- | This instance is somewhat improper in that it works only for plain- and
--   ciphertexts whose length is a multiple of 'cipherBlockSize'
--
-- Improper plaintext lengths throw 'PlaintextIsWrongLength'
--
-- Improper ciphertext lengths throw 'CiphertextIsWrongLength'
instance ( MonadCrypto m
         , MonadCryptoKey m ~ CryptoIDKey
         , KnownSymbol namespace
         ) => HasCryptoID namespace ByteString ByteString m where
  encrypt = cryptoIDKey . flip encrypt
  decrypt = cryptoIDKey . flip decrypt