{-# LANGUAGE TypeFamilies #-}
-------------------------------------------------
-- |
-- Module      : Crypto.Noise.Cipher.AESGCM
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX
module Crypto.Noise.Cipher.AESGCM
  ( -- * Types
    AESGCM
  ) where

import Crypto.Error        (throwCryptoError)
import Crypto.Cipher.AES   (AES256)
import Crypto.Cipher.Types (AuthTag(..), AEADMode(AEAD_GCM), cipherInit,
                            aeadInit, aeadSimpleEncrypt, aeadSimpleDecrypt)
import Data.ByteArray      (ByteArray, Bytes, ScrubbedBytes, convert, take,
                            drop, length, copyAndFreeze, zero, append,
                            replicate)
import Data.Word           (Word8)
import Foreign.Ptr
import Foreign.Storable
import Prelude hiding      (drop, length, replicate, take)

import Crypto.Noise.Cipher

-- | Represents the AES256 cipher with GCM for AEAD.
data AESGCM

instance Cipher AESGCM where
  newtype Ciphertext   AESGCM = CTAES (AuthTag, ScrubbedBytes)
  newtype SymmetricKey AESGCM = SKAES ScrubbedBytes
  newtype Nonce        AESGCM = NAES  Bytes

  cipherName :: forall (proxy :: * -> *). proxy AESGCM -> ScrubbedBytes
cipherName proxy AESGCM
_      = ScrubbedBytes
"AESGCM"
  cipherEncrypt :: SymmetricKey AESGCM
-> Nonce AESGCM
-> ScrubbedBytes
-> ScrubbedBytes
-> Ciphertext AESGCM
cipherEncrypt     = SymmetricKey AESGCM
-> Nonce AESGCM
-> ScrubbedBytes
-> ScrubbedBytes
-> Ciphertext AESGCM
encrypt
  cipherDecrypt :: SymmetricKey AESGCM
-> Nonce AESGCM
-> ScrubbedBytes
-> Ciphertext AESGCM
-> Maybe ScrubbedBytes
cipherDecrypt     = SymmetricKey AESGCM
-> Nonce AESGCM
-> ScrubbedBytes
-> Ciphertext AESGCM
-> Maybe ScrubbedBytes
decrypt
  cipherZeroNonce :: Nonce AESGCM
cipherZeroNonce   = Nonce AESGCM
zeroNonce
  cipherMaxNonce :: Nonce AESGCM
cipherMaxNonce    = Nonce AESGCM
maxNonce
  cipherIncNonce :: Nonce AESGCM -> Nonce AESGCM
cipherIncNonce    = Nonce AESGCM -> Nonce AESGCM
incNonce
  cipherNonceEq :: Nonce AESGCM -> Nonce AESGCM -> Bool
cipherNonceEq     = Nonce AESGCM -> Nonce AESGCM -> Bool
nonceEq
  cipherNonceCmp :: Nonce AESGCM -> Nonce AESGCM -> Ordering
cipherNonceCmp    = Nonce AESGCM -> Nonce AESGCM -> Ordering
nonceCmp
  cipherBytesToSym :: ScrubbedBytes -> SymmetricKey AESGCM
cipherBytesToSym  = ScrubbedBytes -> SymmetricKey AESGCM
bytesToSym
  cipherSymToBytes :: SymmetricKey AESGCM -> ScrubbedBytes
cipherSymToBytes  = SymmetricKey AESGCM -> ScrubbedBytes
symToBytes
  cipherTextToBytes :: Ciphertext AESGCM -> ScrubbedBytes
cipherTextToBytes = Ciphertext AESGCM -> ScrubbedBytes
ctToBytes
  cipherBytesToText :: ScrubbedBytes -> Ciphertext AESGCM
cipherBytesToText = ScrubbedBytes -> Ciphertext AESGCM
bytesToCt

encrypt :: SymmetricKey AESGCM
        -> Nonce AESGCM
        -> AssocData
        -> Plaintext
        -> Ciphertext AESGCM
encrypt :: SymmetricKey AESGCM
-> Nonce AESGCM
-> ScrubbedBytes
-> ScrubbedBytes
-> Ciphertext AESGCM
encrypt (SKAES ScrubbedBytes
k) (NAES Bytes
n) ScrubbedBytes
ad ScrubbedBytes
plaintext =
  (AuthTag, ScrubbedBytes) -> Ciphertext AESGCM
CTAES ((AuthTag, ScrubbedBytes) -> Ciphertext AESGCM)
-> (AuthTag, ScrubbedBytes) -> Ciphertext AESGCM
forall a b. (a -> b) -> a -> b
$ AEAD AES256
-> ScrubbedBytes
-> ScrubbedBytes
-> Int
-> (AuthTag, ScrubbedBytes)
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> Int -> (AuthTag, ba)
aeadSimpleEncrypt AEAD AES256
aead ScrubbedBytes
ad ScrubbedBytes
plaintext Int
16
  where
    state :: AES256
state = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable AES256 -> AES256)
-> (ScrubbedBytes -> CryptoFailable AES256)
-> ScrubbedBytes
-> AES256
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
forall key. ByteArray key => key -> CryptoFailable AES256
cipherInit (ScrubbedBytes -> AES256) -> ScrubbedBytes -> AES256
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes
k :: AES256
    aead :: AEAD AES256
aead  = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> Bytes -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall iv.
ByteArrayAccess iv =>
AEADMode -> AES256 -> iv -> CryptoFailable (AEAD AES256)
aeadInit AEADMode
AEAD_GCM AES256
state Bytes
n

decrypt :: SymmetricKey AESGCM
        -> Nonce AESGCM
        -> AssocData
        -> Ciphertext AESGCM
        -> Maybe Plaintext
decrypt :: SymmetricKey AESGCM
-> Nonce AESGCM
-> ScrubbedBytes
-> Ciphertext AESGCM
-> Maybe ScrubbedBytes
decrypt (SKAES ScrubbedBytes
k) (NAES Bytes
n) ScrubbedBytes
ad (CTAES (AuthTag
authTag, ScrubbedBytes
ct)) =
  AEAD AES256
-> ScrubbedBytes -> ScrubbedBytes -> AuthTag -> Maybe ScrubbedBytes
forall aad ba a.
(ByteArrayAccess aad, ByteArray ba) =>
AEAD a -> aad -> ba -> AuthTag -> Maybe ba
aeadSimpleDecrypt AEAD AES256
aead ScrubbedBytes
ad ScrubbedBytes
ct AuthTag
authTag
  where
    state :: AES256
state = CryptoFailable AES256 -> AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable AES256 -> AES256)
-> (ScrubbedBytes -> CryptoFailable AES256)
-> ScrubbedBytes
-> AES256
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> CryptoFailable AES256
forall cipher key.
(Cipher cipher, ByteArray key) =>
key -> CryptoFailable cipher
forall key. ByteArray key => key -> CryptoFailable AES256
cipherInit (ScrubbedBytes -> AES256) -> ScrubbedBytes -> AES256
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes
k :: AES256
    aead :: AEAD AES256
aead  = CryptoFailable (AEAD AES256) -> AEAD AES256
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable (AEAD AES256) -> AEAD AES256)
-> CryptoFailable (AEAD AES256) -> AEAD AES256
forall a b. (a -> b) -> a -> b
$ AEADMode -> AES256 -> Bytes -> CryptoFailable (AEAD AES256)
forall cipher iv.
(BlockCipher cipher, ByteArrayAccess iv) =>
AEADMode -> cipher -> iv -> CryptoFailable (AEAD cipher)
forall iv.
ByteArrayAccess iv =>
AEADMode -> AES256 -> iv -> CryptoFailable (AEAD AES256)
aeadInit AEADMode
AEAD_GCM AES256
state Bytes
n

zeroNonce :: Nonce AESGCM
zeroNonce :: Nonce AESGCM
zeroNonce = Bytes -> Nonce AESGCM
NAES (Bytes -> Nonce AESGCM) -> (Int -> Bytes) -> Int -> Nonce AESGCM
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Bytes
forall ba. ByteArray ba => Int -> ba
zero (Int -> Nonce AESGCM) -> Int -> Nonce AESGCM
forall a b. (a -> b) -> a -> b
$ Int
12

maxNonce :: Nonce AESGCM
maxNonce :: Nonce AESGCM
maxNonce = Bytes -> Nonce AESGCM
NAES (Bytes -> Nonce AESGCM) -> Bytes -> Nonce AESGCM
forall a b. (a -> b) -> a -> b
$ Int -> Bytes
forall ba. ByteArray ba => Int -> ba
zero Int
4 Bytes -> Bytes -> Bytes
forall bs. ByteArray bs => bs -> bs -> bs
`append` Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
replicate Int
8 Word8
255

incNonce :: Nonce AESGCM
         -> Nonce AESGCM
incNonce :: Nonce AESGCM -> Nonce AESGCM
incNonce (NAES Bytes
n) = Bytes -> Nonce AESGCM
NAES (Bytes -> Nonce AESGCM) -> Bytes -> Nonce AESGCM
forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
forall b. ByteArray b => b -> Int -> b
ivAdd Bytes
n Int
1

nonceEq :: Nonce AESGCM
        -> Nonce AESGCM
        -> Bool
nonceEq :: Nonce AESGCM -> Nonce AESGCM -> Bool
nonceEq (NAES Bytes
a) (NAES Bytes
b) = Bytes
a Bytes -> Bytes -> Bool
forall a. Eq a => a -> a -> Bool
== Bytes
b

nonceCmp :: Nonce AESGCM
         -> Nonce AESGCM
         -> Ordering
nonceCmp :: Nonce AESGCM -> Nonce AESGCM -> Ordering
nonceCmp (NAES Bytes
a) (NAES Bytes
b) = Bytes -> Bytes -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Bytes
a Bytes
b

bytesToSym :: ScrubbedBytes
           -> SymmetricKey AESGCM
bytesToSym :: ScrubbedBytes -> SymmetricKey AESGCM
bytesToSym = ScrubbedBytes -> SymmetricKey AESGCM
SKAES (ScrubbedBytes -> SymmetricKey AESGCM)
-> (ScrubbedBytes -> ScrubbedBytes)
-> ScrubbedBytes
-> SymmetricKey AESGCM
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
take Int
32

symToBytes :: SymmetricKey AESGCM
           -> ScrubbedBytes
symToBytes :: SymmetricKey AESGCM -> ScrubbedBytes
symToBytes (SKAES ScrubbedBytes
sk) = ScrubbedBytes
sk

ctToBytes :: Ciphertext AESGCM
          -> ScrubbedBytes
ctToBytes :: Ciphertext AESGCM -> ScrubbedBytes
ctToBytes (CTAES (AuthTag
a, ScrubbedBytes
ct)) = ScrubbedBytes
ct ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
forall a. Monoid a => a -> a -> a
`mappend` AuthTag -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert AuthTag
a

bytesToCt :: ScrubbedBytes
          -> Ciphertext AESGCM
bytesToCt :: ScrubbedBytes -> Ciphertext AESGCM
bytesToCt ScrubbedBytes
bytes =
  (AuthTag, ScrubbedBytes) -> Ciphertext AESGCM
CTAES ( Bytes -> AuthTag
AuthTag (Bytes -> AuthTag)
-> (ScrubbedBytes -> Bytes) -> ScrubbedBytes -> AuthTag
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ScrubbedBytes -> AuthTag) -> ScrubbedBytes -> AuthTag
forall a b. (a -> b) -> a -> b
$ Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
drop (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
length ScrubbedBytes
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ScrubbedBytes
bytes
        , Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
take (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
length ScrubbedBytes
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ScrubbedBytes
bytes
        )

-- Adapted from cryptonite's Crypto.Cipher.Types.Block module.
ivAdd :: ByteArray b
      => b
      -> Int
      -> b
ivAdd :: forall b. ByteArray b => b -> Int -> b
ivAdd b
b Int
i = b -> b
forall bs. ByteArray bs => bs -> bs
copy b
b
  where copy :: ByteArray bs => bs -> bs
        copy :: forall bs. ByteArray bs => bs -> bs
copy bs
bs = bs -> (Ptr Word8 -> IO ()) -> bs
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
copyAndFreeze bs
bs ((Ptr Word8 -> IO ()) -> bs) -> (Ptr Word8 -> IO ()) -> bs
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Ptr Word8 -> IO ()
loop Int
i (bs -> Int
forall ba. ByteArrayAccess ba => ba -> Int
length bs
bs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

        loop :: Int -> Int -> Ptr Word8 -> IO ()
        loop :: Int -> Int -> Ptr Word8 -> IO ()
loop Int
acc Int
ofs Ptr Word8
p
            | Int
ofs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0   = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise = do
                Word8
v <- Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
ofs) :: IO Word8
                let accv :: Int
accv    = Int
acc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
v
                    (Int
hi,Int
lo) = Int
accv Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
256
                Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
ofs) (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
lo :: Word8)
                Int -> Int -> Ptr Word8 -> IO ()
loop Int
hi (Int
ofs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Ptr Word8
p