{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module PrimitiveExtras.PrimArray where

import qualified Data.ByteString.Short.Internal as ShortByteString
import qualified Data.Serialize as Cereal
import qualified Data.Vector.Primitive as PrimitiveVector
import qualified Data.Vector.Unboxed as UnboxedVector
import qualified PrimitiveExtras.FoldMs as FoldMs
import qualified PrimitiveExtras.Folds as Folds
import PrimitiveExtras.Prelude hiding (replicateM, traverse_)

-- |
-- Construct from a primitive vector.
-- In case the vector is not a slice, it is an /O(1)/ op.
primitiveVector :: (Prim a) => PrimitiveVector.Vector a -> PrimArray a
primitiveVector :: forall a. Prim a => Vector a -> PrimArray a
primitiveVector (PrimitiveVector.Vector Int
offset Int
length (ByteArray ByteArray#
unliftedByteArray)) =
  let primArray :: PrimArray a
primArray = forall a. ByteArray# -> PrimArray a
PrimArray ByteArray#
unliftedByteArray
   in if Int
offset forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Int
length forall a. Eq a => a -> a -> Bool
== forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray a
primArray
        then PrimArray a
primArray
        else forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
          MutablePrimArray s a
ma <- forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
length
          forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
MutablePrimArray (PrimState m) a
-> Int -> PrimArray a -> Int -> Int -> m ()
copyPrimArray MutablePrimArray s a
ma Int
0 PrimArray a
primArray Int
offset Int
length
          forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray s a
ma

oneHot ::
  (Prim a) =>
  -- | Size
  Int ->
  -- | Index
  Int ->
  a ->
  PrimArray a
oneHot :: forall a. Prim a => Int -> Int -> a -> PrimArray a
oneHot Int
size Int
index a
value =
  forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
    MutablePrimArray s a
marr <- forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
size
    forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray s a
marr Int
index a
value
    forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray s a
marr

generate :: (Prim a) => Int -> (Int -> IO a) -> IO (PrimArray a)
generate :: forall a. Prim a => Int -> (Int -> IO a) -> IO (PrimArray a)
generate Int
size Int -> IO a
elementIO =
  do
    MutablePrimArray RealWorld a
array <- forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
size
    let loop :: Int -> IO (PrimArray a)
loop Int
index =
          if Int
index forall a. Ord a => a -> a -> Bool
< Int
size
            then do
              a
element <- Int -> IO a
elementIO Int
index
              forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray RealWorld a
array Int
index a
element
              Int -> IO (PrimArray a)
loop (forall a. Enum a => a -> a
succ Int
index)
            else forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld a
array
     in Int -> IO (PrimArray a)
loop Int
0

replicate :: (Prim a) => Int -> IO a -> IO (PrimArray a)
replicate :: forall a. Prim a => Int -> IO a -> IO (PrimArray a)
replicate Int
size IO a
elementIO =
  do
    MutablePrimArray RealWorld a
array <- forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
size
    let loop :: Int -> IO (PrimArray a)
loop Int
index =
          if Int
index forall a. Ord a => a -> a -> Bool
< Int
size
            then do
              a
element <- IO a
elementIO
              forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray RealWorld a
array Int
index a
element
              Int -> IO (PrimArray a)
loop (forall a. Enum a => a -> a
succ Int
index)
            else forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld a
array
     in Int -> IO (PrimArray a)
loop Int
0

-- | Please notice that this function is highly untested
replicateM :: (Monad m, Prim element) => Int -> m element -> m (PrimArray element)
replicateM :: forall (m :: * -> *) element.
(Monad m, Prim element) =>
Int -> m element -> m (PrimArray element)
replicateM Int
size m element
elementM =
  do
    !MutablePrimArray RealWorld element
mutable <- forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. IO a -> a
unsafeDupablePerformIO (forall (m :: * -> *) a.
(PrimMonad m, Prim a) =>
Int -> m (MutablePrimArray (PrimState m) a)
newPrimArray Int
size))
    let iterate :: Int -> m (PrimArray element)
iterate Int
index =
          if Int
index forall a. Ord a => a -> a -> Bool
< Int
size
            then do
              element
element <- m element
elementM
              let !() = forall a. IO a -> a
unsafeDupablePerformIO (forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutablePrimArray (PrimState m) a -> Int -> a -> m ()
writePrimArray MutablePrimArray RealWorld element
mutable Int
index element
element)
              Int -> m (PrimArray element)
iterate (forall a. Enum a => a -> a
succ Int
index)
            else forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. IO a -> a
unsafePerformIO (forall (m :: * -> *) a.
PrimMonad m =>
MutablePrimArray (PrimState m) a -> m (PrimArray a)
unsafeFreezePrimArray MutablePrimArray RealWorld element
mutable))
     in Int -> m (PrimArray element)
iterate Int
0

traverse_ :: (Applicative f, Prim a) => (a -> f b) -> PrimArray a -> f ()
traverse_ :: forall (f :: * -> *) a b.
(Applicative f, Prim a) =>
(a -> f b) -> PrimArray a -> f ()
traverse_ = forall (f :: * -> *) a b.
(Applicative f, Prim a) =>
(a -> f b) -> PrimArray a -> f ()
traversePrimArray_

traverseWithIndexInRange_ :: (Prim a) => PrimArray a -> Int -> Int -> (Int -> a -> IO ()) -> IO ()
traverseWithIndexInRange_ :: forall a.
Prim a =>
PrimArray a -> Int -> Int -> (Int -> a -> IO ()) -> IO ()
traverseWithIndexInRange_ PrimArray a
primArray Int
from Int
to Int -> a -> IO ()
action =
  let iterate :: Int -> IO ()
iterate Int
index =
        if Int
index forall a. Ord a => a -> a -> Bool
< Int
to
          then do
            Int -> a -> IO ()
action Int
index forall a b. (a -> b) -> a -> b
$! forall a. Prim a => PrimArray a -> Int -> a
indexPrimArray PrimArray a
primArray Int
index
            Int -> IO ()
iterate (forall a. Enum a => a -> a
succ Int
index)
          else forall (m :: * -> *) a. Monad m => a -> m a
return ()
   in Int -> IO ()
iterate Int
from

toElementsUnfoldl :: (Prim prim) => PrimArray prim -> Unfoldl prim
toElementsUnfoldl :: forall prim. Prim prim => PrimArray prim -> Unfoldl prim
toElementsUnfoldl PrimArray prim
ba = forall a. (forall x. (x -> a -> x) -> x -> x) -> Unfoldl a
Unfoldl forall a b. (a -> b) -> a -> b
$ \x -> prim -> x
f x
z -> forall a b. Prim a => (b -> a -> b) -> b -> PrimArray a -> b
foldlPrimArray' x -> prim -> x
f x
z PrimArray prim
ba

toElementsUnfoldlM :: (Monad m, Prim prim) => PrimArray prim -> UnfoldlM m prim
toElementsUnfoldlM :: forall (m :: * -> *) prim.
(Monad m, Prim prim) =>
PrimArray prim -> UnfoldlM m prim
toElementsUnfoldlM PrimArray prim
ba = forall (m :: * -> *) a.
(forall x. (x -> a -> m x) -> x -> m x) -> UnfoldlM m a
UnfoldlM forall a b. (a -> b) -> a -> b
$ \x -> prim -> m x
f x
z -> forall a (m :: * -> *) b.
(Prim a, Monad m) =>
(b -> a -> m b) -> b -> PrimArray a -> m b
foldlPrimArrayM' x -> prim -> m x
f x
z PrimArray prim
ba

toByteArray :: PrimArray a -> ByteArray
toByteArray :: forall a. PrimArray a -> ByteArray
toByteArray (PrimArray ByteArray#
unliftedByteArray) =
  ByteArray# -> ByteArray
ByteArray ByteArray#
unliftedByteArray

toPrimitiveVector :: (Prim a) => PrimArray a -> PrimitiveVector.Vector a
toPrimitiveVector :: forall a. Prim a => PrimArray a -> Vector a
toPrimitiveVector PrimArray a
primArray =
  forall a. Int -> Int -> ByteArray -> Vector a
PrimitiveVector.Vector Int
0 (forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray a
primArray) (forall a. PrimArray a -> ByteArray
toByteArray PrimArray a
primArray)

toUnboxedVector :: (Prim a) => PrimArray a -> UnboxedVector.Vector a
toUnboxedVector :: forall a. Prim a => PrimArray a -> Vector a
toUnboxedVector PrimArray a
primArray =
  forall a b. a -> b
unsafeCoerce (forall a. Prim a => PrimArray a -> Vector a
toPrimitiveVector PrimArray a
primArray)

cerealGet :: (Prim element) => Cereal.Get Int -> Cereal.Get element -> Cereal.Get (PrimArray element)
cerealGet :: forall element.
Prim element =>
Get Int -> Get element -> Get (PrimArray element)
cerealGet Get Int
int Get element
element =
  do
    Int
size <- Get Int
int
    forall (m :: * -> *) element.
(Monad m, Prim element) =>
Int -> m element -> m (PrimArray element)
replicateM Int
size Get element
element

cerealGetAsInMemory :: (Prim element) => Cereal.Get Int -> Cereal.Get (PrimArray element)
cerealGetAsInMemory :: forall element. Prim element => Get Int -> Get (PrimArray element)
cerealGetAsInMemory Get Int
int =
  do
    Int
size <- Get Int
int
    ShortByteString.SBS ByteArray#
ba <- Int -> Get ShortByteString
Cereal.getShortByteString Int
size
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. ByteArray# -> PrimArray a
PrimArray ByteArray#
ba)

cerealPut :: (Prim element) => Cereal.Putter Int -> Cereal.Putter element -> Cereal.Putter (PrimArray element)
cerealPut :: forall element.
Prim element =>
Putter Int -> Putter element -> Putter (PrimArray element)
cerealPut Putter Int
int Putter element
element PrimArray element
primArrayValue =
  Put
size forall a. Semigroup a => a -> a -> a
<> Put
elements
  where
    size :: Put
size = Putter Int
int (forall a. Prim a => PrimArray a -> Int
sizeofPrimArray PrimArray element
primArrayValue)
    elements :: Put
elements = forall (f :: * -> *) a b.
(Applicative f, Prim a) =>
(a -> f b) -> PrimArray a -> f ()
traverse_ Putter element
element PrimArray element
primArrayValue

cerealPutAsInMemory :: (Prim element) => Cereal.Putter Int -> Cereal.Putter (PrimArray element)
cerealPutAsInMemory :: forall element.
Prim element =>
Putter Int -> Putter (PrimArray element)
cerealPutAsInMemory Putter Int
int primArrayValue :: PrimArray element
primArrayValue@(PrimArray ByteArray#
ba) =
  Put
size forall a. Semigroup a => a -> a -> a
<> Put
elements
  where
    size :: Put
size = Putter Int
int (ByteArray -> Int
sizeofByteArray (ByteArray# -> ByteArray
ByteArray ByteArray#
ba))
    elements :: Put
elements = Putter ShortByteString
Cereal.putShortByteString (ByteArray# -> ShortByteString
ShortByteString.SBS ByteArray#
ba)

-- |
-- Given a size of the array,
-- construct a fold, which produces an array of index counts.
indexCountsFold ::
  (Integral count, Prim count) =>
  -- | Array size
  Int ->
  Fold Int (PrimArray count)
indexCountsFold :: forall count.
(Integral count, Prim count) =>
Int -> Fold Int (PrimArray count)
indexCountsFold = forall count.
(Integral count, Prim count) =>
Int -> Fold Int (PrimArray count)
Folds.indexCounts

-- |
-- Given a size of the array,
-- construct a fold, which produces an array of elements.
elementsFoldM ::
  (Prim a) =>
  -- | Array size
  Int ->
  FoldM IO a (PrimArray a)
elementsFoldM :: forall a. Prim a => Int -> FoldM IO a (PrimArray a)
elementsFoldM = forall a. Prim a => Int -> FoldM IO a (PrimArray a)
FoldMs.primArray