{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
-- |
-- Module      : Crypto.PubKey.ML_KEM
-- License     : BSD-style
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : experimental
-- Portability : unknown
--
-- Module-Lattice-based Key-Encapsulation Mechanism (ML-KEM), defined
-- in <https://csrc.nist.gov/pubs/fips/203/final FIPS 203>.
module Crypto.PubKey.ML_KEM
    ( EncapsulationKey, DecapsulationKey, Ciphertext, SharedSecret
    -- * Operations
    , generate, encapsulate, decapsulate, generateWith, encapsulateWith
    -- * Parameter sets
    , ParamSet, ML_KEM_512, ML_KEM_768, ML_KEM_1024
    -- * Conversions and checks
    , Decode(..), Encode(..)
    , toPublic, checkKeyPair
    ) where

import Crypto.Random

import Data.ByteArray (ByteArrayAccess, ScrubbedBytes)
import qualified Data.ByteArray as B

import Internal

-- | ML-KEM-512 (security category 1)
data ML_KEM_512  = ML_KEM_512  deriving Show
-- | ML-KEM-768 (security category 3)
data ML_KEM_768  = ML_KEM_768  deriving Show
-- | ML-KEM-1024 (security category 5)
data ML_KEM_1024 = ML_KEM_1024 deriving Show

instance ParamSet ML_KEM_512 where
    type K ML_KEM_512 = 2
    getParams _ = Params 3 2 10 4
instance ParamSet ML_KEM_768 where
    type K ML_KEM_768 = 3
    getParams _ = Params 2 2 10 4
instance ParamSet ML_KEM_1024 where
    type K ML_KEM_1024 = 4
    getParams _ = Params 2 2 11 5

-- | Generate an ML-KEM key pair from a random seed.
generate :: (ParamSet a, MonadRandom m)
         => proxy a -> m (EncapsulationKey a, DecapsulationKey a)
generate p = do
    d <- getRandomBytes 32
    z <- getRandomBytes 32
    return (Internal.keyGen p (d :: ScrubbedBytes) z)

-- | Generate an ML-KEM key pair from the specified seed (d, z).  Length of
-- inputs must be 32 bytes.  For testing purposes.
generateWith :: (ParamSet a, ByteArrayAccess d, ByteArrayAccess z)
             => proxy a -> d -> z -> Maybe (EncapsulationKey a, DecapsulationKey a)
generateWith p d z
    | B.length d /= 32 = Nothing
    | B.length z /= 32 = Nothing
    | otherwise = Just $ Internal.keyGen p d (B.convert z)

-- | Generate a shared secret key and an associated ciphertext using randomness.
encapsulate :: (ParamSet a, MonadRandom m)
            => EncapsulationKey a -> m (SharedSecret a, Ciphertext a)
encapsulate ek = do
    m <- getRandomBytes 32
    return (Internal.encaps ek (m :: ScrubbedBytes))

-- | Generate a shared secret key and an associated ciphertext using a
-- specified random input.  This byte array must be 32 bytes and not repeated
-- with other encapsulations.  For testing purposes.
encapsulateWith :: (ParamSet a, ByteArrayAccess m)
                => EncapsulationKey a -> m -> Maybe (SharedSecret a, Ciphertext a)
encapsulateWith ek m
    | B.length m /= 32 = Nothing
    | otherwise = Just $ Internal.encaps ek m

-- | Return the shared secret for a given ciphertext.  Does implicit rejection
-- in the event the ciphertext or encapsulation key have been tampered with.
decapsulate :: ParamSet a => DecapsulationKey a -> Ciphertext a -> SharedSecret a
decapsulate = Internal.decaps

-- | Try to detect corruptions in a pair of keys.  Note that this does not
-- fully guarantee that the key pair was properly generated.
checkKeyPair :: (ParamSet a, MonadRandom m)
             => (EncapsulationKey a, DecapsulationKey a) -> m Bool
checkKeyPair (ek, dk) = do
    m <- getRandomBytes 32
    let (kk, ct) = Internal.encaps ek (m :: ScrubbedBytes)
        kk' = Internal.decaps dk ct
    return (kk' == kk)
