module PrimitiveExtras.PrimMultiArray
(
  PrimMultiArray,
  create,
  replicateM,
  outerLength,
  toAssocsUnfoldl,
  toIndicesUnfoldl,
  toUnfoldlAt,
  toAssocsUnfoldlM,
  toIndicesUnfoldlM,
  toUnfoldlAtM,
  cerealGet,
  cerealGetAsInMemory,
  cerealPut,
  cerealPutAsInMemory,
  fold,
)
where

import PrimitiveExtras.Prelude hiding (replicateM, fold)
import PrimitiveExtras.Types
import qualified DeferredFolds.Unfoldl as Unfoldl
import qualified DeferredFolds.UnfoldlM as UnfoldlM
import qualified PrimitiveExtras.UnliftedArray as UnliftedArray
import qualified PrimitiveExtras.PrimArray as PrimArray
import qualified PrimitiveExtras.Folds as Folds
import qualified Data.Serialize as Cereal


deriving instance (Eq a, Prim a) => Eq (PrimMultiArray a)

instance (Show a, Prim a) => Show (PrimMultiArray a) where
  show :: PrimMultiArray a -> String
show (PrimMultiArray UnliftedArray (PrimArray a)
outerArray) =
    UnliftedArray (PrimArray a) -> [PrimArray a]
forall a. PrimUnlifted a => UnliftedArray a -> [a]
unliftedArrayToList UnliftedArray (PrimArray a)
outerArray [PrimArray a] -> ([PrimArray a] -> [[a]]) -> [[a]]
forall a b. a -> (a -> b) -> b
&
    (PrimArray a -> [a]) -> [PrimArray a] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map PrimArray a -> [a]
forall a. Prim a => PrimArray a -> [a]
primArrayToList [[a]] -> ([[a]] -> String) -> String
forall a b. a -> (a -> b) -> b
&
    [[a]] -> String
forall a. Show a => a -> String
show

{-| Given a size of the outer array and a function, which executes a fold over indexed elements in a monad,
constructs a prim multi-array -}
create :: (Monad m, Prim element) => Int -> (forall x. Fold (Int, element) x -> m x) -> m (PrimMultiArray element)
create :: Int
-> (forall x. Fold (Int, element) x -> m x)
-> m (PrimMultiArray element)
create Int
outerArraySize forall x. Fold (Int, element) x -> m x
runFold =
  do
    PrimArray Word32
indexCounts <- Fold (Int, element) (PrimArray Word32) -> m (PrimArray Word32)
forall x. Fold (Int, element) x -> m x
runFold (((Int, element) -> Int)
-> Fold Int (PrimArray Word32)
-> Fold (Int, element) (PrimArray Word32)
forall (p :: * -> * -> *) a b c.
Profunctor p =>
(a -> b) -> p b c -> p a c
lmap (Int, element) -> Int
forall a b. (a, b) -> a
fst (Int -> Fold Int (PrimArray Word32)
forall count.
(Integral count, Prim count) =>
Int -> Fold Int (PrimArray count)
Folds.indexCounts Int
outerArraySize))
    Fold (Int, element) (PrimMultiArray element)
-> m (PrimMultiArray element)
forall x. Fold (Int, element) x -> m x
runFold (PrimArray Word32 -> Fold (Int, element) (PrimMultiArray element)
forall size element.
(Integral size, Prim size, Prim element) =>
PrimArray size -> Fold (Int, element) (PrimMultiArray element)
Folds.primMultiArray (PrimArray Word32
indexCounts :: PrimArray Word32))

replicateM :: (Monad m, Prim a) => Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM :: Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM Int
size m (PrimArray a)
elementM =
  do
    !MutableUnliftedArray RealWorld (PrimArray a)
mutable <- MutableUnliftedArray RealWorld (PrimArray a)
-> m (MutableUnliftedArray RealWorld (PrimArray a))
forall (m :: * -> *) a. Monad m => a -> m a
return (IO (MutableUnliftedArray RealWorld (PrimArray a))
-> MutableUnliftedArray RealWorld (PrimArray a)
forall a. IO a -> a
unsafeDupablePerformIO (Int -> IO (MutableUnliftedArray (PrimState IO) (PrimArray a))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MutableUnliftedArray (PrimState m) a)
unsafeNewUnliftedArray Int
size))
    let 
      iterate :: Int -> m (PrimMultiArray a)
iterate Int
index =
        if Int
index Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
size
          then do
            PrimArray a
element <- m (PrimArray a)
elementM
            let !() = IO () -> ()
forall a. IO a -> a
unsafeDupablePerformIO (MutableUnliftedArray (PrimState IO) (PrimArray a)
-> Int -> PrimArray a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, PrimUnlifted a) =>
MutableUnliftedArray (PrimState m) a -> Int -> a -> m ()
writeUnliftedArray MutableUnliftedArray RealWorld (PrimArray a)
MutableUnliftedArray (PrimState IO) (PrimArray a)
mutable Int
index PrimArray a
element)
            Int -> m (PrimMultiArray a)
iterate (Int -> Int
forall a. Enum a => a -> a
succ Int
index)
          else PrimMultiArray a -> m (PrimMultiArray a)
forall (m :: * -> *) a. Monad m => a -> m a
return (UnliftedArray (PrimArray a) -> PrimMultiArray a
forall a. UnliftedArray (PrimArray a) -> PrimMultiArray a
PrimMultiArray (IO (UnliftedArray (PrimArray a)) -> UnliftedArray (PrimArray a)
forall a. IO a -> a
unsafePerformIO (MutableUnliftedArray (PrimState IO) (PrimArray a)
-> IO (UnliftedArray (PrimArray a))
forall (m :: * -> *) a.
PrimMonad m =>
MutableUnliftedArray (PrimState m) a -> m (UnliftedArray a)
unsafeFreezeUnliftedArray MutableUnliftedArray RealWorld (PrimArray a)
MutableUnliftedArray (PrimState IO) (PrimArray a)
mutable)))
      in Int -> m (PrimMultiArray a)
iterate Int
0

{-| Get length of the outer dimension of a primitive multi array -}
outerLength :: PrimMultiArray a -> Int
outerLength :: PrimMultiArray a -> Int
outerLength (PrimMultiArray UnliftedArray (PrimArray a)
outerDimension) = UnliftedArray (PrimArray a) -> Int
forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray a)
outerDimension

toAssocsUnfoldl :: Prim a => PrimMultiArray a -> Unfoldl (Int, a)
toAssocsUnfoldl :: PrimMultiArray a -> Unfoldl (Int, a)
toAssocsUnfoldl = UnfoldlM Identity (Int, a) -> Unfoldl (Int, a)
forall input. UnfoldlM Identity input -> Unfoldl input
Unfoldl.unfoldlM (UnfoldlM Identity (Int, a) -> Unfoldl (Int, a))
-> (PrimMultiArray a -> UnfoldlM Identity (Int, a))
-> PrimMultiArray a
-> Unfoldl (Int, a)
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. PrimMultiArray a -> UnfoldlM Identity (Int, a)
forall (m :: * -> *) a.
(Monad m, Prim a) =>
PrimMultiArray a -> UnfoldlM m (Int, a)
toAssocsUnfoldlM

toIndicesUnfoldl :: PrimMultiArray a -> Unfoldl Int
toIndicesUnfoldl :: PrimMultiArray a -> Unfoldl Int
toIndicesUnfoldl (PrimMultiArray UnliftedArray (PrimArray a)
ua) = Int -> Int -> Unfoldl Int
Unfoldl.intsInRange Int
0 (Int -> Int
forall a. Enum a => a -> a
pred (UnliftedArray (PrimArray a) -> Int
forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray a)
ua))

toUnfoldlAt :: Prim prim => PrimMultiArray prim -> Int -> Unfoldl prim
toUnfoldlAt :: PrimMultiArray prim -> Int -> Unfoldl prim
toUnfoldlAt (PrimMultiArray UnliftedArray (PrimArray prim)
ua) Int
index = UnliftedArray (PrimArray prim)
-> Int
-> Unfoldl prim
-> (PrimArray prim -> Unfoldl prim)
-> Unfoldl prim
forall element.
PrimUnlifted element =>
UnliftedArray element
-> Int -> forall result. result -> (element -> result) -> result
UnliftedArray.at UnliftedArray (PrimArray prim)
ua Int
index Unfoldl prim
forall (f :: * -> *) a. Alternative f => f a
empty PrimArray prim -> Unfoldl prim
forall prim. Prim prim => PrimArray prim -> Unfoldl prim
PrimArray.toElementsUnfoldl

toAssocsUnfoldlM :: (Monad m, Prim a) => PrimMultiArray a -> UnfoldlM m (Int, a)
toAssocsUnfoldlM :: PrimMultiArray a -> UnfoldlM m (Int, a)
toAssocsUnfoldlM PrimMultiArray a
pma =
  do
    Int
index <- PrimMultiArray a -> UnfoldlM m Int
forall (m :: * -> *) a.
Monad m =>
PrimMultiArray a -> UnfoldlM m Int
toIndicesUnfoldlM PrimMultiArray a
pma
    a
element <- PrimMultiArray a -> Int -> UnfoldlM m a
forall (m :: * -> *) prim.
(Monad m, Prim prim) =>
PrimMultiArray prim -> Int -> UnfoldlM m prim
toUnfoldlAtM PrimMultiArray a
pma Int
index
    (Int, a) -> UnfoldlM m (Int, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
index, a
element)

toIndicesUnfoldlM :: Monad m => PrimMultiArray a -> UnfoldlM m Int
toIndicesUnfoldlM :: PrimMultiArray a -> UnfoldlM m Int
toIndicesUnfoldlM (PrimMultiArray UnliftedArray (PrimArray a)
ua) = Int -> Int -> UnfoldlM m Int
forall (m :: * -> *). Monad m => Int -> Int -> UnfoldlM m Int
UnfoldlM.intsInRange Int
0 (Int -> Int
forall a. Enum a => a -> a
pred (UnliftedArray (PrimArray a) -> Int
forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray a)
ua))

toUnfoldlAtM :: (Monad m, Prim prim) => PrimMultiArray prim -> Int -> UnfoldlM m prim
toUnfoldlAtM :: PrimMultiArray prim -> Int -> UnfoldlM m prim
toUnfoldlAtM (PrimMultiArray UnliftedArray (PrimArray prim)
ua) Int
index = UnliftedArray (PrimArray prim)
-> Int
-> UnfoldlM m prim
-> (PrimArray prim -> UnfoldlM m prim)
-> UnfoldlM m prim
forall element.
PrimUnlifted element =>
UnliftedArray element
-> Int -> forall result. result -> (element -> result) -> result
UnliftedArray.at UnliftedArray (PrimArray prim)
ua Int
index UnfoldlM m prim
forall (f :: * -> *) a. Alternative f => f a
empty PrimArray prim -> UnfoldlM m prim
forall (m :: * -> *) prim.
(Monad m, Prim prim) =>
PrimArray prim -> UnfoldlM m prim
PrimArray.toElementsUnfoldlM

cerealGet :: Prim element => Cereal.Get Int -> Cereal.Get element -> Cereal.Get (PrimMultiArray element)
cerealGet :: Get Int -> Get element -> Get (PrimMultiArray element)
cerealGet Get Int
int Get element
element =
  do
    Int
size <- Get Int
int
    Int -> Get (PrimArray element) -> Get (PrimMultiArray element)
forall (m :: * -> *) a.
(Monad m, Prim a) =>
Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM Int
size (Get Int -> Get element -> Get (PrimArray element)
forall element.
Prim element =>
Get Int -> Get element -> Get (PrimArray element)
PrimArray.cerealGet Get Int
int Get element
element)

cerealGetAsInMemory :: Prim element => Cereal.Get Int -> Cereal.Get (PrimMultiArray element)
cerealGetAsInMemory :: Get Int -> Get (PrimMultiArray element)
cerealGetAsInMemory Get Int
int =
  do
    Int
size <- Get Int
int
    Int -> Get (PrimArray element) -> Get (PrimMultiArray element)
forall (m :: * -> *) a.
(Monad m, Prim a) =>
Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM Int
size (Get Int -> Get (PrimArray element)
forall element. Prim element => Get Int -> Get (PrimArray element)
PrimArray.cerealGetAsInMemory Get Int
int)

cerealPut :: Prim element => Cereal.Putter Int -> Cereal.Putter element -> Cereal.Putter (PrimMultiArray element)
cerealPut :: Putter Int -> Putter element -> Putter (PrimMultiArray element)
cerealPut Putter Int
int Putter element
element (PrimMultiArray UnliftedArray (PrimArray element)
outerArrayValue) =
  Put
size Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> Put
innerArrays
  where
    size :: Put
size = Putter Int
int (UnliftedArray (PrimArray element) -> Int
forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray element)
outerArrayValue)
    innerArrays :: Put
innerArrays = (PrimArray element -> Put)
-> UnliftedArray (PrimArray element) -> Put
forall (m :: * -> *) a.
(Monad m, PrimUnlifted a) =>
(a -> m ()) -> UnliftedArray a -> m ()
UnliftedArray.traverse_ (Putter Int -> Putter element -> PrimArray element -> Put
forall element.
Prim element =>
Putter Int -> Putter element -> Putter (PrimArray element)
PrimArray.cerealPut Putter Int
int Putter element
element) UnliftedArray (PrimArray element)
outerArrayValue

cerealPutAsInMemory :: Prim element => Cereal.Putter Int -> Cereal.Putter (PrimMultiArray element)
cerealPutAsInMemory :: Putter Int -> Putter (PrimMultiArray element)
cerealPutAsInMemory Putter Int
int (PrimMultiArray UnliftedArray (PrimArray element)
outerArrayValue) =
  Put
size Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> Put
innerArrays
  where
    size :: Put
size = Putter Int
int (UnliftedArray (PrimArray element) -> Int
forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray element)
outerArrayValue)
    innerArrays :: Put
innerArrays = (PrimArray element -> Put)
-> UnliftedArray (PrimArray element) -> Put
forall (m :: * -> *) a.
(Monad m, PrimUnlifted a) =>
(a -> m ()) -> UnliftedArray a -> m ()
UnliftedArray.traverse_ (Putter Int -> PrimArray element -> Put
forall element.
Prim element =>
Putter Int -> Putter (PrimArray element)
PrimArray.cerealPutAsInMemory Putter Int
int) UnliftedArray (PrimArray element)
outerArrayValue

{-|
Having a priorly computed array of inner dimension sizes,
e.g., using the 'PrimitiveExtras.PrimArray.indexCountsFold',
construct a fold over indexed elements into a multi-array of elements.

Thus it allows to construct it in two passes over the indexed elements.
-}
fold :: (Integral size, Prim size, Prim element) => PrimArray size -> Fold (Int, element) (PrimMultiArray element)
fold :: PrimArray size -> Fold (Int, element) (PrimMultiArray element)
fold = PrimArray size -> Fold (Int, element) (PrimMultiArray element)
forall size element.
(Integral size, Prim size, Prim element) =>
PrimArray size -> Fold (Int, element) (PrimMultiArray element)
Folds.primMultiArray