-- |
-- Module      : Block
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- An array of primitive (unlifted) elements.  This module currently exposes
-- the implementation from basement and fixes a lack of inlining in the
-- @create@ function.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
module Block
    ( Block, MutableBlock, blockIndex, blockRead, blockWrite
    , create, foldZipWith, iterModify, Block.length
    , Block.new, Block.newPinned, Block.thaw, Block.unsafeCast
    , Block.unsafeFreeze, Block.unsafeThaw, Block.withMutablePtr
    ) where

import Basement.Block (Block)
import Basement.Block.Mutable (MutableBlock)
import qualified Basement.Block as Block hiding (create, map)
import qualified Basement.Block.Mutable as Block
import Basement.Monad
import Basement.PrimType
import Basement.Types.OffsetSize

import Control.Exception (assert)
import Control.Monad.ST

blockIndex :: PrimType ty => Block ty -> Offset ty -> ty
blockRead :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockWrite :: (PrimMonad prim, PrimType ty) => MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
#ifdef ML_KEM_TESTING
blockIndex = Block.index
blockRead = Block.read
blockWrite = Block.write
#else
blockIndex = Block.unsafeIndex
blockRead = Block.unsafeRead
blockWrite = Block.unsafeWrite
#endif

create :: PrimType ty
       => CountOf ty
       -> (Offset ty -> ty)
       -> Block ty
create n initializer = runST $ do
    mb <- Block.new n
    loop mb 0
    Block.unsafeFreeze mb
  where
    loop !mb i
        | i .==# n = pure ()
        | otherwise = Block.unsafeWrite mb i (initializer i) >> loop mb (i + 1)
{-# INLINE create #-}

iterModify :: (PrimType ty, PrimMonad prim)
           => (ty -> ty)
           -> MutableBlock ty (PrimState prim)
           -> prim ()
iterModify f ma = loop 0
  where
    !sz = Block.mutableLength ma
    loop i
        | i .==# sz = pure ()
        | otherwise = Block.unsafeRead ma i >>= \x -> Block.unsafeWrite ma i (f x) >> loop (i+1)
{-# INLINE iterModify #-}

foldZipWith :: (PrimType a, PrimType b)
            => (c -> a -> b -> c) -> c -> Block a -> Block b -> c
foldZipWith f c a b = assert (sa == sb) $
    loop c 0
  where
    CountOf sa = Block.length a
    CountOf sb = Block.length b

    loop !acc i
        | i == sa = acc
        | otherwise = do
            let va = Block.unsafeIndex a (Offset i)
            let vb = Block.unsafeIndex b (Offset i)
            loop (f acc va vb) (i + 1)
{-# INLINE foldZipWith #-}
