{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilyDependencies #-}
module Marking
    ( SecurityMarking(..), Classified(..), Leak(..)
    ) where

import Basement.Monad
import Basement.NormalForm
import Basement.PrimType
import Basement.Types.OffsetSize
#ifdef ML_KEM_TESTING
import qualified Basement.Compat.IsList
#endif

import Control.Monad.ST

import Data.ByteArray (Bytes, ScrubbedBytes)
import qualified Data.ByteArray as B

import Data.Kind

import Foreign.Ptr (Ptr)

import Unsafe.Coerce

import Block (Block, MutableBlock, blockIndex)
import ScrubbedBlock (ScrubbedBlock)
import qualified Block
import qualified ByteArrayST as ST
import qualified ScrubbedBlock

data SecurityMarking = Sec | Pub  -- secret or public information

-- Transformation called only at expected location in the LWE problem, after
-- adding noise to secret information.
--
-- Block and ScrubbedBlock have the same representation, we can force coercion
-- from Sec to Pub even though the block will be actually scrubbed.  This is
-- simpler than copying to a real non-scrubbed block.
class Leak t where
    leak :: t Sec -> t Pub
    leak = unsafeCoerce

class Classified (marking :: SecurityMarking) where
    type SecureBlock marking = (block :: Type -> Type) | block -> marking

    create :: PrimType ty => CountOf ty -> (Offset ty -> ty) -> SecureBlock marking ty
    index :: PrimType ty => SecureBlock marking ty -> Offset ty -> ty
    map :: (PrimType a, PrimType b) => (a -> b) -> SecureBlock marking a -> SecureBlock marking b
    new :: (PrimType ty, PrimMonad prim) => proxy marking -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
    thaw :: (PrimType ty, PrimMonad m) => SecureBlock marking ty -> m (MutableBlock ty (PrimState m))
    unsafeCast :: PrimType b => SecureBlock marking a -> SecureBlock marking b
    unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (SecureBlock marking ty)

    toNormalForm :: SecureBlock marking ty -> ()
#ifdef ML_KEM_TESTING
    eq :: PrimType ty => SecureBlock marking ty -> SecureBlock marking ty -> Bool
    showsPrec :: (PrimType ty, Show ty) => Int -> SecureBlock marking ty -> ShowS
    fromList :: PrimType ty => [ty] -> SecureBlock marking ty
    replicate :: PrimType ty => CountOf ty -> ty -> SecureBlock marking ty
    toList :: PrimType ty => SecureBlock marking ty -> [ty]
    lengthBlock :: PrimType ty => SecureBlock marking ty -> CountOf ty
#endif

    type SecureBytes marking = bytes | bytes -> marking
    unsafeCreate :: Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes marking
    lengthBytes :: SecureBytes marking -> Int
    copyByteArrayToPtr :: SecureBytes marking -> Ptr a -> IO ()

instance Classified Pub where
    type SecureBlock Pub = Block

    create = Block.create
    {-# INLINE create #-}
    index = blockIndex
    map = Block.map
    {-# INLINE map #-}
    new _ = Block.new
    thaw = Block.thaw
    unsafeCast = Block.unsafeCast
    unsafeFreeze = Block.unsafeFreeze

    toNormalForm = Basement.NormalForm.toNormalForm
#ifdef ML_KEM_TESTING
    eq = (==)
    showsPrec = Prelude.showsPrec
    fromList = Basement.Compat.IsList.fromList
    replicate = Block.replicate
    toList = Basement.Compat.IsList.toList
    lengthBlock = Block.length
#endif

    type SecureBytes Pub = Bytes
    unsafeCreate = ST.unsafeCreate
    {-# INLINE unsafeCreate #-}
    lengthBytes = B.length
    copyByteArrayToPtr = B.copyByteArrayToPtr

instance Classified Sec where
    type SecureBlock Sec = ScrubbedBlock

    create = ScrubbedBlock.create
    {-# INLINE create #-}
    index = ScrubbedBlock.index
    map = ScrubbedBlock.map
    {-# INLINE map #-}
    new _ = ScrubbedBlock.new
    thaw = ScrubbedBlock.thaw
    unsafeCast = ScrubbedBlock.unsafeCast
    unsafeFreeze = ScrubbedBlock.unsafeFreeze

    toNormalForm = Basement.NormalForm.toNormalForm
#ifdef ML_KEM_TESTING
    eq = (==)
    showsPrec = Prelude.showsPrec
    fromList = ScrubbedBlock.fromList
    replicate = ScrubbedBlock.replicate
    toList = ScrubbedBlock.toList
    lengthBlock = ScrubbedBlock.length
#endif

    type SecureBytes Sec = ScrubbedBytes
    unsafeCreate = ST.unsafeCreate
    {-# INLINE unsafeCreate #-}
    lengthBytes = B.length
    copyByteArrayToPtr = B.copyByteArrayToPtr
