-- |
-- Module      : K_PKE
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- The K-PKE component scheme
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
module K_PKE
    ( Params(..), dimension, keyGen, encrypt, decrypt
    , DecryptionKey, dkEncode, dkDecode
    , EncryptionKey, ekEncode, ekDecode
    ) where

import Basement.Nat
import Basement.Types.OffsetSize

import Control.DeepSeq (NFData(..))
import Control.Monad

import Data.ByteArray (ByteArrayAccess, Bytes, ScrubbedBytes)
import qualified Data.ByteArray as B

import Unsafe.Coerce

import Auxiliary (Rq, Tq, (..+), (..-))
import Builder (Builder)
import Iterate
import Marking (SecurityMarking(..), Leak(..))
import Vector (Vector)
import qualified Auxiliary as Aux
import qualified Crypto
import qualified Builder
import Math
import qualified Matrix
import qualified Vector

data Params (k :: Nat) = Params
    { eta1 :: {-# UNPACK #-} !Word
    , eta2 :: {-# UNPACK #-} !Word
    , du :: {-# UNPACK #-} !Int
    , dv :: {-# UNPACK #-} !Int
    }

dimension :: KnownNat k => Params k -> Int
dimension = fromIntegral . natVal

class Leak t => LeakVec vec t where
    leakVec :: vec (t Sec) -> vec (t Pub)
    leakVec = unsafeCoerce

instance LeakVec (Vector k) Tq
instance LeakVec (Vector k) Rq

newtype DecryptionKey (k :: Nat) = DecryptionKey { dkS :: Vector k (Tq Sec) }
data EncryptionKey (k :: Nat) = EncryptionKey { ekT :: Vector k (Tq Pub), ekRho :: Bytes, ekA :: Vector k (Vector k (Tq Pub)) }

instance Crypto.ConstEqW (DecryptionKey k) where
    constEqW a b = Crypto.constEqW (dkS a) (dkS b)

instance Crypto.ConstEqW (EncryptionKey k) where
    constEqW a b = Crypto.constEqW (ekT a) (ekT b) `Crypto.andW` Crypto.constEqW (ekRho a) (ekRho b)

instance NFData (DecryptionKey k) where
    rnf = Vector.toNormalForm . dkS

instance NFData (EncryptionKey k) where
    rnf ek = Vector.toNormalForm (ekT ek) `seq` rnf (ekRho ek)
    -- ekA omitted because just for caching

ekEncode :: EncryptionKey k -> Builder Pub
ekEncode ek = Vector.concatMap Aux.byteEncode12 (ekT ek) <> Builder.bytes (ekRho ek)

ekDecode :: (KnownNat k, ByteArrayAccess ba) => Params k -> ba -> Maybe (EncryptionKey k)
ekDecode params input = do
    -- type check:
    guard (B.length input == 384 * k + 32)
    let !tt = Vector.create $ \i -> Aux.byteDecode12 (view384 i)
        !rho = B.convert $ B.view input (384 * k) 32
        elem384 off = Builder.run (Aux.byteEncode12 (Vector.index tt off))
    -- modulus check:
    forM_ (offsets k) $ \i -> guard (elem384 (Offset i) `Crypto.eq` view384 (Offset i))
    let aa = createMatrix rho
    Just EncryptionKey { ekT = tt, ekRho = rho, ekA = aa }
  where
    k = dimension params
    view384 (Offset i) = B.view input (384 * i) 384

dkEncode :: DecryptionKey k -> Builder Sec
dkEncode = Vector.concatMap Aux.byteEncode12 . dkS

dkDecode :: (KnownNat k, ByteArrayAccess ba) => ba -> DecryptionKey k
dkDecode input = do
    let !dk = Vector.create $ \i -> Aux.byteDecode12 (view384 i)
     in DecryptionKey { dkS = dk }
  where
    view384 (Offset i) = B.view input (384 * i) 384

createMatrix :: KnownNat k => Bytes -> Vector k (Vector k (Tq Pub))
createMatrix !rho = Matrix.create $ \(Offset i) (Offset j) ->
    Aux.sampleNTT rho (fromIntegral j) (fromIntegral i)

createVector :: (KnownNat k, ByteArrayAccess s) => Word -> s -> Int -> Vector k (Rq Sec)
createVector !eta !s !j = Vector.create $ \(Offset i) -> sample eta s (i + j)

sample :: ByteArrayAccess s => Word -> s -> Int -> Rq Sec
sample eta s = Aux.samplePolyCBD eta . Crypto.prf eta s . fromIntegral

-- Uses randomness to generate an encryption key and a corresponding decryption key
keyGen :: (KnownNat k, ByteArrayAccess d) => Params k -> d -> (EncryptionKey k, DecryptionKey k)
keyGen params@Params{..} d = (ek, dk)
  where
    k   = dimension params
    (rho, sigma) = Crypto.g (Crypto.snoc d (fromIntegral k))
    aa  = createMatrix rho
    s   = createVector eta1 sigma 0
    e   = createVector eta1 sigma k
    !ss = Aux.ntt <$> s
    ee  = Aux.ntt <$> e
    !tt = leakVec $ Matrix.mulw aa ss ee
    ek  = EncryptionKey { ekT = tt, ekRho = rho, ekA = aa }
    dk  = DecryptionKey { dkS = ss }

-- Uses the encryption key to encrypt a plaintext message using the randomness 𝑟
encrypt :: (KnownNat k, ByteArrayAccess m, ByteArrayAccess r) => Params k -> EncryptionKey k -> m -> r -> Bytes
encrypt params@Params{..} ek m r = Builder.run (c1 <> c2)
  where
    k   = dimension params
    tt  = ekT ek
    aa  = ekA ek
    y   = createVector eta1 r 0
    e1  = createVector eta2 r k
    e2  = sample eta2 r (2 * k)
    yy  = Aux.ntt <$> y
    u   = leakVec $ (Aux.nttInv <$> Matrix.muly aa yy) .+ e1
    mu  = Aux.rdecompress 1 (Aux.byteDecode1 m)
    v   = Aux.nttInv (tt `Matrix.mulz` yy) .+ e2 ..+ mu
    c1  = Vector.concatMap (Aux.byteEncode du . Aux.rcompress du) u
    c2  = Aux.byteEncode dv (Aux.rcompress dv v)

-- Uses the decryption key to decrypt a ciphertext
decrypt :: KnownNat k => Params k -> DecryptionKey k -> Bytes -> ScrubbedBytes
decrypt params@Params{..} dk c = Builder.run m
  where
    k  = dimension params
    c2 = B.view c (32 * du * k) (32 * dv)
    u' = Vector.create $ \(Offset i) -> Aux.rdecompress du . Aux.byteDecode du $ B.view c (32 * du * i) (32 * du) :: Rq Pub
    v' = Aux.rdecompress dv (Aux.byteDecode dv c2)
    w  = v' ..- Aux.nttInv ((Aux.ntt <$> u') `Matrix.mulz` dkS dk)
    m  = Aux.byteEncode1 (Aux.rcompress 1 w)
