{-# LANGUAGE BangPatterns  #-}
{-# LANGUAGE MagicHash     #-}
{-# LANGUAGE TypeFamilies  #-}
{-# LANGUAGE UnboxedTuples #-}
-- |Memory access primitives
module Data.Flat.Memory (
    chunksToByteString,
    chunksToByteArray,
    ByteArray(..),
    pokeByteArray,
    pokeByteString,
    unsafeCreateUptoN',
    minusPtr,
    ) where

import           Control.Monad
import           Control.Monad.Primitive  (PrimMonad (..))
import qualified Data.ByteString.Internal as BS
import           Data.Primitive.ByteArray
import           Foreign                  hiding (void)
import           GHC.Prim                 (copyAddrToByteArray#,copyByteArrayToAddr#)
import           GHC.Ptr                  (Ptr (..))
import           GHC.Types                (IO (..), Int (..))
import           System.IO.Unsafe
import qualified Data.ByteString                as B

unsafeCreateUptoN' :: Int -> (Ptr Word8 -> IO (Int, a)) -> (BS.ByteString, a)
unsafeCreateUptoN' l f = unsafeDupablePerformIO (createUptoN' l f)
{-# INLINE unsafeCreateUptoN' #-}

createUptoN' :: Int -> (Ptr Word8 -> IO (Int, a)) -> IO (BS.ByteString, a)
createUptoN' l f = do
  fp <- BS.mallocByteString l
  (l', res) <- withForeignPtr fp $ \p -> f p
  --print (unwords ["Buffer allocated:",show l,"bytes, used:",show l',"bytes"])
  when (l'> l) $ error (unwords ["Buffer overflow, allocated:",show l,"bytes, used:",show l',"bytes"])
  return (BS.PS fp 0 l', res) -- , minusPtr l')
{-# INLINE createUptoN' #-}

-- |Copy bytestring to given pointer, returns new pointer
pokeByteString :: B.ByteString -> Ptr Word8 -> IO (Ptr Word8)
pokeByteString (BS.PS foreignPointer sourceOffset sourceLength) destPointer = do
    withForeignPtr foreignPointer $ \sourcePointer ->
      BS.memcpy destPointer (sourcePointer `plusPtr` sourceOffset) sourceLength
    return (destPointer `plusPtr` sourceLength)

pokeByteArray :: ByteArray# -> Int -> Int -> Ptr Word8 -> IO (Ptr Word8)
pokeByteArray sourceArr sourceOffset len dest = do
        copyByteArrayToAddr sourceArr sourceOffset dest len
        let !dest' = dest `plusPtr` len
        return dest'
{-# INLINE pokeByteArray #-}

-- | Wrapper around @copyByteArrayToAddr#@ primop.
-- Copied from the store-core package
copyByteArrayToAddr :: ByteArray# -> Int -> Ptr a -> Int -> IO ()
copyByteArrayToAddr arr (I# offset) (Ptr addr) (I# len) =
    IO (\s -> (# copyByteArrayToAddr# arr offset addr len s, () #))
{-# INLINE copyByteArrayToAddr  #-}

-- toByteString :: Ptr Word8 -> Int -> BS.ByteString
-- toByteString sourcePtr sourceLength = BS.unsafeCreate sourceLength $ \destPointer -> BS.memcpy destPointer sourcePtr sourceLength

chunksToByteString :: (Ptr Word8,[Int]) -> BS.ByteString
chunksToByteString (sourcePtr,lens) =
  BS.unsafeCreate (sum lens) $ \destPtr -> void $ foldM (\(destPtr,sourcePtr) sourceLength -> BS.memcpy destPtr sourcePtr sourceLength >> return (destPtr `plusPtr` sourceLength,sourcePtr `plusPtr` (sourceLength+1))) (destPtr,sourcePtr) lens

chunksToByteArray :: (Ptr Word8,[Int]) -> (ByteArray,Int)
chunksToByteArray (sourcePtr,lens) = unsafePerformIO $ do
  let len = sum lens
  arr <- newByteArray len
  foldM_ (\(destOff,sourcePtr) sourceLength -> copyAddrToByteArray sourcePtr arr destOff sourceLength >> return (destOff + sourceLength,sourcePtr `plusPtr` (sourceLength+1))) (0,sourcePtr) lens
  farr <- unsafeFreezeByteArray arr
  return (farr,len)

-- from store-core
-- | Wrapper around @copyAddrToByteArray#@ primop.
copyAddrToByteArray :: Ptr a -> MutableByteArray (PrimState IO) -> Int -> Int -> IO ()
copyAddrToByteArray (Ptr addr) (MutableByteArray arr) (I# offset) (I# len) =
    IO (\s -> (# copyAddrToByteArray# addr arr offset len s, () #))
{-# INLINE copyAddrToByteArray  #-}