{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

{-|
Description: Encryption of serializable values using a type level nonce for determinism
License: BSD3

Given a value of an arbitrary serializable type (like 'Int') we perform
serialization and compute a cryptographic hash of the associated namespace
(carried as a phantom type of kind 'Symbol').
The serializedpayload is then encrypted using the symmetric cipher in CBC mode
using the hashed namespace as an initialization vector (IV).

Since the serialized payload is padded such that its length is an integer
multiple of the block size we can detect namespace mismatches by checking that
all bytes expected to have been inserted during padding are nil.

The probability of detecting a namespace mismatch is thus \(1 - 2^{l \
\text{mod} \ 64}\) where \(l\) is the length of the serialized payload in bits.
-}
module Data.CryptoID.Poly
  ( encrypt
  , decrypt
  , module Data.CryptoID.ByteString
  ) where

import Data.CryptoID.ByteString hiding (encrypt, decrypt)
import qualified Data.CryptoID.ByteString as ByteString (encrypt, decrypt)
import Data.CryptoID.Class (HasCryptoID)
import qualified Data.CryptoID.Class as Class (HasCryptoID(..))

import Data.Binary

import Data.Monoid

import Data.ByteString (ByteString)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString

import GHC.TypeLits

import Control.Monad
import Control.Monad.Catch (MonadThrow(..))
  

_ciphertext :: Functor m => (a -> m b) -> CryptoID n a -> m (CryptoID n b)
_ciphertext f (CryptoID x) = CryptoID <$> f x

  
-- | Encrypt a serialized value
encrypt :: forall a m c namespace.
           ( KnownSymbol namespace
           , MonadThrow m
           , Binary a
           ) => (ByteString -> m (Maybe Int)) -- ^ Ensure the resulting ciphertext is of the provided length (needs to be a multiple of the block size of 'CryptoCipher' in bytes, otherwise an exception will be thrown at runtime). The computation has access to the serialized plaintext
        -> (ByteString -> m c)
        -> CryptoIDKey
        -> a
        -> m (CryptoID namespace c)
encrypt pLength' encode' key plaintext = do
  cID <- ByteString.encrypt key <=< (\str -> pad str =<< pLength' str) . Lazy.ByteString.toStrict $ encode plaintext
  _ciphertext encode' cID
  where
    pad str pLength
      | Just l <- pLength
      , l' <= l           = return $ str <> ByteString.replicate (l - l') 0
      | Just _ <- pLength = throwM $ CiphertextConversionFailed str
      | otherwise         = return str
      where
        l' = ByteString.length str


-- | Decrypt a serialized value
decrypt :: forall a m c namespace.
           ( KnownSymbol namespace
           , MonadThrow m
           , Binary a
           ) => (c -> m ByteString) -> CryptoIDKey -> CryptoID namespace c -> m a
decrypt decode key cID = do
  cID' <- _ciphertext decode cID
  plaintext <- Lazy.ByteString.fromStrict <$> ByteString.decrypt key cID'

  case decodeOrFail plaintext of
    Left _ -> throwM DeserializationError
    Right (rem, _, res)
      | Lazy.ByteString.all (== 0) rem -> return res
      | otherwise -> throwM InvalidNamespaceDetected

instance ( MonadCrypto m
         , MonadCryptoKey m ~ CryptoIDKey
         , KnownSymbol namespace
         , Binary a
         ) => HasCryptoID namespace ByteString a m where
  encrypt = cryptoIDKey . flip (encrypt (const $ return Nothing) return)
  decrypt = cryptoIDKey . flip (decrypt return)