{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
module Data.Prim.Memory.Bytes.Internal
( Bytes(..)
, MBytes(..)
, Pinned(..)
, isSameBytes
, isSamePinnedBytes
, isPinnedBytes
, isPinnedMBytes
, castPinnedBytes
, castPinnedMBytes
, allocMBytes
, allocPinnedMBytes
, allocAlignedMBytes
, allocUnpinnedMBytes
, callocAlignedMBytes
, reallocMBytes
, freezeMBytes
, thawBytes
, shrinkMBytes
, resizeMBytes
, indexOffBytes
, indexByteOffBytes
, compareByteOffBytes
, byteCountBytes
, countBytes
, getCountMBytes
, getByteCountMBytes
, setMBytes
, copyByteOffBytesToMBytes
, moveByteOffMBytesToMBytes
, readOffMBytes
, readByteOffMBytes
, writeOffMBytes
, writeByteOffMBytes
, toPtrBytes
, toPtrMBytes
, withPtrBytes
, withPtrMBytes
, withNoHaltPtrBytes
, withNoHaltPtrMBytes
, toForeignPtrBytes
, toForeignPtrMBytes
, fromForeignPtrBytes
, byteStringConvertError
) where
import Control.DeepSeq
import Control.Prim.Monad
import Control.Prim.Monad.Unsafe
import Data.Prim
import Data.Prim.Class
import GHC.ForeignPtr
import Data.Typeable
import Foreign.Prim
data Pinned = Pin | Inc
data Bytes (p :: Pinned) = Bytes ByteArray#
type role Bytes nominal
data MBytes (p :: Pinned) s = MBytes (MutableByteArray# s)
type role MBytes nominal nominal
instance NFData (Bytes p) where
rnf (Bytes _) = ()
instance NFData (MBytes p s) where
rnf (MBytes _) = ()
compareByteOffBytes :: Prim e => Bytes p1 -> Off Word8 -> Bytes p2 -> Off Word8 -> Count e -> Ordering
compareByteOffBytes (Bytes b1#) (Off (I# off1#)) (Bytes b2#) (Off (I# off2#)) c =
toOrdering# (compareByteArrays# b1# off1# b2# off2# (fromCount# c))
{-# INLINE compareByteOffBytes #-}
indexOffBytes :: Prim e => Bytes p -> Off e -> e
indexOffBytes (Bytes ba#) (Off (I# i#)) = indexByteArray# ba# i#
{-# INLINE indexOffBytes #-}
indexByteOffBytes :: Prim e => Bytes p -> Off Word8 -> e
indexByteOffBytes (Bytes ba#) (Off (I# i#)) = indexByteOffByteArray# ba# i#
{-# INLINE indexByteOffBytes #-}
allocMBytes ::
forall p e s m. (Typeable p, Prim e, MonadPrim s m)
=> Count e
-> m (MBytes p s)
allocMBytes c =
case eqT :: Maybe (p :~: 'Pin) of
Just Refl -> allocPinnedMBytes c
_ ->
case eqT :: Maybe (p :~: 'Inc) of
Just Refl -> allocUnpinnedMBytes c
Nothing ->
errorImpossible
"allocMBytes"
$ "Unexpected 'Pinned' kind: '" ++ showsType (Proxy :: Proxy (Bytes p)) "'."
{-# INLINE[0] allocMBytes #-}
{-# RULES
"allocUnpinnedMBytes" allocMBytes = allocUnpinnedMBytes
"allocPinnedMBytes" allocMBytes = allocPinnedMBytes
#-}
allocUnpinnedMBytes :: (MonadPrim s m, Prim e) => Count e -> m (MBytes 'Inc s)
allocUnpinnedMBytes c =
prim $ \s ->
case newByteArray# (fromCount# c) s of
(# s', ba# #) -> (# s', MBytes ba# #)
{-# INLINE allocUnpinnedMBytes #-}
allocPinnedMBytes :: (MonadPrim s m, Prim e) => Count e -> m (MBytes 'Pin s)
allocPinnedMBytes c =
prim $ \s ->
case newPinnedByteArray# (fromCount# c) s of
(# s', ba# #) -> (# s', MBytes ba# #)
{-# INLINE allocPinnedMBytes #-}
allocAlignedMBytes ::
forall e m s. (MonadPrim s m, Prim e)
=> Count e
-> m (MBytes 'Pin s)
allocAlignedMBytes c =
prim $ \s ->
case newAlignedPinnedByteArray#
(fromCount# c)
(alignment# (proxy# :: Proxy# e))
s of
(# s', ba# #) -> (# s', MBytes ba# #)
{-# INLINE allocAlignedMBytes #-}
callocAlignedMBytes ::
(MonadPrim s m, Prim e)
=> Count e
-> m (MBytes 'Pin s)
callocAlignedMBytes n = allocAlignedMBytes n >>= \mb -> mb <$ setMBytes mb 0 (toByteCount n) 0
{-# INLINE callocAlignedMBytes #-}
getByteCountMBytes :: MonadPrim s m => MBytes p s -> m (Count Word8)
getByteCountMBytes (MBytes mba#) =
prim $ \s ->
case getSizeofMutableByteArray# mba# s of
(# s', n# #) -> (# s', Count (I# n#) #)
{-# INLINE getByteCountMBytes #-}
freezeMBytes :: MonadPrim s m => MBytes p s -> m (Bytes p)
freezeMBytes (MBytes mba#) =
prim $ \s ->
case unsafeFreezeByteArray# mba# s of
(# s', ba# #) -> (# s', Bytes ba# #)
{-# INLINE freezeMBytes #-}
thawBytes :: MonadPrim s m => Bytes p -> m (MBytes p s)
thawBytes (Bytes ba#) =
prim $ \s ->
case unsafeThawByteArray# ba# s of
(# s', mba# #) -> (# s', MBytes mba# #)
{-# INLINE thawBytes #-}
copyByteOffBytesToMBytes ::
(MonadPrim s m, Prim e) => Bytes ps -> Off Word8 -> MBytes pd s -> Off Word8 -> Count e -> m ()
copyByteOffBytesToMBytes (Bytes src#) (Off (I# srcOff#)) (MBytes dst#) (Off (I# dstOff#)) c =
prim_ $ copyByteArray# src# srcOff# dst# dstOff# (fromCount# c)
{-# INLINE copyByteOffBytesToMBytes #-}
moveByteOffMBytesToMBytes ::
(MonadPrim s m, Prim e) => MBytes ps s-> Off Word8 -> MBytes pd s -> Off Word8 -> Count e -> m ()
moveByteOffMBytesToMBytes (MBytes src#) (Off (I# srcOff#)) (MBytes dst#) (Off (I# dstOff#)) c =
prim_ (copyMutableByteArray# src# srcOff# dst# dstOff# (fromCount# c))
{-# INLINE moveByteOffMBytesToMBytes #-}
byteCountBytes :: Bytes p -> Count Word8
byteCountBytes (Bytes ba#) = coerce (I# (sizeofByteArray# ba#))
{-# INLINE byteCountBytes #-}
shrinkMBytes :: (MonadPrim s m, Prim e) => MBytes p s -> Count e -> m ()
shrinkMBytes (MBytes mb#) c = prim_ (shrinkMutableByteArray# mb# (fromCount# c))
{-# INLINE shrinkMBytes #-}
resizeMBytes ::
(MonadPrim s m, Prim e) => MBytes p s -> Count e -> m (MBytes 'Inc s)
resizeMBytes (MBytes mb#) c =
prim $ \s ->
case resizeMutableByteArray# mb# (fromCount# c) s of
(# s', mb'# #) -> (# s', MBytes mb'# #)
{-# INLINE resizeMBytes #-}
reallocMBytes ::
forall e p m s. (MonadPrim s m, Typeable p, Prim e)
=> MBytes p s
-> Count e
-> m (MBytes p s)
reallocMBytes mb c = do
oldByteCount <- getByteCountMBytes mb
let newByteCount = toByteCount c
if newByteCount <= oldByteCount
then mb <$ when (newByteCount < oldByteCount) (shrinkMBytes mb newByteCount)
else case eqT :: Maybe (p :~: 'Pin) of
Just Refl -> do
b <- freezeMBytes mb
mb' <- allocPinnedMBytes newByteCount
mb' <$ copyByteOffBytesToMBytes b 0 mb' 0 oldByteCount
Nothing -> castPinnedMBytes <$> resizeMBytes mb newByteCount
{-# INLINABLE reallocMBytes #-}
castPinnedBytes :: Bytes p' -> Bytes p
castPinnedBytes (Bytes b#) = Bytes b#
castPinnedMBytes :: MBytes p' s -> MBytes p s
castPinnedMBytes (MBytes b#) = MBytes b#
countBytes :: Prim e => Bytes p -> Count e
countBytes = fromByteCount . byteCountBytes
{-# INLINE countBytes #-}
getCountMBytes :: (MonadPrim s m, Prim e) => MBytes p s -> m (Count e)
getCountMBytes b = fromByteCount <$> getByteCountMBytes b
{-# INLINE getCountMBytes #-}
readOffMBytes :: (MonadPrim s m, Prim e) => MBytes p s -> Off e -> m e
readOffMBytes (MBytes mba#) (Off (I# i#)) = prim (readMutableByteArray# mba# i#)
{-# INLINE readOffMBytes #-}
readByteOffMBytes :: (MonadPrim s m, Prim e) => MBytes p s -> Off Word8 -> m e
readByteOffMBytes (MBytes mba#) (Off (I# i#)) = prim (readByteOffMutableByteArray# mba# i#)
{-# INLINE readByteOffMBytes #-}
writeOffMBytes :: (MonadPrim s m, Prim e) => MBytes p s -> Off e -> e -> m ()
writeOffMBytes (MBytes mba#) (Off (I# i#)) a = prim_ (writeMutableByteArray# mba# i# a)
{-# INLINE writeOffMBytes #-}
writeByteOffMBytes :: (MonadPrim s m, Prim e) => MBytes p s -> Off Word8 -> e -> m ()
writeByteOffMBytes (MBytes mba#) (Off (I# i#)) a = prim_ (writeByteOffMutableByteArray# mba# i# a)
{-# INLINE writeByteOffMBytes #-}
isPinnedBytes :: Bytes p -> Bool
isPinnedBytes (Bytes b#) = isTrue# (isByteArrayPinned# b#)
{-# INLINE[0] isPinnedBytes #-}
isPinnedMBytes :: MBytes p d -> Bool
isPinnedMBytes (MBytes mb#) = isTrue# (isMutableByteArrayPinned# mb#)
{-# INLINE[0] isPinnedMBytes #-}
{-# RULES
"isPinnedBytes" forall (x :: Bytes 'Pin) . isPinnedBytes x = True
"isPinnedMBytes" forall (x :: MBytes 'Pin s) . isPinnedMBytes x = True
#-}
setMBytes ::
(MonadPrim s m, Prim e)
=> MBytes p s
-> Off e
-> Count e
-> e
-> m ()
setMBytes (MBytes mba#) (Off (I# o#)) (Count (I# n#)) a = prim_ (setMutableByteArray# mba# o# n# a)
{-# INLINE setMBytes #-}
toPtrBytes :: Bytes 'Pin -> Ptr e
toPtrBytes (Bytes ba#) = Ptr (byteArrayContents# ba#)
{-# INLINE toPtrBytes #-}
toPtrMBytes :: MBytes 'Pin s -> Ptr e
toPtrMBytes (MBytes mba#) = Ptr (mutableByteArrayContents# mba#)
{-# INLINE toPtrMBytes #-}
withPtrBytes :: MonadPrim s m => Bytes 'Pin -> (Ptr e -> m b) -> m b
withPtrBytes b f = do
res <- f (toPtrBytes b)
res <$ touch b
{-# INLINE withPtrBytes #-}
withNoHaltPtrBytes :: MonadUnliftPrim s m => Bytes 'Pin -> (Ptr e -> m b) -> m b
withNoHaltPtrBytes b f = withAliveUnliftPrim b $ f (toPtrBytes b)
{-# INLINE withNoHaltPtrBytes #-}
withPtrMBytes :: MonadPrim s m => MBytes 'Pin s -> (Ptr e -> m b) -> m b
withPtrMBytes mb f = do
res <- f (toPtrMBytes mb)
res <$ touch mb
{-# INLINE withPtrMBytes #-}
withNoHaltPtrMBytes :: MonadUnliftPrim s m => MBytes 'Pin s -> (Ptr e -> m b) -> m b
withNoHaltPtrMBytes mb f = withAliveUnliftPrim mb $ f (toPtrMBytes mb)
{-# INLINE withNoHaltPtrMBytes #-}
toForeignPtrBytes :: Bytes 'Pin -> ForeignPtr e
toForeignPtrBytes (Bytes ba#) =
ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#))
{-# INLINE toForeignPtrBytes #-}
toForeignPtrMBytes :: MBytes 'Pin s -> ForeignPtr e
toForeignPtrMBytes (MBytes mba#) =
ForeignPtr (byteArrayContents# (unsafeCoerce# mba#)) (PlainPtr (unsafeCoerce# mba#))
{-# INLINE toForeignPtrMBytes #-}
fromForeignPtrBytes :: ForeignPtr e -> Either String (Bytes 'Pin)
fromForeignPtrBytes (ForeignPtr addr# content) =
case content of
PlainPtr mbaRW# -> checkConvert mbaRW#
MallocPtr mbaRW# _ -> checkConvert mbaRW#
_ -> Left "Cannot convert a C allocated pointer"
where
checkConvert mba# =
let !b@(Bytes ba#) = unsafePerformIO (freezeMBytes (MBytes mba#))
in if isTrue# (byteArrayContents# ba# `eqAddr#` addr#)
then Right b
else Left
"ForeignPtr does not point to the beginning of the associated MutableByteArray#"
{-# INLINE fromForeignPtrBytes #-}
isSameBytes :: Bytes p1 -> Bytes p2 -> Bool
isSameBytes (Bytes b1#) (Bytes b2#) = isTrue# (isSameByteArray# b1# b2#)
{-# INLINE[0] isSameBytes #-}
{-# RULES
"isSamePinnedBytes" isSameBytes = isSamePinnedBytes
#-}
isSamePinnedBytes :: Bytes 'Pin -> Bytes 'Pin -> Bool
isSamePinnedBytes pb1 pb2 = toPtrBytes pb1 == toPtrBytes pb2
{-# INLINE isSamePinnedBytes #-}
byteStringConvertError :: String -> a
byteStringConvertError msg = error $ "Cannot convert 'ByteString'. " ++ msg
{-# NOINLINE byteStringConvertError #-}