-- |
-- Module      : Crypto.Cipher.ChaChaPoly1305
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : stable
-- Portability : good
--
-- A simple AEAD scheme using ChaCha20 and Poly1305. See
-- <https://tools.ietf.org/html/rfc7539 RFC 7539>.
--
-- The State is not modified in place, so each function changing the State,
-- returns a new State.
--
-- Authenticated Data need to be added before any call to 'encrypt' or 'decrypt',
-- and once all the data has been added, then 'finalizeAAD' need to be called.
--
-- Once 'finalizeAAD' has been called, no further 'appendAAD' call should be make.
--
-- >import Data.ByteString.Char8 as B
-- >import Data.ByteArray
-- >import Crypto.Error
-- >import Crypto.Cipher.ChaChaPoly1305 as C
-- >
-- >encrypt
-- >    :: ByteString -- nonce (12 random bytes)
-- >    -> ByteString -- symmetric key
-- >    -> ByteString -- optional associated data (won't be encrypted)
-- >    -> ByteString -- input plaintext to be encrypted
-- >    -> CryptoFailable ByteString -- ciphertext with a 128-bit tag attached
-- >encrypt nonce key header plaintext = do
-- >    st1 <- C.nonce12 nonce >>= C.initialize key
-- >    let
-- >        st2 = C.finalizeAAD $ C.appendAAD header st1
-- >        (out, st3) = C.encrypt plaintext st2
-- >        auth = C.finalize st3
-- >    return $ out `B.append` Data.ByteArray.convert auth
--
module Crypto.Cipher.ChaChaPoly1305
    ( State
    , Nonce
    , nonce12
    , nonce8
    , incrementNonce
    , initialize
    , appendAAD
    , finalizeAAD
    , encrypt
    , decrypt
    , finalize
    ) where

import           Control.Monad             (when)
import           Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, Bytes, ScrubbedBytes)
import qualified Crypto.Internal.ByteArray as B
import           Crypto.Internal.Imports
import           Crypto.Error
import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Crypto.MAC.Poly1305  as Poly1305
import           Data.Memory.Endian
import qualified Data.ByteArray.Pack as P
import           Foreign.Ptr
import           Foreign.Storable

-- | A ChaChaPoly1305 State.
--
-- The state is immutable, and only new state can be created
data State = State !ChaCha.State
                   !Poly1305.State
                   !Word64 -- AAD length
                   !Word64 -- ciphertext length

-- | Valid Nonce for ChaChaPoly1305.
--
-- It can be created with 'nonce8' or 'nonce12'
data Nonce = Nonce8 Bytes | Nonce12 Bytes

instance ByteArrayAccess Nonce where
  length :: Nonce -> Int
length (Nonce8  Bytes
n) = Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n
  length (Nonce12 Bytes
n) = Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
n

  withByteArray :: Nonce -> (Ptr p -> IO a) -> IO a
withByteArray (Nonce8  Bytes
n) = Bytes -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n
  withByteArray (Nonce12 Bytes
n) = Bytes -> (Ptr p -> IO a) -> IO a
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
n

-- Based on the following pseudo code:
--
-- chacha20_aead_encrypt(aad, key, iv, constant, plaintext):
--     nonce = constant | iv
--     otk = poly1305_key_gen(key, nonce)
--     ciphertext = chacha20_encrypt(key, 1, nonce, plaintext)
--     mac_data = aad | pad16(aad)
--     mac_data |= ciphertext | pad16(ciphertext)
--     mac_data |= num_to_4_le_bytes(aad.length)
--     mac_data |= num_to_4_le_bytes(ciphertext.length)
--     tag = poly1305_mac(mac_data, otk)
--     return (ciphertext, tag)

pad16 :: Word64 -> Bytes
pad16 :: Word64 -> Bytes
pad16 Word64
n
    | Int
modLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Bytes
forall a. ByteArray a => a
B.empty
    | Bool
otherwise   = Int -> Word8 -> Bytes
forall ba. ByteArray ba => Int -> Word8 -> ba
B.replicate (Int
16 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
modLen) Word8
0
  where
    modLen :: Int
modLen = Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
n Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`mod` Word64
16)

-- | Nonce smart constructor 12 bytes IV, nonce constructor
nonce12 :: ByteArrayAccess iv => iv -> CryptoFailable Nonce
nonce12 :: iv -> CryptoFailable Nonce
nonce12 iv
iv
    | iv -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length iv
iv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
12 = CryptoError -> CryptoFailable Nonce
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
    | Bool
otherwise         = Nonce -> CryptoFailable Nonce
forall a. a -> CryptoFailable a
CryptoPassed (Nonce -> CryptoFailable Nonce)
-> (iv -> Nonce) -> iv -> CryptoFailable Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce12 (Bytes -> Nonce) -> (iv -> Bytes) -> iv -> Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. iv -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (iv -> CryptoFailable Nonce) -> iv -> CryptoFailable Nonce
forall a b. (a -> b) -> a -> b
$ iv
iv

-- | 8 bytes IV, nonce constructor
nonce8 :: ByteArrayAccess ba
       => ba -- ^ 4 bytes constant
       -> ba -- ^ 8 bytes IV
       -> CryptoFailable Nonce
nonce8 :: ba -> ba -> CryptoFailable Nonce
nonce8 ba
constant ba
iv
    | ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
constant Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
4 = CryptoError -> CryptoFailable Nonce
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
    | ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
iv       Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
8 = CryptoError -> CryptoFailable Nonce
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_IvSizeInvalid
    | Bool
otherwise              = Nonce -> CryptoFailable Nonce
forall a. a -> CryptoFailable a
CryptoPassed (Nonce -> CryptoFailable Nonce)
-> ([ba] -> Nonce) -> [ba] -> CryptoFailable Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bytes -> Nonce
Nonce8 (Bytes -> Nonce) -> ([ba] -> Bytes) -> [ba] -> Nonce
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ba] -> Bytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
B.concat ([ba] -> CryptoFailable Nonce) -> [ba] -> CryptoFailable Nonce
forall a b. (a -> b) -> a -> b
$ [ba
constant, ba
iv]

-- | Increment a nonce
incrementNonce :: Nonce -> Nonce
incrementNonce :: Nonce -> Nonce
incrementNonce (Nonce8  Bytes
n) = Bytes -> Nonce
Nonce8  (Bytes -> Nonce) -> Bytes -> Nonce
forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
4
incrementNonce (Nonce12 Bytes
n) = Bytes -> Nonce
Nonce12 (Bytes -> Nonce) -> Bytes -> Nonce
forall a b. (a -> b) -> a -> b
$ Bytes -> Int -> Bytes
incrementNonce' Bytes
n Int
0

incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' :: Bytes -> Int -> Bytes
incrementNonce' Bytes
b Int
offset = Bytes -> (Ptr Word8 -> IO ()) -> Bytes
forall bs1 bs2 p.
(ByteArrayAccess bs1, ByteArray bs2) =>
bs1 -> (Ptr p -> IO ()) -> bs2
B.copyAndFreeze Bytes
b ((Ptr Word8 -> IO ()) -> Bytes) -> (Ptr Word8 -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
s ->
    Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
s Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
offset)
    where
      loop :: Ptr Word8 -> Ptr Word8 -> IO ()
      loop :: Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s Ptr Word8
p
          | Ptr Word8
s Ptr Word8 -> Ptr Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
b Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)) = Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s IO Word8 -> (Word8 -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
s (Word8 -> IO ()) -> (Word8 -> Word8) -> Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
(+) Word8
1
          | Bool
otherwise = do
              Word8
r <- Word8 -> Word8 -> Word8
forall a. Num a => a -> a -> a
(+) Word8
1 (Word8 -> Word8) -> IO Word8 -> IO Word8
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> IO Word8
forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
p
              Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
p Word8
r
              Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word8
r Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> Ptr Word8 -> IO ()
loop Ptr Word8
s (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1)

-- | Initialize a new ChaChaPoly1305 State
--
-- The key length need to be 256 bits, and the nonce
-- procured using either `nonce8` or `nonce12`
initialize :: ByteArrayAccess key
           => key -> Nonce -> CryptoFailable State
initialize :: key -> Nonce -> CryptoFailable State
initialize key
key (Nonce8  Bytes
nonce) = key -> Bytes -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
initialize key
key (Nonce12 Bytes
nonce) = key -> Bytes -> CryptoFailable State
forall key.
ByteArrayAccess key =>
key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce

initialize' :: ByteArrayAccess key
            => key -> Bytes -> CryptoFailable State
initialize' :: key -> Bytes -> CryptoFailable State
initialize' key
key Bytes
nonce
    | key -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length key
key Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = CryptoError -> CryptoFailable State
forall a. CryptoError -> CryptoFailable a
CryptoFailed CryptoError
CryptoError_KeySizeInvalid
    | Bool
otherwise          = State -> CryptoFailable State
forall a. a -> CryptoFailable a
CryptoPassed (State -> CryptoFailable State) -> State -> CryptoFailable State
forall a b. (a -> b) -> a -> b
$ State -> State -> Word64 -> Word64 -> State
State State
encState State
polyState Word64
0 Word64
0
  where
    rootState :: State
rootState           = Int -> key -> Bytes -> State
forall key nonce.
(ByteArrayAccess key, ByteArrayAccess nonce) =>
Int -> key -> nonce -> State
ChaCha.initialize Int
20 key
key Bytes
nonce
    (ScrubbedBytes
polyKey, State
encState) = State -> Int -> (ScrubbedBytes, State)
forall ba. ByteArray ba => State -> Int -> (ba, State)
ChaCha.generate State
rootState Int
64
    polyState :: State
polyState           = CryptoFailable State -> State
forall a. CryptoFailable a -> a
throwCryptoError (CryptoFailable State -> State) -> CryptoFailable State -> State
forall a b. (a -> b) -> a -> b
$ ScrubbedBytes -> CryptoFailable State
forall key. ByteArrayAccess key => key -> CryptoFailable State
Poly1305.initialize (Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
B.take Int
32 ScrubbedBytes
polyKey :: ScrubbedBytes)

-- | Append Authenticated Data to the State and return
-- the new modified State.
--
-- Once no further call to this function need to be make,
-- the user should call 'finalizeAAD'
appendAAD :: ByteArrayAccess ba => ba -> State -> State
appendAAD :: ba -> State -> State
appendAAD ba
ba (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
newLength Word64
plainLength
  where
    newMacState :: State
newMacState = State -> ba -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
ba
    newLength :: Word64
newLength   = Word64
aadLength Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
ba)

-- | Finalize the Authenticated Data and return the finalized State
finalizeAAD :: State -> State
finalizeAAD :: State -> State
finalizeAAD (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    State -> State -> Word64 -> Word64 -> State
State State
encState State
newMacState Word64
aadLength Word64
plainLength
  where
    newMacState :: State
newMacState = State -> Bytes -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState (Bytes -> State) -> Bytes -> State
forall a b. (a -> b) -> a -> b
$ Word64 -> Bytes
pad16 Word64
aadLength

-- | Encrypt a piece of data and returns the encrypted Data and the
-- updated State.
encrypt :: ByteArray ba => ba -> State -> (ba, State)
encrypt :: ba -> State -> (ba, State)
encrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    (ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
  where
    (ba
output, State
newEncState) = State -> ba -> (ba, State)
forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
    newMacState :: State
newMacState           = State -> ba -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
output
    newPlainLength :: Word64
newPlainLength        = Word64
plainLength Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)

-- | Decrypt a piece of data and returns the decrypted Data and the
-- updated State.
decrypt :: ByteArray ba => ba -> State -> (ba, State)
decrypt :: ba -> State -> (ba, State)
decrypt ba
input (State State
encState State
macState Word64
aadLength Word64
plainLength) =
    (ba
output, State -> State -> Word64 -> Word64 -> State
State State
newEncState State
newMacState Word64
aadLength Word64
newPlainLength)
  where
    (ba
output, State
newEncState) = State -> ba -> (ba, State)
forall ba. ByteArray ba => State -> ba -> (ba, State)
ChaCha.combine State
encState ba
input
    newMacState :: State
newMacState           = State -> ba -> State
forall ba. ByteArrayAccess ba => State -> ba -> State
Poly1305.update State
macState ba
input
    newPlainLength :: Word64
newPlainLength        = Word64
plainLength Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
input)

-- | Generate an authentication tag from the State.
finalize :: State -> Poly1305.Auth
finalize :: State -> Auth
finalize (State State
_ State
macState Word64
aadLength Word64
plainLength) =
    State -> Auth
Poly1305.finalize (State -> Auth) -> State -> Auth
forall a b. (a -> b) -> a -> b
$ State -> [Bytes] -> State
forall ba. ByteArrayAccess ba => State -> [ba] -> State
Poly1305.updates State
macState
        [ Word64 -> Bytes
pad16 Word64
plainLength
        , (String -> Bytes)
-> (Bytes -> Bytes) -> Either String Bytes -> Bytes
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (String -> String -> Bytes
forall a. HasCallStack => String -> a
error String
"finalize: internal error") Bytes -> Bytes
forall a. a -> a
id (Either String Bytes -> Bytes) -> Either String Bytes -> Bytes
forall a b. (a -> b) -> a -> b
$ Int -> Packer () -> Either String Bytes
forall byteArray a.
ByteArray byteArray =>
Int -> Packer a -> Either String byteArray
P.fill Int
16 (LE Word64 -> Packer ()
forall storable. Storable storable => storable -> Packer ()
P.putStorable (Word64 -> LE Word64
forall a. ByteSwap a => a -> LE a
toLE Word64
aadLength) Packer () -> Packer () -> Packer ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> LE Word64 -> Packer ()
forall storable. Storable storable => storable -> Packer ()
P.putStorable (Word64 -> LE Word64
forall a. ByteSwap a => a -> LE a
toLE Word64
plainLength))
        ]