-- |
-- Module      : ScrubbedBlock
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A block that is always pinned in memory and automatically erased by a
-- finalizer when not referenced anymore.  Same pattern as ScrubbedBytes from
-- package memory but for blocks.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
module ScrubbedBlock
    ( ScrubbedBlock, create, foldZipWith, ScrubbedBlock.length
    , new, thaw, unsafeFreeze, Block.withMutablePtr
    , erasePtr
    ) where

import Basement.Block (Block(..), MutableBlock(..), isPinned)
import Basement.Block.Mutable (mutableLengthBytes, unsafeCopyBytesRO)
import Basement.Compat.Primitive

import Basement.Monad
import Basement.PrimType
import Basement.Types.OffsetSize

import Control.Exception (assert)
import Control.Monad.ST

import Data.Word

import Foreign.Ptr (Ptr)

import Block (blockWrite)
import qualified Block

#if MIN_VERSION_base(4,19,0)
import GHC.Base (setAddrRange#)
import GHC.Exts (Ptr(Ptr))
#else
import Data.Memory.PtrMethods (memSet)
#endif
import GHC.Base (IO(IO), Int(I#), setByteArray#)
import GHC.Exts (mkWeak#)

newtype ScrubbedBlock ty = ScrubbedBlock (Block ty)
    deriving (Eq, Show)

create :: PrimType ty
       => CountOf ty
       -> (Offset ty -> ty)
       -> ScrubbedBlock ty
create n initializer = runST $ do
    mb <- new n
    loop mb 0
    unsafeFreeze mb
  where
    loop !mb i
        | i .==# n = pure ()
        | otherwise = blockWrite mb i (initializer i) >> loop mb (i + 1)
{-# INLINE create #-}

foldZipWith :: (PrimType a, PrimType b)
            => (c -> a -> b -> c) -> c -> ScrubbedBlock a -> ScrubbedBlock b -> c
foldZipWith f c (ScrubbedBlock a) (ScrubbedBlock b) =
    Block.foldZipWith f c a b
{-# INLINE foldZipWith #-}

length :: PrimType ty => ScrubbedBlock ty -> CountOf ty
length (ScrubbedBlock b) = Block.length b

new :: (PrimType ty, PrimMonad prim) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
new = Block.newPinned  -- always pinned

thaw :: (PrimType ty, PrimMonad m) => ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
thaw (ScrubbedBlock !b) = do
    mb <- new (Block.length b)
    unsafeCopyBytesRO mb 0 b 0 (mutableLengthBytes mb)
    return mb

unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
unsafeFreeze mb = Block.unsafeFreeze mb >>= scrubbed


{- internal -}

assertPinned :: Block ty -> a -> a
assertPinned mb = assert (isPinned mb == Pinned)

scrubbed :: PrimMonad prim => Block ty -> prim (ScrubbedBlock ty)
scrubbed b = assertPinned b $ unsafePrimFromIO $
    scheduleBlockScrubbing b >> return (ScrubbedBlock b)

scheduleBlockScrubbing :: Block ty -> IO ()
scheduleBlockScrubbing b = addBlockFinalizer b (scrub $ Block.unsafeCast b)
{-# NOINLINE scheduleBlockScrubbing #-}

scrub :: Block Word8 -> IO ()
scrub b = Block.unsafeThaw b >>= erase len
  where CountOf len = Block.length b

addBlockFinalizer :: Block ty -> IO () -> IO ()
addBlockFinalizer (Block barr) (IO finalizer) = IO $ \s ->
   case mkWeak# barr () finalizer s of { (# s1, _ #) -> (# s1, () #) }

erase :: Int -> MutableBlock Word8 RealWorld -> IO ()
erase (I# len) (MutableBlock mbarr) = IO $ \s1 ->
    case setByteArray# mbarr 0# len 0# s1 of
        s2 -> (# s2, () #)

erasePtr :: Int -> Ptr Word8 -> IO ()
#if MIN_VERSION_base(4,19,0)
erasePtr (I# n) (Ptr addr) = IO $ \s1 ->
    case setAddrRange# addr n 0# s1 of
        s2 -> (# s2, () #)
#else
erasePtr n ptr = memSet ptr 0 n
#endif
