{-# LANGUAGE ScopedTypeVariables #-}

{-|
Description: Reversably generate filepaths from arbitrary serializable types in a secure fashion
License: BSD3

Given a value of a serializable type (like 'Int') we perform serialization and
compute a cryptographic hash of the associated namespace (carried as a phantom
type of kind 'Symbol').
The serialized payload is encrypted using a symmetric cipher in CBC mode using
the hashed namespace as an initialization vector (IV).

The ciphertext is then
<https://hackage.haskell.org/package/sandi/docs/Codec-Binary-Base32.html base32>-encoded
and padding stripped.

Rather than being indicated by the amount of padding, the length of the
serialized plaintext is instead carried at the type level within
'CryptoFileName' (analogously to the namespace).
Mismatches in serialized plaintext length are checked for but are /not/
guaranteed to cause runtime errors in all cases.

Since the serialized payload is padded to the length of the next cipher block we
can detect namespace mismatches by checking that all bytes expected to have been
inserted during padding are nil.
The probability of detecting a namespace mismatch is thus \(1 - 2^{b \left \lceil \frac{l}{b} \right \rceil - l}\)
where \(l\) is the length of the serialized payload and \(b\) the length of a
ciphertext block (both in bits).
-}
module System.FilePath.Cryptographic
  ( CryptoID(..)
  , CryptoFileName
  , module Data.Binary.SerializationLength
  , encrypt
  , decrypt
  , CryptoIDError(..)
  ) where

import Data.CryptoID
import Data.CryptoID.Poly hiding (encrypt, decrypt)
import Data.CryptoID.ByteString (cipherBlockSize)
import qualified Data.CryptoID.Poly as Poly (encrypt, decrypt)

import System.FilePath (FilePath)
import qualified Codec.Binary.Base32 as Base32
import Data.CaseInsensitive (CI)
import qualified Data.CaseInsensitive as CI
import Data.Binary
import Data.Binary.SerializationLength
import Data.Char (toUpper)

import Data.Ratio ((%))
import Data.List

import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Char8 as ByteString.Char8

import Control.Monad
import Control.Monad.Catch

import Data.Proxy
import GHC.TypeLits


type CryptoFileName (namespace :: Symbol) = CryptoID namespace (CI FilePath)


paddedLength :: Integral a => a -> a
-- | Round up to nearest multiple of 'cipherBlockSize'
paddedLength l = bs * ceiling (l % bs)
  where bs = fromIntegral cipherBlockSize

-- | Encrypt an arbitrary serializable value
--
-- We only expect to fail if the given value is not serialized in such a fashion
-- that it meets the expected length given at type level.
encrypt :: forall a m namespace.
           ( KnownSymbol namespace
           , Binary a
           , MonadThrow m
           , HasFixedSerializationLength a
           ) => CryptoIDKey -> a -> m (CryptoFileName namespace)
encrypt = Poly.encrypt determineLength $ return . encode
  where
    determineLength str = do
      let l = ByteString.length str
      unless (fromIntegral l == natVal (Proxy :: Proxy (SerializationLength a))) $
        throwM $ CiphertextConversionFailed str
      return . Just $ paddedLength l
    encode str = CI.mk . dropWhileEnd (== '=') . ByteString.Char8.unpack $ Base32.encode str


-- | Decrypt an arbitrary serializable value
--
-- Since no integrity guarantees can be made (we do not sign the values we
-- 'encrypt') it is likely that deserialization will fail emitting
-- 'DeserializationError' or 'InvalidNamespaceDetected'.
decrypt :: forall a m namespace.
           ( KnownSymbol namespace
           , Binary a
           , MonadThrow m
           , HasFixedSerializationLength a
           ) => CryptoIDKey -> CryptoFileName namespace -> m a
decrypt = Poly.decrypt $ (\str -> either (const . throwM $ CiphertextConversionFailed str) return $ Base32.decode str) . ByteString.Char8.pack . padding (natVal (Proxy :: Proxy (SerializationLength a))) . map toUpper . CI.original
  where
    padding l str = str ++ replicate (genericIndex paddingTable $ l' `mod` 5) '='
      where
        l' = paddedLength l
    paddingTable = [0, 6, 4, 3, 1]