-- |
-- Module      : ScrubbedBlock
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A block that is always pinned in memory and automatically erased by a
-- finalizer when not referenced anymore.  Same pattern as ScrubbedBytes from
-- package memory but for blocks.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
module ScrubbedBlock
    ( ScrubbedBlock, create, foldZipWith, index
    , ScrubbedBlock.length, ScrubbedBlock.map, new
    , thaw, unsafeCast, unsafeFreeze, Block.withMutablePtr
#ifdef ML_KEM_TESTING
    , ScrubbedBlock.fromList, ScrubbedBlock.replicate, ScrubbedBlock.toList
#endif
    , erasePtr
    ) where

import Basement.Block (Block(..), MutableBlock(..), isPinned)
import Basement.Block.Mutable (mutableLengthBytes, unsafeCopyBytesRO)
import Basement.Compat.Primitive

#ifdef ML_KEM_TESTING
import Basement.Compat.IsList
#endif
import Basement.Monad
import Basement.NormalForm
import Basement.PrimType
import Basement.Types.OffsetSize

import Control.Exception (assert)
import Control.Monad.ST

import Data.Word

import Foreign.Ptr (Ptr)

import Block (blockIndex, blockWrite)
import qualified Block

#if MIN_VERSION_base(4,19,0)
import GHC.Base (Int(I#), setAddrRange#)
import GHC.Exts (Ptr(Ptr))
#else
import Data.Memory.PtrMethods (memSet)
#endif
import GHC.Base (IO(IO), setByteArray#)
import GHC.Exts (getSizeofMutableByteArray#, mkWeak#)

newtype ScrubbedBlock ty = ScrubbedBlock (Block ty)
    deriving (Eq, Show, NormalForm)

create :: PrimType ty
       => CountOf ty
       -> (Offset ty -> ty)
       -> ScrubbedBlock ty
create n initializer = runST $ do
    mb <- new n
    loop mb 0
    unsafeFreeze mb
  where
    loop !mb i
        | i .==# n = pure ()
        | otherwise = blockWrite mb i (initializer i) >> loop mb (i + 1)
{-# INLINE create #-}

foldZipWith :: (PrimType a, PrimType b)
            => (c -> a -> b -> c) -> c -> ScrubbedBlock a -> ScrubbedBlock b -> c
foldZipWith f c (ScrubbedBlock a) (ScrubbedBlock b) =
    Block.foldZipWith f c a b
{-# INLINE foldZipWith #-}

index :: PrimType ty => ScrubbedBlock ty -> Offset ty -> ty
index (ScrubbedBlock b) = blockIndex b

length :: PrimType ty => ScrubbedBlock ty -> CountOf ty
length (ScrubbedBlock b) = Block.length b

map :: (PrimType a, PrimType b) => (a -> b) -> ScrubbedBlock a -> ScrubbedBlock b
map f (ScrubbedBlock b) =
    create (CountOf n) $ \(Offset i) -> f (blockIndex b (Offset i))
  where
    CountOf n = Block.length b
{-# INLINE map #-}

new :: (PrimType ty, PrimMonad prim) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
new = Block.newPinned  -- always pinned

thaw :: (PrimType ty, PrimMonad m) => ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
thaw (ScrubbedBlock b) = do
    mb <- new (Block.length b)
    unsafeCopyBytesRO mb 0 b 0 (mutableLengthBytes mb)
    return mb

unsafeCast :: PrimType b => ScrubbedBlock a -> ScrubbedBlock b
unsafeCast (ScrubbedBlock b) = ScrubbedBlock (Block.unsafeCast b)

unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
unsafeFreeze mb = Block.unsafeFreeze mb >>= scrubbed

#ifdef ML_KEM_TESTING
replicate :: PrimType ty => CountOf ty -> ty -> ScrubbedBlock ty
replicate n e = create n (const e)

fromList :: PrimType ty => [ty] -> ScrubbedBlock ty
fromList elems = runST $ do
    mb <- new (CountOf len)
    go mb 0 elems
    unsafeFreeze mb
  where
    !len = Prelude.length elems

    go !mb !i list = case list of
        []     -> return ()
        (x:xs) -> blockWrite mb i x >> go mb (i + 1) xs

toList :: PrimType ty => ScrubbedBlock ty -> [ty]
toList (ScrubbedBlock b) = Basement.Compat.IsList.toList b
#endif


{- internal -}

assertPinned :: Block ty -> a -> a
assertPinned mb = assert (isPinned mb == Pinned)

scrubbed :: PrimMonad prim => Block ty -> prim (ScrubbedBlock ty)
scrubbed b = assertPinned b $ unsafePrimFromIO $ do
    addBlockFinalizer b (scrub $ Block.unsafeCast b)
    return (ScrubbedBlock b)

scrub :: Block Word8 -> IO ()
scrub b = Block.unsafeThaw b >>= erase

addBlockFinalizer :: Block ty -> IO () -> IO ()
addBlockFinalizer (Block barr) (IO finalizer) = IO $ \s ->
   case mkWeak# barr () finalizer s of { (# s1, _ #) -> (# s1, () #) }

erase :: MutableBlock ty RealWorld -> IO ()
erase (MutableBlock mbarr) = IO $ \s1 ->
    case getSizeofMutableByteArray# mbarr s1 of
        (# s2, len #) -> case setByteArray# mbarr 0# len 0# s2 of
            s3 -> (# s3, () #)

erasePtr :: Int -> Ptr Word8 -> IO ()
#if MIN_VERSION_base(4,19,0)
erasePtr (I# n) (Ptr addr) = IO $ \s1 ->
    case setAddrRange# addr n 0# s1 of
        s2 -> (# s2, () #)
#else
erasePtr n ptr = memSet ptr 0 n
#endif
