-- |
-- Module      : Internal
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- ML-KEM main internal algorithms
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
module Internal
    ( ParamSet(..), Params(..), Encode(..), Decode(..)
    , DecapsulationKey, EncapsulationKey, Ciphertext, SharedSecret
    , keyGen, toPublic, encaps, decaps
    ) where

import Basement.Nat

import Control.DeepSeq (NFData(..))
import Control.Monad

import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes)
import qualified Data.ByteArray as B

import qualified Builder
import qualified Crypto
import qualified K_PKE as K
import K_PKE (Params(..))
import Marking (Leak(..))

-- | The class of ML-KEM parameter sets.
class KnownNat (K a) => ParamSet a where
    type K a :: Nat
    getParams :: proxy a -> Params (K a)

-- | Utility class to serialize ML-KEM objects to byte arrays.
class Encode obj where
    -- | Serializes an object to a sequence of bytes.
    encode :: ByteArray ba => obj a -> ba

-- | Utility class to deserialize ML-KEM objects from byte arrays.
class Decode obj where
    -- | Deserializes an object from a sequence of bytes.
    decode :: (ParamSet a, ByteArrayAccess ba) => proxy a -> ba -> Maybe (obj a)

-- | An ML-KEM decapsulation key, aka private key.
data DecapsulationKey a = DK (K.DecryptionKey (K a)) (K.EncryptionKey (K a)) Bytes ScrubbedBytes

-- | An ML-KEM encapsulation key, aka public key.
data EncapsulationKey a = EK Bytes (K.EncryptionKey (K a))

-- | The ciphertext produced by the encapsulation function and consumed by the
-- decapsulation function.
newtype Ciphertext a = C Bytes deriving (Eq, ByteArrayAccess)

-- | A shared secret returned by the encapsulation and decapsulation functions.
-- Length is 32 bytes for all defined parameter sets.
newtype SharedSecret a = S ScrubbedBytes deriving ByteArrayAccess

instance Eq (DecapsulationKey a) where
    DK dk1 ek1 h1 z1 == DK dk2 ek2 h2 z2 = Crypto.toBool $
        Crypto.constEqW dk1 dk2 `Crypto.andW`
        Crypto.constEqW ek1 ek2 `Crypto.andW`
        Crypto.constEqW h1 h2 `Crypto.andW`
        Crypto.constEqW z1 z2

instance Eq (EncapsulationKey a) where
    EK _ ek1 == EK _ ek2 = Crypto.toBool $ Crypto.constEqW ek1 ek2

instance Eq (SharedSecret a) where
    S a == S b = Crypto.toBool $ Crypto.constEqW a b

instance Show (DecapsulationKey a) where
#ifdef ML_KEM_TESTING
    showsPrec d dk = showParen (d > 10) $
        showString "DecapsulationKey " . showsPrec 11 (encode dk :: Bytes)
#else
    showsPrec _ _ = showString "DecapsulationKey"
#endif

instance Show (EncapsulationKey a) where
    showsPrec d ek = showParen (d > 10) $
        showString "EncapsulationKey " . showsPrec 11 (encode ek :: Bytes)

instance Show (Ciphertext a) where
    showsPrec d (C ct) = showParen (d > 10) $
        showString "Ciphertext " . showsPrec 11 ct

instance Show (SharedSecret a) where
#ifdef ML_KEM_TESTING
    showsPrec d (S kk) = showParen (d > 10) $
        showString "SharedSecret " . showsPrec 11 kk
#else
    showsPrec _ _ = showString "SharedSecret"
#endif

instance NFData (DecapsulationKey a) where
    rnf (DK dk ek h z) = rnf dk `seq` rnf ek `seq` rnf h `seq` rnf z

instance NFData (EncapsulationKey a) where
    rnf (EK _ ek) = rnf ek  -- h omitted because just for caching

instance NFData (Ciphertext a) where
    rnf (C c) = rnf c

instance NFData (SharedSecret a) where
    rnf (S kk) = rnf kk

instance Encode EncapsulationKey where
    encode (EK _ ek) = Builder.runRelaxed $ K.ekEncode ek

instance Decode EncapsulationKey where
    decode p input = EK (Crypto.h input) <$> K.ekDecode params input
      where params = getParams p

instance Encode DecapsulationKey where
    encode (DK dk ek h z) = Builder.runRelaxed $
        leak (K.dkEncode dk) <> K.ekEncode ek <> Builder.bytes h <> leak (Builder.bytes z)

instance Decode DecapsulationKey where
    decode p input = do
        -- decapsulation key type check:
        guard (B.length input == 768 * k + 96)
        let dks = B.view input 0 (384 * k)
            eks = B.view input (384 * k) (384 * k + 32)
            !h = B.convert $ B.view input (768 * k + 32) 32
        -- hash check:
        guard (Crypto.toBool $ Crypto.constEqW h (Crypto.h eks))
        let !dk = K.dkDecode dks
        !ek <- K.ekDecode params eks
        let !z = B.convert $ B.view input (768 * k + 64) 32
        return (DK dk ek h z)
      where
        params = getParams p
        k = K.dimension params

instance Decode Ciphertext where
    decode p input
        -- ciphertext type check:
        | B.length input == 32 * (du * k + dv) = Just (C $ B.convert input)
        | otherwise = Nothing
      where
        params@Params{..} = getParams p
        k = K.dimension params

instance Decode SharedSecret where
    decode _ input
        | B.length input == 32 = Just (S $ B.convert input)
        | otherwise = Nothing

-- Uses randomness to generate an encapsulation key and a corresponding decapsulation key
keyGen :: (ParamSet a, ByteArrayAccess d) => proxy a -> d -> ScrubbedBytes -> (EncapsulationKey a, DecapsulationKey a)
keyGen p d z = (EK h ek, DK dk ek h z)
  where
    params = getParams p
    (ek, dk) = K.keyGen params d
    h = Crypto.h $ Builder.run (K.ekEncode ek)

-- | Returns the encapsulation key embedded in the given decapsulation key.
-- Note that they may not necessarily match when the decapsulation key was
-- decoded from an untrusted source.
toPublic :: DecapsulationKey a -> EncapsulationKey a
toPublic (DK _ ek h _) = EK h ek

-- Uses the encapsulation key and randomness to generate a key and an associated ciphertext
encaps :: (ParamSet a, ByteArrayAccess m) => EncapsulationKey a -> m -> (SharedSecret a, Ciphertext a)
encaps p@(EK h ek) m = (S kk, C c)
  where
    params = getParams p
    (kk, r) = Crypto.g (m `Crypto.append` h)
    c = K.encrypt params ek m r

-- Uses the decapsulation key to produce a shared secret key from a ciphertext
decaps :: ParamSet a => DecapsulationKey a -> Ciphertext a -> SharedSecret a
decaps p@(DK dk ek h z) (C c) = S $
    Crypto.constSelectBytes
        (Crypto.constEqW c c')            -- condition
        kk'                               -- when equal
        (Crypto.j (z `Crypto.append` c))  -- when different
  where
    params    = getParams p
    m'        = K.decrypt params dk c
    (kk', r') = Crypto.g (m' `Crypto.append` h)
    c'        = K.encrypt params ek m' r'
