-- |
-- Module      : Crypto.Saltine.Core.AEAD.ChaCha20Poly1305
-- Copyright   : (c) Thomas DuBuisson 2017
--               (c) Max Amanshauser 2021
-- License     : MIT
--
-- Maintainer  : max@lambdalifting.org
-- Stability   : experimental
-- Portability : non-portable
--
-- Secret-key authenticated encryption with additional data (AEAD):
-- "Crypto.Saltine.Core.AEAD.ChaCha20Poly1305"
--
-- Generating nonces for the functions in this module randomly
-- is not recommended, due to the risk of generating collisions.

module Crypto.Saltine.Core.AEAD.ChaCha20Poly1305 (
  Key, Nonce,
  aead, aeadOpen,
  aeadDetached, aeadOpenDetached,
  newKey, newNonce
  ) where

import Crypto.Saltine.Internal.AEAD.ChaCha20Poly1305
            ( c_aead
            , c_aead_open
            , c_aead_detached
            , c_aead_open_detached
            , Key(..)
            , Nonce(..)
            )
import Crypto.Saltine.Internal.Util as U
import Data.ByteString              (ByteString)
import Foreign.Ptr

import qualified Crypto.Saltine.Internal.AEAD.ChaCha20Poly1305  as Bytes
import qualified Data.ByteString                                as S

-- | Creates a random 'ChaCha20Poly1305' key
newKey :: IO Key
newKey :: IO Key
newKey = ByteString -> Key
Key (ByteString -> Key) -> IO ByteString -> IO Key
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
randomByteString Int
Bytes.aead_chacha20poly1305_keybytes

-- | Creates a random 'ChaCha20Poly1305' nonce
newNonce :: IO Nonce
newNonce :: IO Nonce
newNonce = ByteString -> Nonce
Nonce (ByteString -> Nonce) -> IO ByteString -> IO Nonce
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
randomByteString Int
Bytes.aead_chacha20poly1305_npubbytes


-- | Encrypts a message. It is infeasible for an attacker to decrypt
-- the message so long as the 'Nonce' is never repeated.
aead
    :: Key
    -> Nonce
    -> ByteString
    -- ^ Message
    -> ByteString
    -- ^ AAD
    -> ByteString
    -- ^ Ciphertext
aead :: Key -> Nonce -> ByteString -> ByteString -> ByteString
aead (Key ByteString
key) (Nonce ByteString
nonce) ByteString
msg ByteString
aad =
  (CInt, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((CInt, ByteString) -> ByteString)
-> ((Ptr CChar -> IO CInt) -> (CInt, ByteString))
-> (Ptr CChar -> IO CInt)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Ptr CChar -> IO CInt) -> (CInt, ByteString)
forall b. Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
clen ((Ptr CChar -> IO CInt) -> ByteString)
-> (Ptr CChar -> IO CInt) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr CChar
pc ->
    [ByteString] -> ([CStringLen] -> IO CInt) -> IO CInt
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
key, ByteString
msg, ByteString
aad, ByteString
nonce] (([CStringLen] -> IO CInt) -> IO CInt)
-> ([CStringLen] -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \
      [(Ptr CChar
pk, Int
_), (Ptr CChar
pm, Int
_), (Ptr CChar
pa, Int
_), (Ptr CChar
pn, Int
_)] ->
          Ptr CChar
-> Ptr CULLong
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> IO CInt
c_aead Ptr CChar
pc Ptr CULLong
forall a. Ptr a
nullPtr Ptr CChar
pm (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
mlen) Ptr CChar
pa (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
alen) Ptr CChar
forall a. Ptr a
nullPtr Ptr CChar
pn Ptr CChar
pk
  where mlen :: Int
mlen    = ByteString -> Int
S.length ByteString
msg
        alen :: Int
alen    = ByteString -> Int
S.length ByteString
aad
        clen :: Int
clen    = Int
mlen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
Bytes.aead_chacha20poly1305_abytes

-- | Decrypts a message. Returns 'Nothing' if the keys and message do
-- not match.
aeadOpen
    :: Key
    -> Nonce
    -> ByteString
    -- ^ Ciphertext
    -> ByteString
    -- ^ AAD
    -> Maybe ByteString
    -- ^ Message
aeadOpen :: Key -> Nonce -> ByteString -> ByteString -> Maybe ByteString
aeadOpen (Key ByteString
key) (Nonce ByteString
nonce) ByteString
cipher ByteString
aad = do
  let clen :: Int
clen   = ByteString -> Int
S.length ByteString
cipher
      alen :: Int
alen   = ByteString -> Int
S.length ByteString
aad
  Int
mlen <- Int
clen Int -> Int -> Maybe Int
forall a. (Ord a, Num a) => a -> a -> Maybe a
`safeSubtract` Int
Bytes.aead_chacha20poly1305_abytes
  let (CInt
err, ByteString
vec) = Int -> (Ptr CChar -> IO CInt) -> (CInt, ByteString)
forall b. Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
mlen ((Ptr CChar -> IO CInt) -> (CInt, ByteString))
-> (Ptr CChar -> IO CInt) -> (CInt, ByteString)
forall a b. (a -> b) -> a -> b
$ \Ptr CChar
pm ->
        [ByteString] -> ([CStringLen] -> IO CInt) -> IO CInt
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
key, ByteString
cipher, ByteString
aad, ByteString
nonce] (([CStringLen] -> IO CInt) -> IO CInt)
-> ([CStringLen] -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \
          [(Ptr CChar
pk, Int
_), (Ptr CChar
pc, Int
_), (Ptr CChar
pa, Int
_), (Ptr CChar
pn, Int
_)] ->
            Ptr CChar
-> Ptr CULLong
-> Ptr CChar
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> Ptr CChar
-> IO CInt
c_aead_open Ptr CChar
pm Ptr CULLong
forall a. Ptr a
nullPtr Ptr CChar
forall a. Ptr a
nullPtr Ptr CChar
pc (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
clen) Ptr CChar
pa (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
alen) Ptr CChar
pn Ptr CChar
pk
  Either String ByteString -> Maybe ByteString
forall s a. Either s a -> Maybe a
hush (Either String ByteString -> Maybe ByteString)
-> (ByteString -> Either String ByteString)
-> ByteString
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> ByteString -> Either String ByteString
forall a. CInt -> a -> Either String a
handleErrno CInt
err (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
vec

-- | Encrypts a message. It is infeasible for an attacker to decrypt
-- the message so long as the 'Nonce' is never repeated.
aeadDetached
    :: Key
    -> Nonce
    -> ByteString
    -- ^ Message
    -> ByteString
    -- ^ AAD
    -> (ByteString,ByteString)
    -- ^ Tag, Ciphertext
aeadDetached :: Key
-> Nonce -> ByteString -> ByteString -> (ByteString, ByteString)
aeadDetached (Key ByteString
key) (Nonce ByteString
nonce) ByteString
msg ByteString
aad =
  Int -> (Ptr CChar -> IO ByteString) -> (ByteString, ByteString)
forall b. Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
clen ((Ptr CChar -> IO ByteString) -> (ByteString, ByteString))
-> (Ptr CChar -> IO ByteString) -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ \Ptr CChar
pc ->
   ((CInt, ByteString) -> ByteString)
-> IO (CInt, ByteString) -> IO ByteString
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (CInt, ByteString) -> ByteString
forall a b. (a, b) -> b
snd (IO (CInt, ByteString) -> IO ByteString)
-> ((Ptr CChar -> IO CInt) -> IO (CInt, ByteString))
-> (Ptr CChar -> IO CInt)
-> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Ptr CChar -> IO CInt) -> IO (CInt, ByteString)
forall b. Int -> (Ptr CChar -> IO b) -> IO (b, ByteString)
buildUnsafeByteString' Int
tlen ((Ptr CChar -> IO CInt) -> IO ByteString)
-> (Ptr CChar -> IO CInt) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr CChar
pt ->
    [ByteString] -> ([CStringLen] -> IO CInt) -> IO CInt
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
key, ByteString
msg, ByteString
aad, ByteString
nonce] (([CStringLen] -> IO CInt) -> IO CInt)
-> ([CStringLen] -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \
      [(Ptr CChar
pk, Int
_), (Ptr CChar
pm, Int
_), (Ptr CChar
pa, Int
_), (Ptr CChar
pn, Int
_)] ->
          Ptr CChar
-> Ptr CChar
-> Ptr CULLong
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> IO CInt
c_aead_detached Ptr CChar
pc Ptr CChar
pt Ptr CULLong
forall a. Ptr a
nullPtr Ptr CChar
pm (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
mlen) Ptr CChar
pa (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
alen) Ptr CChar
forall a. Ptr a
nullPtr Ptr CChar
pn Ptr CChar
pk
  where mlen :: Int
mlen    = ByteString -> Int
S.length ByteString
msg
        alen :: Int
alen    = ByteString -> Int
S.length ByteString
aad
        clen :: Int
clen    = Int
mlen
        tlen :: Int
tlen    = Int
Bytes.aead_chacha20poly1305_abytes

-- | Decrypts a message. Returns 'Nothing' if the keys and message do
-- not match.
aeadOpenDetached
    :: Key
    -> Nonce
    -> ByteString
    -- ^ Tag
    -> ByteString
    -- ^ Ciphertext
    -> ByteString
    -- ^ AAD
    -> Maybe ByteString
    -- ^ Message
aeadOpenDetached :: Key
-> Nonce
-> ByteString
-> ByteString
-> ByteString
-> Maybe ByteString
aeadOpenDetached (Key ByteString
key) (Nonce ByteString
nonce) ByteString
tag ByteString
cipher ByteString
aad
    | ByteString -> Int
S.length ByteString
tag Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
tlen = Maybe ByteString
forall a. Maybe a
Nothing
    | Bool
otherwise =
  let (CInt
err, ByteString
vec) = Int -> (Ptr CChar -> IO CInt) -> (CInt, ByteString)
forall b. Int -> (Ptr CChar -> IO b) -> (b, ByteString)
buildUnsafeByteString Int
len ((Ptr CChar -> IO CInt) -> (CInt, ByteString))
-> (Ptr CChar -> IO CInt) -> (CInt, ByteString)
forall a b. (a -> b) -> a -> b
$ \Ptr CChar
pm ->
        [ByteString] -> ([CStringLen] -> IO CInt) -> IO CInt
forall b. [ByteString] -> ([CStringLen] -> IO b) -> IO b
constByteStrings [ByteString
key, ByteString
tag, ByteString
cipher, ByteString
aad, ByteString
nonce] (([CStringLen] -> IO CInt) -> IO CInt)
-> ([CStringLen] -> IO CInt) -> IO CInt
forall a b. (a -> b) -> a -> b
$ \
          [(Ptr CChar
pk, Int
_), (Ptr CChar
pt, Int
_), (Ptr CChar
pc, Int
_), (Ptr CChar
pa, Int
_), (Ptr CChar
pn, Int
_)] ->
            Ptr CChar
-> Ptr CChar
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> Ptr CChar
-> CULLong
-> Ptr CChar
-> Ptr CChar
-> IO CInt
c_aead_open_detached Ptr CChar
pm Ptr CChar
forall a. Ptr a
nullPtr Ptr CChar
pc (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Ptr CChar
pt Ptr CChar
pa (Int -> CULLong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
alen) Ptr CChar
pn Ptr CChar
pk
  in Either String ByteString -> Maybe ByteString
forall s a. Either s a -> Maybe a
hush (Either String ByteString -> Maybe ByteString)
-> (ByteString -> Either String ByteString)
-> ByteString
-> Maybe ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CInt -> ByteString -> Either String ByteString
forall a. CInt -> a -> Either String a
handleErrno CInt
err (ByteString -> Maybe ByteString) -> ByteString -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$ ByteString
vec
  where len :: Int
len    = ByteString -> Int
S.length ByteString
cipher
        alen :: Int
alen   = ByteString -> Int
S.length ByteString
aad
        tlen :: Int
tlen   = Int
Bytes.aead_chacha20poly1305_abytes