module Data.Primitive.PrimArray
(
PrimArray(..)
, MutablePrimArray(..)
, newPrimArray
, emptyPrimArray
, singletonPrimArray
, readPrimArray
, writePrimArray
, indexPrimArray
, unsafeFreezePrimArray
, unsafeThawPrimArray
, copyPrimArray
, copyMutablePrimArray
, copyPrimArrayToPtr
, copyPtrToMutablePrimArray
, copyPtrToPrimArray
, setPrimArray
, sameMutablePrimArray
, getSizeofMutablePrimArray
, sizeofPrimArray
) where
import GHC.Prim
import GHC.Exts (isTrue#,IsList(..))
import GHC.Int
import GHC.Ptr
import Data.Primitive
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Semigroup (Semigroup(..))
import qualified Data.Semigroup as SG
import qualified Data.List as L
import qualified Data.Primitive.Types as PT
data PrimArray a = PrimArray ByteArray#
data MutablePrimArray s a = MutablePrimArray (MutableByteArray# s)
instance (Eq a, Prim a) => Eq (PrimArray a) where
a1 == a2 = sizeofPrimArray a1 == sizeofPrimArray a2 && loop (sizeofPrimArray a1 1)
where
loop !i | i < 0 = True
| otherwise = indexPrimArray a1 i == indexPrimArray a2 i && loop (i1)
instance Prim a => IsList (PrimArray a) where
type Item (PrimArray a) = a
fromList xs = primArrayFromList (L.length xs) xs
fromListN = primArrayFromList
toList = primArrayToList
instance (Prim a, Show a) => Show (PrimArray a) where
showsPrec p = showsPrec p . primArrayToList
primArrayFromList :: forall a. Prim a => Int -> [a] -> PrimArray a
primArrayFromList len vs = runST run where
run :: forall s. ST s (PrimArray a)
run = do
arr <- newPrimArray len
let go :: [a] -> Int -> ST s ()
go !xs !ix = case xs of
[] -> return ()
a : as -> do
writePrimArray arr ix a
go as (ix + 1)
go vs 0
unsafeFreezePrimArray arr
primArrayToList :: forall a. Prim a => PrimArray a -> [a]
primArrayToList arr = go 0 where
!len = sizeofPrimArray arr
go :: Int -> [a]
go !ix = if ix < len
then indexPrimArray arr ix : go (ix + 1)
else []
appendPrimArray :: Prim a => PrimArray a -> PrimArray a -> PrimArray a
appendPrimArray a b = runST $ do
let szA = sizeofPrimArray a
let szB = sizeofPrimArray b
c <- newPrimArray (szA + szB)
copyPrimArray c 0 a 0 szA
copyPrimArray c szA b 0 szB
unsafeFreezePrimArray c
instance Prim a => Semigroup (PrimArray a) where
(<>) = appendPrimArray
instance Prim a => Monoid (PrimArray a) where
mempty = emptyPrimArray
mappend = (SG.<>)
emptyPrimArray :: PrimArray a
emptyPrimArray = runST $ primitive $ \s0# -> case newByteArray# 0# s0# of
(# s1#, arr# #) -> case unsafeFreezeByteArray# arr# s1# of
(# s2#, arr'# #) -> (# s2#, PrimArray arr'# #)
singletonPrimArray :: Prim a => a -> PrimArray a
singletonPrimArray a = runST $ do
arr <- newPrimArray 1
writePrimArray arr 0 a
unsafeFreezePrimArray arr
newPrimArray :: forall m a. (PrimMonad m, Prim a) => Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray (I# n#)
= primitive (\s# ->
case newByteArray# (n# *# sizeOf# (undefined :: a)) s# of
(# s'#, arr# #) -> (# s'#, MutablePrimArray arr# #)
)
readPrimArray :: (Prim a, PrimMonad m) => MutablePrimArray (PrimState m) a -> Int -> m a
readPrimArray (MutablePrimArray arr#) (I# i#)
= primitive (readByteArray# arr# i#)
writePrimArray ::
(Prim a, PrimMonad m)
=> MutablePrimArray (PrimState m) a
-> Int
-> a
-> m ()
writePrimArray (MutablePrimArray arr#) (I# i#) x
= primitive_ (writeByteArray# arr# i# x)
copyMutablePrimArray :: forall m a.
(PrimMonad m, Prim a)
=> MutablePrimArray (PrimState m) a
-> Int
-> MutablePrimArray (PrimState m) a
-> Int
-> Int
-> m ()
copyMutablePrimArray (MutablePrimArray dst#) (I# doff#) (MutablePrimArray src#) (I# soff#) (I# n#)
= primitive_ (copyMutableByteArray#
src#
(soff# *# (sizeOf# (undefined :: a)))
dst#
(doff# *# (sizeOf# (undefined :: a)))
(n# *# (sizeOf# (undefined :: a)))
)
copyPrimArray :: forall m a.
(PrimMonad m, Prim a)
=> MutablePrimArray (PrimState m) a
-> Int
-> PrimArray a
-> Int
-> Int
-> m ()
copyPrimArray (MutablePrimArray dst#) (I# doff#) (PrimArray src#) (I# soff#) (I# n#)
= primitive_ (copyByteArray#
src#
(soff# *# (sizeOf# (undefined :: a)))
dst#
(doff# *# (sizeOf# (undefined :: a)))
(n# *# (sizeOf# (undefined :: a)))
)
copyPrimArrayToPtr :: forall m a. (PrimMonad m, Prim a)
=> Ptr a
-> PrimArray a
-> Int
-> Int
-> m ()
copyPrimArrayToPtr (Ptr addr#) (PrimArray ba#) (I# soff#) (I# n#) =
primitive (\ s# ->
let s'# = copyByteArrayToAddr# ba# (soff# *# siz#) addr# (n# *# siz#) s#
in (# s'#, () #))
where siz# = sizeOf# (undefined :: a)
copyPtrToMutablePrimArray :: forall m a. (PrimMonad m, Prim a)
=> MutablePrimArray (PrimState m) a
-> Int
-> Ptr a
-> Int
-> m ()
copyPtrToMutablePrimArray (MutablePrimArray ba#) (I# doff#) (Ptr addr#) (I# n#) =
primitive (\ s# ->
let s'# = copyAddrToByteArray# addr# ba# (doff# *# siz#) (n# *# siz#) s#
in (# s'#, () #))
where siz# = sizeOf# (undefined :: a)
copyPtrToPrimArray :: forall a. Prim a
=> Ptr a
-> Int
-> PrimArray a
copyPtrToPrimArray ptr len = runST $ do
arr <- newPrimArray len
copyPtrToMutablePrimArray arr 0 ptr len
unsafeFreezePrimArray arr
setPrimArray
:: (Prim a, PrimMonad m)
=> MutablePrimArray (PrimState m) a
-> Int
-> Int
-> a
-> m ()
setPrimArray (MutablePrimArray dst#) (I# doff#) (I# sz#) x
= primitive_ (PT.setByteArray# dst# doff# sz# x)
getSizeofMutablePrimArray :: forall m a. (PrimMonad m, Prim a)
=> MutablePrimArray (PrimState m) a
-> m Int
getSizeofMutablePrimArray (MutablePrimArray arr#)
= primitive (\s# ->
case getSizeofMutableByteArray# arr# s# of
(# s'#, sz# #) -> (# s'#, I# (quotInt# sz# (sizeOf# (undefined :: a))) #)
)
sameMutablePrimArray :: MutablePrimArray s a -> MutablePrimArray s a -> Bool
sameMutablePrimArray (MutablePrimArray arr#) (MutablePrimArray brr#)
= isTrue# (sameMutableByteArray# arr# brr#)
unsafeFreezePrimArray
:: PrimMonad m => MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray (MutablePrimArray arr#)
= primitive (\s# -> case unsafeFreezeByteArray# arr# s# of
(# s'#, arr'# #) -> (# s'#, PrimArray arr'# #))
unsafeThawPrimArray
:: PrimMonad m => PrimArray a -> m (MutablePrimArray (PrimState m) a)
unsafeThawPrimArray (PrimArray arr#)
= primitive (\s# -> (# s#, MutablePrimArray (unsafeCoerce# arr#) #))
indexPrimArray :: forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray (PrimArray arr#) (I# i#) = indexByteArray# arr# i#
sizeofPrimArray :: forall a. Prim a => PrimArray a -> Int
sizeofPrimArray (PrimArray arr#) = I# (quotInt# (sizeofByteArray# arr#) (sizeOf# (undefined :: a)))