{-# LANGUAGE TypeFamilies #-}
---------------------------------------------------
-- |
-- Module      : Crypto.Noise.Cipher.ChaChaPoly1305
-- Maintainer  : John Galt <jgalt@centromere.net>
-- Stability   : experimental
-- Portability : POSIX
module Crypto.Noise.Cipher.ChaChaPoly1305
  ( -- * Types
    ChaChaPoly1305
  ) where

import           Crypto.Error    (throwCryptoError)
import qualified Crypto.Cipher.ChaChaPoly1305 as CCP
import qualified Crypto.MAC.Poly1305          as P
import           Data.ByteArray  (ScrubbedBytes, Bytes, convert, take, drop,
                                  length, replicate, constEq)
import           Data.ByteString (ByteString, reverse)
import           Prelude hiding  (drop, length, replicate, take, reverse)

import Crypto.Noise.Cipher

-- | Represents the ChaCha cipher with Poly1305 for AEAD.
data ChaChaPoly1305

instance Cipher ChaChaPoly1305 where
  newtype Ciphertext   ChaChaPoly1305 = CTCCP1305 (ScrubbedBytes, P.Auth)
  newtype SymmetricKey ChaChaPoly1305 = SKCCP1305 ScrubbedBytes
  newtype Nonce        ChaChaPoly1305 = NCCP1305  CCP.Nonce

  cipherName :: forall (proxy :: * -> *). proxy ChaChaPoly1305 -> ScrubbedBytes
cipherName proxy ChaChaPoly1305
_      = ScrubbedBytes
"ChaChaPoly"
  cipherEncrypt :: SymmetricKey ChaChaPoly1305
-> Nonce ChaChaPoly1305
-> ScrubbedBytes
-> ScrubbedBytes
-> Ciphertext ChaChaPoly1305
cipherEncrypt     = SymmetricKey ChaChaPoly1305
-> Nonce ChaChaPoly1305
-> ScrubbedBytes
-> ScrubbedBytes
-> Ciphertext ChaChaPoly1305
encrypt
  cipherDecrypt :: SymmetricKey ChaChaPoly1305
-> Nonce ChaChaPoly1305
-> ScrubbedBytes
-> Ciphertext ChaChaPoly1305
-> Maybe ScrubbedBytes
cipherDecrypt     = SymmetricKey ChaChaPoly1305
-> Nonce ChaChaPoly1305
-> ScrubbedBytes
-> Ciphertext ChaChaPoly1305
-> Maybe ScrubbedBytes
decrypt
  cipherZeroNonce :: Nonce ChaChaPoly1305
cipherZeroNonce   = Nonce ChaChaPoly1305
zeroNonce
  cipherMaxNonce :: Nonce ChaChaPoly1305
cipherMaxNonce    = Nonce ChaChaPoly1305
maxNonce
  cipherIncNonce :: Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305
cipherIncNonce    = Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305
incNonce
  cipherNonceEq :: Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305 -> Bool
cipherNonceEq     = Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305 -> Bool
nonceEq
  cipherNonceCmp :: Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305 -> Ordering
cipherNonceCmp    = Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305 -> Ordering
nonceCmp
  cipherBytesToSym :: ScrubbedBytes -> SymmetricKey ChaChaPoly1305
cipherBytesToSym  = ScrubbedBytes -> SymmetricKey ChaChaPoly1305
bytesToSym
  cipherSymToBytes :: SymmetricKey ChaChaPoly1305 -> ScrubbedBytes
cipherSymToBytes  = SymmetricKey ChaChaPoly1305 -> ScrubbedBytes
symToBytes
  cipherTextToBytes :: Ciphertext ChaChaPoly1305 -> ScrubbedBytes
cipherTextToBytes = Ciphertext ChaChaPoly1305 -> ScrubbedBytes
ctToBytes
  cipherBytesToText :: ScrubbedBytes -> Ciphertext ChaChaPoly1305
cipherBytesToText = ScrubbedBytes -> Ciphertext ChaChaPoly1305
bytesToCt

encrypt :: SymmetricKey ChaChaPoly1305
        -> Nonce ChaChaPoly1305
        -> AssocData
        -> Plaintext
        -> Ciphertext ChaChaPoly1305
encrypt :: SymmetricKey ChaChaPoly1305
-> Nonce ChaChaPoly1305
-> ScrubbedBytes
-> ScrubbedBytes
-> Ciphertext ChaChaPoly1305
encrypt (SKCCP1305 ScrubbedBytes
k) (NCCP1305 Nonce
n) ScrubbedBytes
ad ScrubbedBytes
plaintext =
  (ScrubbedBytes, Auth) -> Ciphertext ChaChaPoly1305
CTCCP1305 (ScrubbedBytes
out, Bytes -> Auth
P.Auth (Auth -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Auth
authTag))
  where
    initState :: State
initState       = CryptoFailable State -> State
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable State -> State) -> CryptoFailable State -> State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
CCP.initialize ScrubbedBytes
k Nonce
n
    afterAAD :: State
afterAAD        = State -> State
CCP.finalizeAAD (ScrubbedBytes -> State -> State
forall ba. ByteArrayAccess ba => ba -> State -> State
CCP.appendAAD ScrubbedBytes
ad State
initState)
    (ScrubbedBytes
out, State
afterEnc) = ScrubbedBytes -> State -> (ScrubbedBytes, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
CCP.encrypt ScrubbedBytes
plaintext State
afterAAD
    authTag :: Auth
authTag         = State -> Auth
CCP.finalize State
afterEnc

decrypt :: SymmetricKey ChaChaPoly1305
        -> Nonce ChaChaPoly1305
        -> AssocData
        -> Ciphertext ChaChaPoly1305
        -> Maybe Plaintext
decrypt :: SymmetricKey ChaChaPoly1305
-> Nonce ChaChaPoly1305
-> ScrubbedBytes
-> Ciphertext ChaChaPoly1305
-> Maybe ScrubbedBytes
decrypt (SKCCP1305 ScrubbedBytes
k) (NCCP1305 Nonce
n) ScrubbedBytes
ad (CTCCP1305 (ScrubbedBytes
ct, Auth
auth)) =
  if Auth
auth Auth -> Auth -> Bool
forall a. Eq a => a -> a -> Bool
== Auth
calcAuthTag
    then ScrubbedBytes -> Maybe ScrubbedBytes
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return ScrubbedBytes
out
    else Maybe ScrubbedBytes
forall a. Maybe a
Nothing
  where
    initState :: State
initState       = CryptoFailable State -> State
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable State -> State) -> CryptoFailable State -> State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> Nonce -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Nonce -> CryptoFailable State
CCP.initialize ScrubbedBytes
k Nonce
n
    afterAAD :: State
afterAAD        = State -> State
CCP.finalizeAAD (ScrubbedBytes -> State -> State
forall ba. ByteArrayAccess ba => ba -> State -> State
CCP.appendAAD ScrubbedBytes
ad State
initState)
    (ScrubbedBytes
out, State
afterDec) = ScrubbedBytes -> State -> (ScrubbedBytes, State)
forall ba. ByteArray ba => ba -> State -> (ba, State)
CCP.decrypt ScrubbedBytes
ct State
afterAAD
    calcAuthTag :: Auth
calcAuthTag     = State -> Auth
CCP.finalize State
afterDec

zeroNonce :: Nonce ChaChaPoly1305
zeroNonce :: Nonce ChaChaPoly1305
zeroNonce = Nonce -> Nonce ChaChaPoly1305
NCCP1305 (Nonce -> Nonce ChaChaPoly1305)
-> (CryptoFailable Nonce -> Nonce)
-> CryptoFailable Nonce
-> Nonce ChaChaPoly1305
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CryptoFailable Nonce -> Nonce
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable Nonce -> Nonce ChaChaPoly1305)
-> CryptoFailable Nonce -> Nonce ChaChaPoly1305
forall a b. (a -> b) -> a -> b
$ Bytes -> Bytes -> CryptoFailable Nonce
forall ba. ByteArrayAccess ba => ba -> ba -> CryptoFailable Nonce
CCP.nonce8 Bytes
constant Bytes
iv
  where
    constant :: Bytes
constant = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
replicate Int
4 Word8
0 :: Bytes
    iv :: Bytes
iv       = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
replicate Int
8 Word8
0 :: Bytes

maxNonce :: Nonce ChaChaPoly1305
maxNonce :: Nonce ChaChaPoly1305
maxNonce = Nonce -> Nonce ChaChaPoly1305
NCCP1305 (Nonce -> Nonce ChaChaPoly1305)
-> (CryptoFailable Nonce -> Nonce)
-> CryptoFailable Nonce
-> Nonce ChaChaPoly1305
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CryptoFailable Nonce -> Nonce
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable Nonce -> Nonce ChaChaPoly1305)
-> CryptoFailable Nonce -> Nonce ChaChaPoly1305
forall a b. (a -> b) -> a -> b
$ Bytes -> Bytes -> CryptoFailable Nonce
forall ba. ByteArrayAccess ba => ba -> ba -> CryptoFailable Nonce
CCP.nonce8 Bytes
constant Bytes
iv
  where
    constant :: Bytes
constant = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
replicate Int
4 Word8
0   :: Bytes
    iv :: Bytes
iv       = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
replicate Int
8 Word8
255 :: Bytes

incNonce :: Nonce ChaChaPoly1305
         -> Nonce ChaChaPoly1305
incNonce :: Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305
incNonce (NCCP1305 Nonce
n) = Nonce -> Nonce ChaChaPoly1305
NCCP1305 (Nonce -> Nonce ChaChaPoly1305) -> Nonce -> Nonce ChaChaPoly1305
forall a b. (a -> b) -> a -> b
$ Nonce -> Nonce
CCP.incrementNonce Nonce
n

nonceEq :: Nonce ChaChaPoly1305
        -> Nonce ChaChaPoly1305
        -> Bool
nonceEq :: Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305 -> Bool
nonceEq (NCCP1305 Nonce
a) (NCCP1305 Nonce
b) = Nonce -> Nonce -> Bool
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
constEq Nonce
a Nonce
b

-- | Since nonces in this cipher are little endian, they must be reversed prior
--   to comparison. A ByteString was chosen because it uses memcmp under the
--   hood.
nonceCmp :: Nonce ChaChaPoly1305
         -> Nonce ChaChaPoly1305
         -> Ordering
nonceCmp :: Nonce ChaChaPoly1305 -> Nonce ChaChaPoly1305 -> Ordering
nonceCmp (NCCP1305 Nonce
a) (NCCP1305 Nonce
b) = ByteString -> ByteString -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (ByteString -> ByteString
reverse (Nonce -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Nonce
a :: ByteString))
                                             (ByteString -> ByteString
reverse (Nonce -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Nonce
b :: ByteString))

bytesToSym :: ScrubbedBytes
           -> SymmetricKey ChaChaPoly1305
bytesToSym :: ScrubbedBytes -> SymmetricKey ChaChaPoly1305
bytesToSym = ScrubbedBytes -> SymmetricKey ChaChaPoly1305
SKCCP1305 (ScrubbedBytes -> SymmetricKey ChaChaPoly1305)
-> (ScrubbedBytes -> ScrubbedBytes)
-> ScrubbedBytes
-> SymmetricKey ChaChaPoly1305
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
take Int
32

symToBytes :: SymmetricKey ChaChaPoly1305
           -> ScrubbedBytes
symToBytes :: SymmetricKey ChaChaPoly1305 -> ScrubbedBytes
symToBytes (SKCCP1305 ScrubbedBytes
sk) = ScrubbedBytes
sk

ctToBytes :: Ciphertext ChaChaPoly1305
          -> ScrubbedBytes
ctToBytes :: Ciphertext ChaChaPoly1305 -> ScrubbedBytes
ctToBytes (CTCCP1305 (ScrubbedBytes
ct, Auth
a)) = ScrubbedBytes
ct ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
forall a. Monoid a => a -> a -> a
`mappend` Auth -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert Auth
a

bytesToCt :: ScrubbedBytes
          -> Ciphertext ChaChaPoly1305
bytesToCt :: ScrubbedBytes -> Ciphertext ChaChaPoly1305
bytesToCt ScrubbedBytes
bytes =
  (ScrubbedBytes, Auth) -> Ciphertext ChaChaPoly1305
CTCCP1305 (Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
take (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
length ScrubbedBytes
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ScrubbedBytes
bytes
            , Bytes -> Auth
P.Auth (Bytes -> Auth)
-> (ScrubbedBytes -> Bytes) -> ScrubbedBytes -> Auth
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScrubbedBytes -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert (ScrubbedBytes -> Auth) -> ScrubbedBytes -> Auth
forall a b. (a -> b) -> a -> b
$ Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
drop (ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
length ScrubbedBytes
bytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
16) ScrubbedBytes
bytes
            )