{-# LANGUAGE ScopedTypeVariables #-}

{-|
Description: Encryption of bytestrings using a type level nonce for determinism
License: BSD3

Given a value of an arbitrary serializable type (like 'Int') we perform
serialization and compute a cryptographic hash of the associated namespace
(carried as a phantom type of kind 'Symbol').
The serializedpayload is then encrypted using the symmetric cipher in CBC mode
using the hashed namespace as an initialization vector (IV).

Since the serialized payload is padded such that its length is an integer
multiple of the block size we can detect namespace mismatches by checking that
all bytes expected to have been inserted during padding are nil.

The probability of detecting a namespace mismatch is thus \(1 - 2^{l \
\text{mod} \ 64}\) where \(l\) is the length of the serialized payload in bits.
-}
module Data.CryptoID.Poly
  ( CryptoID(..)
  , CryptoIDKey
  , genKey, readKeyFile
  , encrypt
  , decrypt
  , CryptoIDError(..)
  , CryptoCipher, CryptoHash
  ) where

import Data.CryptoID
import Data.CryptoID.ByteString hiding (encrypt, decrypt)
import qualified Data.CryptoID.ByteString as ByteString (encrypt, decrypt)

import Data.Binary

import Data.ByteString (ByteString)
import qualified Data.ByteString.Lazy as Lazy.ByteString

import GHC.TypeLits

import Control.Monad.Catch (MonadThrow(..))
  

_ciphertext :: Functor m => (a -> m b) -> CryptoID n a -> m (CryptoID n b)
_ciphertext f (CryptoID x) = CryptoID <$> f x

  
-- | Encrypt a serialized value
encrypt :: forall a m c namespace.
           ( KnownSymbol namespace
           , MonadThrow m
           , Binary a
           ) => (ByteString -> m c) -> CryptoIDKey -> a -> m (CryptoID namespace c)
encrypt encode' key plaintext = do
  cID <- ByteString.encrypt key . Lazy.ByteString.toStrict $ encode plaintext
  _ciphertext encode' cID


-- | Decrypt a serialized value
decrypt :: forall a m c namespace.
           ( KnownSymbol namespace
           , MonadThrow m
           , Binary a
           ) => (c -> m ByteString) -> CryptoIDKey -> CryptoID namespace c -> m a
decrypt decode key cID = do
  cID' <- _ciphertext decode cID
  plaintext <- Lazy.ByteString.fromStrict <$> ByteString.decrypt key cID'

  case decodeOrFail plaintext of
    Left err -> throwM $ DeserializationError err
    Right (rem, _, res)
      | Lazy.ByteString.all (== 0) rem -> return res
      | otherwise -> throwM InvalidNamespaceDetected