-- |
-- Module      : BlockN
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A secure block with length at type level
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
module BlockN
    ( BlockN, MutableBlockN, create, index, iterModify, BlockN.map
    , BlockN.new, BlockN.read, BlockN.thaw, BlockN.unsafeCast
    , BlockN.unsafeFreeze, BlockN.write, BlockN.zipWith
#ifdef ML_KEM_TESTING
    , BlockN.fromList, BlockN.replicate, BlockN.toList
#endif
    ) where

import Basement.Monad
import Basement.Nat
import Basement.NormalForm
import Basement.PrimType
import Basement.Types.OffsetSize

import Data.Proxy

import Block (MutableBlock, blockRead, blockWrite)
import Marking (Classified, SecurityMarking)
import SecureBlock (SecureBlock)
import qualified SecureBlock
import Math

newtype BlockN marking (n :: Nat) a = BlockN { unBlockN :: SecureBlock marking a }

#ifdef ML_KEM_TESTING
instance (Classified marking, PrimType a) => Eq (BlockN marking n a) where
    BlockN a == BlockN b = SecureBlock.eq a b

instance (Classified marking, PrimType a, Show a) => Show (BlockN marking n a) where
    showsPrec d = SecureBlock.showsPrec d . unBlockN
#endif

instance Classified marking => NormalForm (BlockN marking n a) where
    toNormalForm = SecureBlock.toNormalForm . unBlockN

instance (Classified marking, KnownNat n, PrimType a, Add a) => Add (BlockN marking n a) where
    zero = create (const zero)
    {-# INLINE zero #-}
    (.+) = BlockN.zipWith (.+)
    {-# INLINE (.+) #-}
    (.-) = BlockN.zipWith (.-)
    {-# INLINE (.-) #-}
    neg = BlockN.map neg
    {-# INLINE neg #-}

newtype MutableBlockN (marking :: SecurityMarking) (n :: Nat) a m = MutableBlockN { unMutableBlockN :: MutableBlock a m }

index :: (Classified marking, PrimType a) => BlockN marking n a -> Offset a -> a
index = SecureBlock.index . unBlockN

#ifdef ML_KEM_TESTING
replicate :: forall marking n a. (Classified marking, KnownNat n, PrimType a) => a -> BlockN marking n a
replicate = BlockN . SecureBlock.replicate sz
  where !sz = fromIntegral $ natVal (Proxy :: Proxy n)

fromList :: forall marking n a. (Classified marking, KnownNat n, PrimType a) => [a] -> Maybe (BlockN marking n a)
fromList elems
    | SecureBlock.length a == CountOf sz = Just (BlockN a)
    | otherwise = Nothing
  where
    a = SecureBlock.fromList elems
    !sz = fromIntegral $ natVal (Proxy :: Proxy n)

toList :: (Classified marking, PrimType a) => BlockN marking n a -> [a]
toList = SecureBlock.toList . unBlockN
#endif

create :: forall marking n ty. (Classified marking, KnownNat n, PrimType ty)
       => (Offset ty -> ty)
       -> BlockN marking n ty
create initializer = BlockN $ SecureBlock.create (CountOf sz) initializer
  where !sz = fromIntegral $ natVal (Proxy :: Proxy n)
{-# INLINE create #-}

map :: (Classified marking, PrimType a, PrimType b) => (a -> b) -> BlockN marking n a -> BlockN marking n b
map f = BlockN . SecureBlock.map f . unBlockN
{-# INLINE map #-}

iterModify :: (PrimType ty, PrimMonad prim)
           => (ty -> ty)
           -> MutableBlockN marking n ty (PrimState prim)
           -> prim ()
iterModify f = SecureBlock.iterModify f . unMutableBlockN
{-# INLINE iterModify #-}

zipWith :: (Classified ma, Classified mb, Classified mc, KnownNat n, PrimType a, PrimType b, PrimType c)
        => (a -> b -> c) -> BlockN ma n a -> BlockN mb n b -> BlockN mc n c
zipWith f (BlockN !a) (BlockN !b) =
    create $ \(Offset i) ->
        f (SecureBlock.index a (Offset i)) (SecureBlock.index b (Offset i))
{-# INLINE zipWith #-}

unsafeCast :: (Classified marking, PrimType b) => BlockN marking n a -> SecureBlock marking b
unsafeCast = SecureBlock.unsafeCast . unBlockN

read :: (PrimMonad prim, PrimType a) => MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
read = blockRead . unMutableBlockN

write :: (PrimMonad prim, PrimType a) => MutableBlockN marking n a (PrimState prim) -> Offset a -> a -> prim ()
write = blockWrite . unMutableBlockN

new :: forall proxy marking n a prim. (Classified marking, KnownNat n, PrimMonad prim, PrimType a) => proxy marking -> prim (MutableBlockN marking n a (PrimState prim))
new prx = MutableBlockN <$> SecureBlock.new prx (CountOf sz)
  where !sz = fromIntegral $ natVal (Proxy :: Proxy n)

thaw :: (Classified marking, PrimMonad prim, PrimType a) => BlockN marking n a -> prim (MutableBlockN marking n a (PrimState prim))
thaw = fmap MutableBlockN . SecureBlock.thaw . unBlockN

unsafeFreeze :: (Classified marking, PrimMonad prim) => MutableBlockN marking n a (PrimState prim) -> prim (BlockN marking n a)
unsafeFreeze = fmap BlockN . SecureBlock.unsafeFreeze . unMutableBlockN
