-- |
-- Module      : Marking
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- Infrastructure that associates a security marking at type level to all
-- buffers created by the library.  This determines which buffers need the
-- scrubbed (Sec) or regular (Pub) variants.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilyDependencies #-}
module Marking
    ( SecurityMarking(..), Classified(..), Leak(..), index
    , Marking.toNormalForm, unsafeCast
#ifdef ML_KEM_TESTING
    , Marking.toList
#endif
    ) where

import Basement.Monad
import Basement.NormalForm
import Basement.PrimType
import Basement.Types.OffsetSize
#ifdef ML_KEM_TESTING
import qualified Basement.Compat.IsList
#endif

import Control.Monad.ST

import Data.ByteArray (Bytes, ScrubbedBytes)
import qualified Data.ByteArray as B

import Data.Kind

import Foreign.Ptr (Ptr)

import Unsafe.Coerce

import Block (Block, MutableBlock, blockIndex)
import ScrubbedBlock (ScrubbedBlock)
import qualified Block
import qualified ByteArrayST as ST
import qualified ScrubbedBlock

data SecurityMarking = Sec | Pub  -- secret or public information

-- Transformation called only at expected location in the LWE problem, after
-- adding noise to secret information.
--
-- Block and ScrubbedBlock have the same representation, we can force coercion
-- from Sec to Pub even though the block will be actually scrubbed.  This is
-- simpler than copying to a real non-scrubbed block.
class Leak t where
    leak :: t Sec -> t Pub
    leak = unsafeCoerce

class Classified (marking :: SecurityMarking) where
    type SecureBlock marking = (block :: Type -> Type) | block -> marking

    create :: PrimType ty => CountOf ty -> (Offset ty -> ty) -> SecureBlock marking ty
    new :: (PrimType ty, PrimMonad prim) => proxy marking -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
    thaw :: (PrimType ty, PrimMonad m) => SecureBlock marking ty -> m (MutableBlock ty (PrimState m))
    unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (SecureBlock marking ty)

#ifdef ML_KEM_TESTING
    eq :: PrimType ty => SecureBlock marking ty -> SecureBlock marking ty -> Bool
    showsPrec :: (PrimType ty, Show ty) => Int -> SecureBlock marking ty -> ShowS
    lengthBlock :: PrimType ty => SecureBlock marking ty -> CountOf ty
#endif

    type SecureBytes marking = bytes | bytes -> marking
    unsafeCreate :: Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes marking
    lengthBytes :: SecureBytes marking -> Int
    copyByteArrayToPtr :: SecureBytes marking -> Ptr a -> IO ()

instance Classified Pub where
    type SecureBlock Pub = Block

    create = Block.create
    {-# INLINE create #-}
    new _ = Block.new
    thaw = Block.thaw
    unsafeFreeze = Block.unsafeFreeze

#ifdef ML_KEM_TESTING
    eq = (==)
    showsPrec = Prelude.showsPrec
    lengthBlock = Block.length
#endif

    type SecureBytes Pub = Bytes
    unsafeCreate = ST.unsafeCreate
    {-# INLINE unsafeCreate #-}
    lengthBytes = B.length
    copyByteArrayToPtr = B.copyByteArrayToPtr

instance Classified Sec where
    type SecureBlock Sec = ScrubbedBlock

    create = ScrubbedBlock.create
    {-# INLINE create #-}
    new _ = ScrubbedBlock.new
    thaw = ScrubbedBlock.thaw
    unsafeFreeze = ScrubbedBlock.unsafeFreeze

#ifdef ML_KEM_TESTING
    eq = (==)
    showsPrec = Prelude.showsPrec
    lengthBlock = ScrubbedBlock.length
#endif

    type SecureBytes Sec = ScrubbedBytes
    unsafeCreate = ST.unsafeCreate
    {-# INLINE unsafeCreate #-}
    lengthBytes = B.length
    copyByteArrayToPtr = B.copyByteArrayToPtr


-- for some functions we use the fact that Block and SecureBlock have the same
-- representation and implementation

unwrap :: SecureBlock marking a -> Block a
unwrap = unsafeCoerce

wrap :: Block b -> SecureBlock marking b
wrap = unsafeCoerce

index :: PrimType ty => SecureBlock marking ty -> Offset ty -> ty
index = blockIndex . unwrap

#ifdef ML_KEM_TESTING
toList :: PrimType ty => SecureBlock marking ty -> [ty]
toList = Basement.Compat.IsList.toList . unwrap
#endif

toNormalForm :: SecureBlock marking ty -> ()
toNormalForm = Basement.NormalForm.toNormalForm . unwrap

unsafeCast :: PrimType b => SecureBlock marking a -> SecureBlock marking b
unsafeCast = wrap . Block.unsafeCast . unwrap
