{-# LANGUAGE StrictData #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Wai.CryptoCookie.Encryption.AEAD_AES_256_GCM_SIV () where

import Crypto.Cipher.AES qualified as CAES
import Crypto.Cipher.AESGCMSIV qualified as CAGS
import Crypto.Cipher.Types qualified as CAES
import Crypto.Error qualified as C
import Crypto.Random qualified as C
import Data.ByteArray qualified as BA
import Data.ByteArray.Parse qualified as BAP
import Data.ByteArray.Sized qualified as BAS
import Data.ByteString qualified as B
import Data.ByteString.Lazy qualified as BL

import Wai.CryptoCookie.Encryption

-- | @AEAD_AES_256_GCM_SIV@ is a nonce-misuse resistant AEAD encryption scheme
-- defined in <https://tools.ietf.org/html/rfc8452 RFC 8452>.
instance Encryption "AEAD_AES_256_GCM_SIV" where
   newtype Key "AEAD_AES_256_GCM_SIV"
      = Key (BAS.SizedByteArray 32 BA.ScrubbedBytes)
      deriving newtype (Key "AEAD_AES_256_GCM_SIV" -> Key "AEAD_AES_256_GCM_SIV" -> Bool
(Key "AEAD_AES_256_GCM_SIV" -> Key "AEAD_AES_256_GCM_SIV" -> Bool)
-> (Key "AEAD_AES_256_GCM_SIV"
    -> Key "AEAD_AES_256_GCM_SIV" -> Bool)
-> Eq (Key "AEAD_AES_256_GCM_SIV")
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Key "AEAD_AES_256_GCM_SIV" -> Key "AEAD_AES_256_GCM_SIV" -> Bool
== :: Key "AEAD_AES_256_GCM_SIV" -> Key "AEAD_AES_256_GCM_SIV" -> Bool
$c/= :: Key "AEAD_AES_256_GCM_SIV" -> Key "AEAD_AES_256_GCM_SIV" -> Bool
/= :: Key "AEAD_AES_256_GCM_SIV" -> Key "AEAD_AES_256_GCM_SIV" -> Bool
Eq)
   type KeyLength "AEAD_AES_256_GCM_SIV" = 32
   data Encrypt "AEAD_AES_256_GCM_SIV"
      = Encrypt CAES.AES256 C.ChaChaDRG CAGS.Nonce
   newtype Decrypt "AEAD_AES_256_GCM_SIV"
      = Decrypt CAES.AES256
   genKey :: forall (m :: * -> *).
MonadRandom m =>
m (Key "AEAD_AES_256_GCM_SIV")
genKey = (ScrubbedBytes -> Key "AEAD_AES_256_GCM_SIV")
-> m ScrubbedBytes -> m (Key "AEAD_AES_256_GCM_SIV")
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SizedByteArray 32 ScrubbedBytes -> Key "AEAD_AES_256_GCM_SIV"
Key (SizedByteArray 32 ScrubbedBytes -> Key "AEAD_AES_256_GCM_SIV")
-> (ScrubbedBytes -> SizedByteArray 32 ScrubbedBytes)
-> ScrubbedBytes
-> Key "AEAD_AES_256_GCM_SIV"
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> SizedByteArray 32 ScrubbedBytes
forall (n :: Nat) ba.
(ByteArrayAccess ba, KnownNat n) =>
ba -> SizedByteArray n ba
BAS.unsafeSizedByteArray) (Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
C.getRandomBytes Int
32)
   keyFromBytes :: forall raw.
ByteArrayAccess raw =>
raw -> Either String (Key "AEAD_AES_256_GCM_SIV")
keyFromBytes =
      Either String (Key "AEAD_AES_256_GCM_SIV")
-> (SizedByteArray 32 ScrubbedBytes
    -> Either String (Key "AEAD_AES_256_GCM_SIV"))
-> Maybe (SizedByteArray 32 ScrubbedBytes)
-> Either String (Key "AEAD_AES_256_GCM_SIV")
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String -> Either String (Key "AEAD_AES_256_GCM_SIV")
forall a b. a -> Either a b
Left String
"Bad length") (Key "AEAD_AES_256_GCM_SIV"
-> Either String (Key "AEAD_AES_256_GCM_SIV")
forall a b. b -> Either a b
Right (Key "AEAD_AES_256_GCM_SIV"
 -> Either String (Key "AEAD_AES_256_GCM_SIV"))
-> (SizedByteArray 32 ScrubbedBytes -> Key "AEAD_AES_256_GCM_SIV")
-> SizedByteArray 32 ScrubbedBytes
-> Either String (Key "AEAD_AES_256_GCM_SIV")
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SizedByteArray 32 ScrubbedBytes -> Key "AEAD_AES_256_GCM_SIV"
Key) (Maybe (SizedByteArray 32 ScrubbedBytes)
 -> Either String (Key "AEAD_AES_256_GCM_SIV"))
-> (raw -> Maybe (SizedByteArray 32 ScrubbedBytes))
-> raw
-> Either String (Key "AEAD_AES_256_GCM_SIV")
forall b c a. (b -> c) -> (a -> b) -> a -> c
. raw -> Maybe (SizedByteArray 32 ScrubbedBytes)
forall (n :: Nat) bin bout.
(ByteArrayAccess bin, ByteArrayN n bout, KnownNat n) =>
bin -> Maybe bout
BAS.fromByteArrayAccess
   keyToBytes :: forall raw.
ByteArrayN (KeyLength "AEAD_AES_256_GCM_SIV") raw =>
Key "AEAD_AES_256_GCM_SIV" -> raw
keyToBytes (Key SizedByteArray 32 ScrubbedBytes
key) = SizedByteArray 32 ScrubbedBytes -> raw
forall (n :: Nat) bin bout.
(ByteArrayN n bin, ByteArrayN n bout, KnownNat n) =>
bin -> bout
BAS.convert SizedByteArray 32 ScrubbedBytes
key
   initial :: forall (m :: * -> *).
MonadRandom m =>
Key "AEAD_AES_256_GCM_SIV"
-> m (Encrypt "AEAD_AES_256_GCM_SIV",
      Decrypt "AEAD_AES_256_GCM_SIV")
initial (Key SizedByteArray 32 ScrubbedBytes
key0) = do
      ChaChaDRG
drg0 <- m ChaChaDRG
forall (randomly :: * -> *).
MonadRandom randomly =>
randomly ChaChaDRG
C.drgNew
      let (Nonce
nonce, ChaChaDRG
drg1) = ChaChaDRG
-> MonadPseudoRandom ChaChaDRG Nonce -> (Nonce, ChaChaDRG)
forall gen a. DRG gen => gen -> MonadPseudoRandom gen a -> (a, gen)
C.withDRG ChaChaDRG
drg0 MonadPseudoRandom ChaChaDRG Nonce
forall (m :: * -> *). MonadRandom m => m Nonce
CAGS.generateNonce
          aes :: AES256
aes = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
C.throwCryptoError (CryptoFailable AES256 -> AES256)
-> CryptoFailable AES256 -> AES256
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
forall key. ByteArray key => key -> CryptoFailable AES256
CAES.cipherInit (ScrubbedBytes -> CryptoFailable AES256)
-> ScrubbedBytes -> CryptoFailable AES256
forall a b. (a -> b) -> a -> b
$ SizedByteArray 32 ScrubbedBytes -> ScrubbedBytes
forall (n :: Nat) ba. SizedByteArray n ba -> ba
BAS.unSizedByteArray SizedByteArray 32 ScrubbedBytes
key0
      (Encrypt "AEAD_AES_256_GCM_SIV", Decrypt "AEAD_AES_256_GCM_SIV")
-> m (Encrypt "AEAD_AES_256_GCM_SIV",
      Decrypt "AEAD_AES_256_GCM_SIV")
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (AES256 -> ChaChaDRG -> Nonce -> Encrypt "AEAD_AES_256_GCM_SIV"
Encrypt AES256
aes ChaChaDRG
drg1 Nonce
nonce, AES256 -> Decrypt "AEAD_AES_256_GCM_SIV"
Decrypt AES256
aes)
   advance :: Encrypt "AEAD_AES_256_GCM_SIV" -> Encrypt "AEAD_AES_256_GCM_SIV"
advance (Encrypt AES256
aes ChaChaDRG
drg0 Nonce
_) =
      let (Nonce
nonce, ChaChaDRG
drg1) = ChaChaDRG
-> MonadPseudoRandom ChaChaDRG Nonce -> (Nonce, ChaChaDRG)
forall gen a. DRG gen => gen -> MonadPseudoRandom gen a -> (a, gen)
C.withDRG ChaChaDRG
drg0 MonadPseudoRandom ChaChaDRG Nonce
forall (m :: * -> *). MonadRandom m => m Nonce
CAGS.generateNonce
      in  AES256 -> ChaChaDRG -> Nonce -> Encrypt "AEAD_AES_256_GCM_SIV"
Encrypt AES256
aes ChaChaDRG
drg1 Nonce
nonce
   encrypt :: Encrypt "AEAD_AES_256_GCM_SIV" -> ByteString -> ByteString
encrypt (Encrypt AES256
aes ChaChaDRG
_ Nonce
nonce) ByteString
plain =
      let (AuthTag
tag, ByteString
cry) = AES256
-> Nonce -> ByteString -> ByteString -> (AuthTag, ByteString)
forall aes aad ba.
(BlockCipher128 aes, ByteArrayAccess aad, ByteArray ba) =>
aes -> Nonce -> aad -> ba -> (AuthTag, ba)
CAGS.encrypt AES256
aes Nonce
nonce ByteString
B.empty (ByteString -> (AuthTag, ByteString))
-> ByteString -> (AuthTag, ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
B.toStrict ByteString
plain
      in  [ByteString] -> ByteString
BL.fromChunks [Nonce -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Nonce
nonce, AuthTag -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert AuthTag
tag, ByteString
cry]
   decrypt :: Decrypt "AEAD_AES_256_GCM_SIV"
-> ByteString -> Either String ByteString
decrypt = \(Decrypt AES256
aes) ByteString
raw -> do
      (Nonce
nonce, AuthTag
tag, ByteString
cry) <- Result ByteString (Nonce, AuthTag, ByteString)
-> Either String (Nonce, AuthTag, ByteString)
forall a. Result ByteString a -> Either String a
fromResult (Result ByteString (Nonce, AuthTag, ByteString)
 -> Either String (Nonce, AuthTag, ByteString))
-> Result ByteString (Nonce, AuthTag, ByteString)
-> Either String (Nonce, AuthTag, ByteString)
forall a b. (a -> b) -> a -> b
$ Parser ByteString (Nonce, AuthTag, ByteString)
-> ByteString -> Result ByteString (Nonce, AuthTag, ByteString)
forall byteArray a.
ByteArrayAccess byteArray =>
Parser byteArray a -> byteArray -> Result byteArray a
BAP.parse Parser ByteString (Nonce, AuthTag, ByteString)
p (ByteString -> ByteString
B.toStrict ByteString
raw)
      case AES256
-> Nonce -> ByteString -> ByteString -> AuthTag -> Maybe ByteString
forall aes aad ba.
(BlockCipher128 aes, ByteArrayAccess aad, ByteArray ba) =>
aes -> Nonce -> aad -> ba -> AuthTag -> Maybe ba
CAGS.decrypt AES256
aes Nonce
nonce ByteString
B.empty ByteString
cry AuthTag
tag of
         Just ByteString
x -> ByteString -> Either String ByteString
forall a. a -> Either String a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString -> Either String ByteString)
-> ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
BL.fromStrict ByteString
x
         Maybe ByteString
Nothing -> String -> Either String ByteString
forall a b. a -> Either a b
Left String
"Can't decrypt"
     where
      p :: BAP.Parser B.ByteString (CAGS.Nonce, CAES.AuthTag, B.ByteString)
      p :: Parser ByteString (Nonce, AuthTag, ByteString)
p = do
         C.CryptoPassed Nonce
nonce <- ByteString -> CryptoFailable Nonce
forall iv. ByteArrayAccess iv => iv -> CryptoFailable Nonce
CAGS.nonce (ByteString -> CryptoFailable Nonce)
-> Parser ByteString ByteString
-> Parser ByteString (CryptoFailable Nonce)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Parser ByteString ByteString
forall byteArray.
ByteArray byteArray =>
Int -> Parser byteArray byteArray
BAP.take Int
12
         AuthTag
tag <- Bytes -> AuthTag
CAES.AuthTag (Bytes -> AuthTag)
-> (ByteString -> Bytes) -> ByteString -> AuthTag
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert (ByteString -> AuthTag)
-> Parser ByteString ByteString -> Parser ByteString AuthTag
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Parser ByteString ByteString
forall byteArray.
ByteArray byteArray =>
Int -> Parser byteArray byteArray
BAP.take Int
16
         ByteString
cry <- Parser ByteString ByteString
forall byteArray. ByteArray byteArray => Parser byteArray byteArray
BAP.takeAll
         (Nonce, AuthTag, ByteString)
-> Parser ByteString (Nonce, AuthTag, ByteString)
forall a. a -> Parser ByteString a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Nonce
nonce, AuthTag
tag, ByteString
cry)

fromResult :: BAP.Result B.ByteString a -> Either String a
fromResult :: forall a. Result ByteString a -> Either String a
fromResult = \case
   BAP.ParseOK ByteString
rest a
a
      | ByteString -> Bool
B.null ByteString
rest -> a -> Either String a
forall a b. b -> Either a b
Right a
a
      | Bool
otherwise -> String -> Either String a
forall a b. a -> Either a b
Left String
"Leftovers"
   BAP.ParseMore Maybe ByteString -> Result ByteString a
f -> Result ByteString a -> Either String a
forall a. Result ByteString a -> Either String a
fromResult (Maybe ByteString -> Result ByteString a
f Maybe ByteString
forall a. Maybe a
Nothing)
   BAP.ParseFail String
e -> String -> Either String a
forall a b. a -> Either a b
Left String
e