module PrimitiveExtras.PrimArray
where
import PrimitiveExtras.Prelude hiding (replicateM, traverse_)
import PrimitiveExtras.Types
import qualified Data.Serialize as Cereal
import qualified Data.Vector.Unboxed as UnboxedVector
import qualified Data.Vector.Primitive as PrimitiveVector
import qualified PrimitiveExtras.Folds as Folds
import qualified PrimitiveExtras.FoldMs as FoldMs
import qualified Data.ByteString.Short.Internal as ShortByteString
oneHot :: Prim a => Int -> Int -> a -> PrimArray a
oneHot size index value =
runST $ do
marr <- newPrimArray size
writePrimArray marr index value
unsafeFreezePrimArray marr
generate :: Prim a => Int -> (Int -> IO a) -> IO (PrimArray a)
generate size elementIO =
do
array <- newPrimArray size
let
loop index =
if index < size
then do
element <- elementIO index
writePrimArray array index element
loop (succ index)
else unsafeFreezePrimArray array
in loop 0
replicate :: Prim a => Int -> IO a -> IO (PrimArray a)
replicate size elementIO =
do
array <- newPrimArray size
let
loop index =
if index < size
then do
element <- elementIO
writePrimArray array index element
loop (succ index)
else unsafeFreezePrimArray array
in loop 0
replicateM :: (Monad m, Prim element) => Int -> m element -> m (PrimArray element)
replicateM size elementM =
do
!mutable <- return (unsafeDupablePerformIO (newPrimArray size))
let
iterate index =
if index < size
then do
element <- elementM
let !() = unsafeDupablePerformIO (writePrimArray mutable index element)
iterate (succ index)
else return (unsafePerformIO (unsafeFreezePrimArray mutable))
in iterate 0
traverse_ = traversePrimArray_
traverseWithIndexInRange_ :: Prim a => PrimArray a -> Int -> Int -> (Int -> a -> IO ()) -> IO ()
traverseWithIndexInRange_ primArray from to action =
let iterate index = if index < to
then do
action index $! indexPrimArray primArray index
iterate (succ index)
else return ()
in iterate from
toElementsUnfold :: Prim prim => PrimArray prim -> Unfold prim
toElementsUnfold ba = Unfold $ \f z -> foldlPrimArray' f z ba
toElementsUnfoldM :: (Monad m, Prim prim) => PrimArray prim -> UnfoldM m prim
toElementsUnfoldM ba = UnfoldM $ \f z -> foldlPrimArrayM' f z ba
toByteArray :: PrimArray a -> ByteArray
toByteArray (PrimArray unliftedByteArray) =
ByteArray unliftedByteArray
toPrimitiveVector :: Prim a => PrimArray a -> PrimitiveVector.Vector a
toPrimitiveVector primArray =
PrimitiveVector.Vector 0 (sizeofPrimArray primArray) (toByteArray primArray)
toUnboxedVector :: Prim a => PrimArray a -> UnboxedVector.Vector a
toUnboxedVector primArray =
unsafeCoerce (toPrimitiveVector primArray)
cerealGet :: Prim element => Cereal.Get Int -> Cereal.Get element -> Cereal.Get (PrimArray element)
cerealGet int element =
do
size <- int
replicateM size element
cerealGetAsInMemory :: Prim element => Cereal.Get Int -> Cereal.Get (PrimArray element)
cerealGetAsInMemory int =
do
size <- int
ShortByteString.SBS ba <- Cereal.getShortByteString size
return (PrimArray ba)
cerealPut :: Prim element => Cereal.Putter Int -> Cereal.Putter element -> Cereal.Putter (PrimArray element)
cerealPut int element primArrayValue =
size <> elements
where
size = int (sizeofPrimArray primArrayValue)
elements = traverse_ element primArrayValue
cerealPutAsInMemory :: Prim element => Cereal.Putter Int -> Cereal.Putter (PrimArray element)
cerealPutAsInMemory int primArrayValue@(PrimArray ba) =
size <> elements
where
size = int (sizeofByteArray (ByteArray ba))
elements = Cereal.putShortByteString (ShortByteString.SBS ba)
indexCountsFold :: (Integral count, Prim count) => Int -> Fold Int (PrimArray count)
indexCountsFold = Folds.indexCounts
elementsFoldM :: Prim a => Int -> FoldM IO a (PrimArray a)
elementsFoldM = FoldMs.primArray