{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE RecordWildCards            #-}

-- |
--
-- Module      : aead-api: Interface
-- Description : Generic interface to authenticated encryption.
-- Copyright   : (c) Piyush P Kurur, 2019
-- License     : Apache-2.0 OR BSD-3-Clause
-- Maintainer  : Piyush P Kurur <ppk@iitpkd.ac.in>
-- Stability   : experimental
--
-- | The interface for an aead construction using a stream cipher like
-- chacha20 and authenticator like poly1305.
module Interface( -- * Locking and unlocking stuff
                  unsafeLock, unlock
                  -- ** Additional data.
                , Locked, AuthTag, Cipher
                , unsafeLockWith, unlockWith
                , unsafeToNounce, unsafeToCipherText, unsafeToAuthTag
                , unsafeLocked
                , AEADMem
                , name
                , description
                ) where

import           Data.ByteString
import           System.IO.Unsafe ( unsafePerformIO )

import           Raaz.Core
import           Raaz.Primitive.AEAD.Internal
import qualified Cipher.Implementation as CI
import qualified Auth.Implementation   as AI

import qualified Cipher.Utils as CU
import qualified Auth.Utils as AU

import qualified Cipher.Buffer as CB

-- | The associated cipher.
type Cipher = CI.Prim

-- | The associated message authenticator.
type AuthTag = AI.Prim

-- | The locked message.
type Locked  = AEAD Cipher AuthTag

-- | This function takes the plain text and the additional data, and
-- constructs the associated Locked message. A peer who has the right
-- @(key, nounce)@ pair and the `aad` can recover the unencrypted
-- object using the `unlockWith` function.
unsafeLockWith :: (Encodable plain, Encodable aad)
               => aad              -- ^ the authenticated additional data.
               -> Key Cipher       -- ^ The key for the stream cipher
               -> Nounce Cipher    -- ^ The nounce used by the stream cipher.
               -> plain            -- ^ the unencrypted object
               -> Locked
unsafeLockWith :: aad -> Key Cipher -> Nounce Cipher -> plain -> Locked
unsafeLockWith aad
aad Key Cipher
k Nounce Cipher
n plain
plain = IO Locked -> Locked
forall a. IO a -> a
unsafePerformIO (IO Locked -> Locked) -> IO Locked -> Locked
forall a b. (a -> b) -> a -> b
$ (AEADMem -> IO Locked) -> IO Locked
forall mem a. Memory mem => (mem -> IO a) -> IO a
withMemory ((AEADMem -> IO Locked) -> IO Locked)
-> (AEADMem -> IO Locked) -> IO Locked
forall a b. (a -> b) -> a -> b
$ \ AEADMem
mem -> do
  Key Cipher -> AEADMem -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Key Cipher
k AEADMem
mem
  Nounce Cipher -> AEADMem -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Nounce Cipher
n AEADMem
mem
  ByteString
cText <- plain -> AEADMem -> IO ByteString
forall plain. Encodable plain => plain -> AEADMem -> IO ByteString
encrypt plain
plain AEADMem
mem
  Nounce Cipher -> ByteString -> AuthTag -> Locked
forall c t. Nounce c -> ByteString -> t -> AEAD c t
AEAD Nounce Cipher
n ByteString
cText (AuthTag -> Locked) -> IO AuthTag -> IO Locked
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> aad -> ByteString -> AEADMem -> IO AuthTag
forall aad.
Encodable aad =>
aad -> ByteString -> AEADMem -> IO AuthTag
computeAuth aad
aad ByteString
cText AEADMem
mem

-- | Unlock an encrypted authenticated version of the data given the
-- additional data, key, and nounce. An attempt to unlock the element
-- can result in `Nothing` if either of the following is true.
--
-- 1. The key, nounce pair used to encrypt the data is incorrect.
--
-- 2. The Authenticated additional data (@aad@) is incorrect.
--
-- 3. The Locked message is of the wrong type and hence the
-- `fromByteString` failed.
--
-- 4. The Locked message has been tampered.
--
-- The interface provided above makes it impossible to know which of
-- the above errors occurred. This is a deliberate design as revealing
-- the nature of the failure can leak information to a potential
-- attacker.
--
unlockWith :: (Encodable plain, Encodable aad)
            => aad              -- ^ the authenticated additional data.
            -> Key Cipher       -- ^ The key for the stream cipher
            -> Locked
                                -- ^ The encrypted authenticated version of the data.
            -> Maybe plain
unlockWith :: aad -> Key Cipher -> Locked -> Maybe plain
unlockWith aad
aad Key Cipher
k Locked
aead = IO (Maybe plain) -> Maybe plain
forall a. IO a -> a
unsafePerformIO (IO (Maybe plain) -> Maybe plain)
-> IO (Maybe plain) -> Maybe plain
forall a b. (a -> b) -> a -> b
$ (AEADMem -> IO (Maybe plain)) -> IO (Maybe plain)
forall mem a. Memory mem => (mem -> IO a) -> IO a
withMemory ((AEADMem -> IO (Maybe plain)) -> IO (Maybe plain))
-> (AEADMem -> IO (Maybe plain)) -> IO (Maybe plain)
forall a b. (a -> b) -> a -> b
$ \ AEADMem
mem -> do
  Key Cipher -> AEADMem -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Key Cipher
k AEADMem
mem
  Nounce Cipher -> AEADMem -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise (Locked -> Nounce Cipher
forall c t. AEAD c t -> Nounce c
unsafeToNounce Locked
aead) AEADMem
mem
  Bool
isSuccess <- aad -> Locked -> AEADMem -> IO Bool
forall aad. Encodable aad => aad -> Locked -> AEADMem -> IO Bool
verify aad
aad Locked
aead AEADMem
mem
  if Bool
isSuccess then Locked -> AEADMem -> IO (Maybe plain)
forall plain.
Encodable plain =>
Locked -> AEADMem -> IO (Maybe plain)
decrypt Locked
aead AEADMem
mem else Maybe plain -> IO (Maybe plain)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe plain
forall a. Maybe a
Nothing


-- | Generate a locked version of an unencrypted object. You will need
-- the exact same key and nounce to unlock the object.
unsafeLock :: Encodable plain
           => Key Cipher
           -> Nounce Cipher
           -> plain
           -> Locked
unsafeLock :: Key Cipher -> Nounce Cipher -> plain -> Locked
unsafeLock = () -> Key Cipher -> Nounce Cipher -> plain -> Locked
forall plain aad.
(Encodable plain, Encodable aad) =>
aad -> Key Cipher -> Nounce Cipher -> plain -> Locked
unsafeLockWith ()


-- | Unlock the encrypted packet.
unlock :: Encodable plain
       => Key Cipher
       -> Locked
       -> Maybe plain
unlock :: Key Cipher -> Locked -> Maybe plain
unlock = () -> Key Cipher -> Locked -> Maybe plain
forall plain aad.
(Encodable plain, Encodable aad) =>
aad -> Key Cipher -> Locked -> Maybe plain
unlockWith ()

-- | The internal memory used for computing the AEAD packet. When using
-- this memory for packet computation, it is important to initalise the
-- memory in the following order.
--
-- 1. Initialise with key either using the `initialise` function or, by using
--    the `WriteAccessible` instance using the `mem.
-- 2. Initialise the nounce
--
-- We are then all set to go.
--
data AEADMem = AEADMem { AEADMem -> Internals
cipherInternals :: CI.Internals
                       , AEADMem -> Internals
authInternals   :: AI.Internals
                       , AEADMem -> Buffer 1
internBuffer    :: CB.Buffer 1
                       }

instance Memory AEADMem where
  memoryAlloc :: Alloc AEADMem
memoryAlloc     = Internals -> Internals -> Buffer 1 -> AEADMem
AEADMem (Internals -> Internals -> Buffer 1 -> AEADMem)
-> TwistRF AllocField (BYTES Int) Internals
-> TwistRF
     AllocField (BYTES Int) (Internals -> Buffer 1 -> AEADMem)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) Internals
forall m. Memory m => Alloc m
memoryAlloc TwistRF AllocField (BYTES Int) (Internals -> Buffer 1 -> AEADMem)
-> TwistRF AllocField (BYTES Int) Internals
-> TwistRF AllocField (BYTES Int) (Buffer 1 -> AEADMem)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) Internals
forall m. Memory m => Alloc m
memoryAlloc TwistRF AllocField (BYTES Int) (Buffer 1 -> AEADMem)
-> TwistRF AllocField (BYTES Int) (Buffer 1) -> Alloc AEADMem
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) (Buffer 1)
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: AEADMem -> Ptr Word8
unsafeToPointer = Internals -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer (Internals -> Ptr Word8)
-> (AEADMem -> Internals) -> AEADMem -> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AEADMem -> Internals
cipherInternals

-- | Initialise with the key of the cipher.
instance Initialisable AEADMem (Key Cipher) where
  initialise :: Key Cipher -> AEADMem -> IO ()
initialise Key Cipher
k = Key Cipher -> Internals -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Key Cipher
k (Internals -> IO ()) -> (AEADMem -> Internals) -> AEADMem -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AEADMem -> Internals
cipherInternals

instance WriteAccessible AEADMem where
  writeAccess :: AEADMem -> [Access]
writeAccess = Internals -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess (Internals -> [Access])
-> (AEADMem -> Internals) -> AEADMem -> [Access]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AEADMem -> Internals
cipherInternals
  afterWriteAdjustment :: AEADMem -> IO ()
afterWriteAdjustment = Internals -> IO ()
forall mem. WriteAccessible mem => mem -> IO ()
afterWriteAdjustment (Internals -> IO ()) -> (AEADMem -> Internals) -> AEADMem -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AEADMem -> Internals
cipherInternals

-- | Initialise after the key is already initialised.
instance Initialisable AEADMem (Nounce Cipher) where
  initialise :: Nounce Cipher -> AEADMem -> IO ()
initialise Nounce Cipher
n AEADMem{Internals
Internals
Buffer 1
internBuffer :: Buffer 1
authInternals :: Internals
cipherInternals :: Internals
internBuffer :: AEADMem -> Buffer 1
authInternals :: AEADMem -> Internals
cipherInternals :: AEADMem -> Internals
..} = do
    Nounce Cipher -> Internals -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise Nounce Cipher
n Internals
cipherInternals
    let zeroCount :: BlockCount Cipher
zeroCount = Int
0 Int -> Proxy Cipher -> BlockCount Cipher
forall p. Int -> Proxy p -> BlockCount p
`blocksOf` (Proxy Cipher
forall k (t :: k). Proxy t
Proxy :: Proxy Cipher)
      in BlockCount Cipher -> Internals -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise BlockCount Cipher
zeroCount Internals
cipherInternals

    --
    -- Generate the key stream
    --
    Word8 -> Buffer 1 -> IO ()
forall (n :: Nat). KnownNat n => Word8 -> Buffer n -> IO ()
CB.memsetBuffer Word8
0 Buffer 1
internBuffer                -- clear the internal buffer
    Buffer 1 -> Internals -> IO ()
forall (n :: Nat). KnownNat n => Buffer n -> Internals -> IO ()
CU.processBuffer Buffer 1
internBuffer Internals
cipherInternals -- generate the keystream
    --
    -- Initialise the authenticator from the keystream.
    --
    Dest Internals -> Src (Buffer 1) -> IO ()
forall src dest.
(ReadAccessible src, WriteAccessible dest) =>
Dest dest -> Src src -> IO ()
memTransfer (Internals -> Dest Internals
forall a. a -> Dest a
destination Internals
authInternals) (Buffer 1 -> Src (Buffer 1)
forall a. a -> Src a
source Buffer 1
internBuffer)

--------------------- Internal functions ---------------------------------
---
-- These are some of the internal functions that are used by various
-- lock unlock functions. One of the constraints that we want to
-- enforce is that unauthenticated input should never be
-- decrypted. Hence, despite their cute names, these functions should
-- not be exposed to the user from this module

-- | Transform the input bytestring with the cipher.
transform :: ByteString -- The plain text associated with the data
          -> AEADMem
          -> IO ByteString
transform :: ByteString -> AEADMem -> IO ByteString
transform ByteString
bs = ByteString -> Internals -> IO ByteString
CU.transform ByteString
bs (Internals -> IO ByteString)
-> (AEADMem -> Internals) -> AEADMem -> IO ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AEADMem -> Internals
cipherInternals

-- | Compute the authenticator
computeAuth :: Encodable aad
            => aad               -- ^ The additional data that needs
                                 -- to be authenticated
            -> ByteString        -- ^ The cipher text.
            -> AEADMem
            -> IO AuthTag
computeAuth :: aad -> ByteString -> AEADMem -> IO AuthTag
computeAuth aad
aad ByteString
cText AEADMem
aeadmem =
  ByteString -> Internals -> IO ()
forall src. ByteSource src => src -> Internals -> IO ()
AU.processByteSource (WriteTo -> ByteString
forall a. Encodable a => a -> ByteString
toByteString WriteTo
authWr) Internals
authMem IO () -> IO AuthTag -> IO AuthTag
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Internals -> IO AuthTag
forall m v. Extractable m v => m -> IO v
extract Internals
authMem
  where (WriteTo
aadWr, LE Word64
lAAD) = aad -> (WriteTo, LE Word64)
forall a. Encodable a => a -> (WriteTo, LE Word64)
padAndLen aad
aad
        (WriteTo
cWr, LE Word64
lC)     = ByteString -> (WriteTo, LE Word64)
forall a. Encodable a => a -> (WriteTo, LE Word64)
padAndLen ByteString
cText
        authWr :: WriteTo
authWr        = WriteTo
aadWr WriteTo -> WriteTo -> WriteTo
forall a. Semigroup a => a -> a -> a
<> WriteTo
cWr WriteTo -> WriteTo -> WriteTo
forall a. Semigroup a => a -> a -> a
<> LE Word64 -> WriteTo
forall a. EndianStore a => a -> WriteTo
write LE Word64
lAAD WriteTo -> WriteTo -> WriteTo
forall a. Semigroup a => a -> a -> a
<> LE Word64 -> WriteTo
forall a. EndianStore a => a -> WriteTo
write LE Word64
lC
        authMem :: Internals
authMem       = AEADMem -> Internals
authInternals AEADMem
aeadmem


verify :: Encodable aad
       => aad
       -> Locked
       -> AEADMem
       -> IO Bool
verify :: aad -> Locked -> AEADMem -> IO Bool
verify aad
aad Locked
aead = (AuthTag -> Bool) -> IO AuthTag -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap AuthTag -> Bool
matchTag (IO AuthTag -> IO Bool)
-> (AEADMem -> IO AuthTag) -> AEADMem -> IO Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. aad -> ByteString -> AEADMem -> IO AuthTag
forall aad.
Encodable aad =>
aad -> ByteString -> AEADMem -> IO AuthTag
computeAuth aad
aad (Locked -> ByteString
forall c t. AEAD c t -> ByteString
unsafeToCipherText Locked
aead)
  where matchTag :: AuthTag -> Bool
matchTag = AuthTag -> AuthTag -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Locked -> AuthTag
forall c t. AEAD c t -> t
unsafeToAuthTag Locked
aead)

-- | Encrypt a plain text object.
encrypt :: Encodable plain
        => plain      -- The plain object that needs encryption
        -> AEADMem
        -> IO ByteString
encrypt :: plain -> AEADMem -> IO ByteString
encrypt plain
plain = ByteString -> AEADMem -> IO ByteString
transform (ByteString -> AEADMem -> IO ByteString)
-> ByteString -> AEADMem -> IO ByteString
forall a b. (a -> b) -> a -> b
$ plain -> ByteString
forall a. Encodable a => a -> ByteString
toByteString plain
plain


-- | Decrypt to recover the plain text object. We assume a stream
-- cipher and hence transform is the encryption and decryption
-- routine.
decrypt :: Encodable plain
        => Locked
        -> AEADMem
        -> IO (Maybe plain)
decrypt :: Locked -> AEADMem -> IO (Maybe plain)
decrypt Locked
aead = (ByteString -> Maybe plain) -> IO ByteString -> IO (Maybe plain)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> Maybe plain
forall a. Encodable a => ByteString -> Maybe a
fromByteString (IO ByteString -> IO (Maybe plain))
-> (AEADMem -> IO ByteString) -> AEADMem -> IO (Maybe plain)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> AEADMem -> IO ByteString
transform (Locked -> ByteString
forall c t. AEAD c t -> ByteString
unsafeToCipherText Locked
aead)

-- | Compute the padded write of an encodable element and its length.
padAndLen :: Encodable a => a -> (WriteTo, LE Word64)
padAndLen :: a -> (WriteTo, LE Word64)
padAndLen a
a = (Word8 -> BlockCount AuthTag -> WriteTo -> WriteTo
forall n. LengthUnit n => Word8 -> n -> WriteTo -> WriteTo
padWrite Word8
0 BlockCount AuthTag
pL WriteTo
aWr, LE Word64
len)
  where aWr :: WriteTo
aWr   = a -> WriteTo
forall a. Encodable a => a -> WriteTo
writeEncodable a
a
        len :: LE Word64
len   = BYTES Int -> LE Word64
toLen (WriteTo -> BYTES Int
forall (t :: Mode). Transfer t -> BYTES Int
transferSize WriteTo
aWr)
        toLen :: BYTES Int -> LE Word64
toLen = Int -> LE Word64
forall a. Enum a => Int -> a
toEnum (Int -> LE Word64) -> (BYTES Int -> Int) -> BYTES Int -> LE Word64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int -> Int
forall a. Enum a => a -> Int
fromEnum
        pL :: BlockCount AuthTag
pL    = Int
1 Int -> Proxy AuthTag -> BlockCount AuthTag
forall p. Int -> Proxy p -> BlockCount p
`blocksOf` (Proxy AuthTag
forall k (t :: k). Proxy t
Proxy :: Proxy AuthTag)

-- | Create the locked message from the associated Nounce, cipher
-- text, and the authentication tag.
unsafeLocked :: Nounce Cipher
             -> ByteString
             -> AuthTag
             -> Locked
unsafeLocked :: Nounce Cipher -> ByteString -> AuthTag -> Locked
unsafeLocked = Nounce Cipher -> ByteString -> AuthTag -> Locked
forall c t. Nounce c -> ByteString -> t -> AEAD c t
AEAD

name :: String
name :: String
name = [String] -> String
unwords [String
CI.name,String
"+", String
AI.name]

description :: String
description :: String
description = [String] -> String
unwords [ String
"AEAD implementation based on",String
name]