{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}

-- |
-- Module      : Data.MemPack.Buffer
-- Copyright   : (c) Alexey Kuleshevich 2024
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <alexey@kuleshevi.ch>
-- Stability   : experimental
-- Portability : non-portable
module Data.MemPack.Buffer where

import Data.Array.Byte
import qualified Data.ByteString as BS
import qualified Data.ByteString.Short.Internal as SBS
import qualified Data.ByteString.Internal as BS
import GHC.Exts
import GHC.ST
import GHC.ForeignPtr

-- | Immutable memory buffer
class Buffer b where
  bufferByteCount :: b -> Int

  buffer :: b -> (ByteArray# -> a) -> (Addr# -> a) -> a

instance Buffer ByteArray where
  bufferByteCount :: ByteArray -> Int
bufferByteCount (ByteArray ByteArray#
ba#) = Int# -> Int
I# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
ba#)
  {-# INLINE bufferByteCount #-}

  buffer :: forall a. ByteArray -> (ByteArray# -> a) -> (Addr# -> a) -> a
buffer (ByteArray ByteArray#
ba#) ByteArray# -> a
f Addr# -> a
_ = ByteArray# -> a
f ByteArray#
ba#
  {-# INLINE buffer #-}

instance Buffer SBS.ShortByteString where
  bufferByteCount :: ShortByteString -> Int
bufferByteCount = ShortByteString -> Int
SBS.length
  {-# INLINE bufferByteCount #-}

  buffer :: forall a. ShortByteString -> (ByteArray# -> a) -> (Addr# -> a) -> a
buffer (SBS.SBS ByteArray#
ba#) ByteArray# -> a
f Addr# -> a
_ = ByteArray# -> a
f ByteArray#
ba#
  {-# INLINE buffer #-}

instance Buffer BS.ByteString where
  bufferByteCount :: ByteString -> Int
bufferByteCount = ByteString -> Int
BS.length
  {-# INLINE bufferByteCount #-}

  buffer :: forall a. ByteString -> (ByteArray# -> a) -> (Addr# -> a) -> a
buffer ByteString
bs ByteArray# -> a
_ Addr# -> a
f =
    (forall s. ST s a) -> a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s a) -> a) -> (forall s. ST s a) -> a
forall a b. (a -> b) -> a -> b
$ ByteString -> (Ptr Any -> ST s a) -> ST s a
forall a s b. ByteString -> (Ptr a -> ST s b) -> ST s b
withPtrByteStringST ByteString
bs ((Ptr Any -> ST s a) -> ST s a) -> (Ptr Any -> ST s a) -> ST s a
forall a b. (a -> b) -> a -> b
$ \(Ptr Addr#
addr#) -> a -> ST s a
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> ST s a) -> a -> ST s a
forall a b. (a -> b) -> a -> b
$! Addr# -> a
f Addr#
addr#
  {-# INLINE buffer #-}


newMutableByteArray :: Bool -> Int -> ST s (MutableByteArray s)
newMutableByteArray :: forall s. Bool -> Int -> ST s (MutableByteArray s)
newMutableByteArray Bool
isPinned (I# Int#
len#) =
  STRep s (MutableByteArray s) -> ST s (MutableByteArray s)
forall s a. STRep s a -> ST s a
ST (STRep s (MutableByteArray s) -> ST s (MutableByteArray s))
-> STRep s (MutableByteArray s) -> ST s (MutableByteArray s)
forall a b. (a -> b) -> a -> b
$ \State# s
s# -> case (if Bool
isPinned then Int# -> State# s -> (# State# s, MutableByteArray# s #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newPinnedByteArray# else Int# -> State# s -> (# State# s, MutableByteArray# s #)
forall d. Int# -> State# d -> (# State# d, MutableByteArray# d #)
newByteArray#) Int#
len# State# s
s# of
    (# State# s
s'#, MutableByteArray# s
mba# #) -> (# State# s
s'#, MutableByteArray# s -> MutableByteArray s
forall s. MutableByteArray# s -> MutableByteArray s
MutableByteArray MutableByteArray# s
mba# #)
{-# INLINE newMutableByteArray #-}

freezeMutableByteArray :: MutableByteArray d -> ST d ByteArray
freezeMutableByteArray :: forall d. MutableByteArray d -> ST d ByteArray
freezeMutableByteArray (MutableByteArray MutableByteArray# d
mba#) =
  STRep d ByteArray -> ST d ByteArray
forall s a. STRep s a -> ST s a
ST (STRep d ByteArray -> ST d ByteArray)
-> STRep d ByteArray -> ST d ByteArray
forall a b. (a -> b) -> a -> b
$ \State# d
s# -> case MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
forall d.
MutableByteArray# d -> State# d -> (# State# d, ByteArray# #)
unsafeFreezeByteArray# MutableByteArray# d
mba# State# d
s# of
    (# State# d
s'#, ByteArray#
ba# #) -> (# State# d
s'#, ByteArray# -> ByteArray
ByteArray ByteArray#
ba# #)

-- | It is ok to use ByteString withing ST, as long as underlying pointer is never mutated
-- or returned from the supplied action.
withPtrByteStringST :: BS.ByteString -> (Ptr a -> ST s b) -> ST s b
#if MIN_VERSION_bytestring(0,11,0)
withPtrByteStringST :: forall a s b. ByteString -> (Ptr a -> ST s b) -> ST s b
withPtrByteStringST (BS.BS (ForeignPtr Addr#
addr# ForeignPtrContents
ptrContents) Int
_) Ptr a -> ST s b
f = do
#else
withPtrByteStringST (BS.PS (ForeignPtr addr0# ptrContents) (I# offset#) _) f = do
  let !addr# = addr0# `plusAddr#` offset#
#endif
  !b
r <- Ptr a -> ST s b
f (Addr# -> Ptr a
forall a. Addr# -> Ptr a
Ptr Addr#
addr#)
  -- It is safe to use `touch#` within ST, so `unsafeCoerce#` is OK
  STRep s () -> ST s ()
forall s a. STRep s a -> ST s a
ST (STRep s () -> ST s ()) -> STRep s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ \State# s
s# -> (# State# RealWorld -> State# s
forall a b. a -> b
unsafeCoerce# (ForeignPtrContents -> State# RealWorld -> State# RealWorld
forall a. a -> State# RealWorld -> State# RealWorld
touch# ForeignPtrContents
ptrContents (State# s -> State# RealWorld
forall a b. a -> b
unsafeCoerce# State# s
s#)), () #)
  b -> ST s b
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
r
{-# INLINE withPtrByteStringST #-}

pinnedByteArrayToByteString :: ByteArray -> BS.ByteString
pinnedByteArrayToByteString :: ByteArray -> ByteString
pinnedByteArrayToByteString (ByteArray ByteArray#
ba#) =
  ForeignPtr Word8 -> Int -> Int -> ByteString
BS.PS (ByteArray# -> ForeignPtr Word8
forall a. ByteArray# -> ForeignPtr a
pinnedByteArrayToForeignPtr ByteArray#
ba#) Int
0 (Int# -> Int
I# (ByteArray# -> Int#
sizeofByteArray# ByteArray#
ba#))
{-# INLINE pinnedByteArrayToByteString #-}

pinnedByteArrayToForeignPtr :: ByteArray# -> ForeignPtr a
pinnedByteArrayToForeignPtr :: forall a. ByteArray# -> ForeignPtr a
pinnedByteArrayToForeignPtr ByteArray#
ba# =
  Addr# -> ForeignPtrContents -> ForeignPtr a
forall a. Addr# -> ForeignPtrContents -> ForeignPtr a
ForeignPtr (ByteArray# -> Addr#
byteArrayContents# ByteArray#
ba#) (MutableByteArray# RealWorld -> ForeignPtrContents
PlainPtr (ByteArray# -> MutableByteArray# RealWorld
forall a b. a -> b
unsafeCoerce# ByteArray#
ba#))
{-# INLINE pinnedByteArrayToForeignPtr #-}


byteArrayToShortByteString :: ByteArray -> SBS.ShortByteString
byteArrayToShortByteString :: ByteArray -> ShortByteString
byteArrayToShortByteString (ByteArray ByteArray#
ba#) = ByteArray# -> ShortByteString
SBS.SBS ByteArray#
ba#
{-# INLINE byteArrayToShortByteString #-}

byteArrayFromShortByteString :: SBS.ShortByteString -> ByteArray
byteArrayFromShortByteString :: ShortByteString -> ByteArray
byteArrayFromShortByteString (SBS.SBS ByteArray#
ba#) = ByteArray# -> ByteArray
ByteArray ByteArray#
ba#
{-# INLINE byteArrayFromShortByteString #-}