module PrimitiveExtras.PrimMultiArray
(
  PrimMultiArray,
  create,
  replicateM,
  outerLength,
  toAssocsUnfoldl,
  toIndicesUnfoldl,
  toUnfoldlAt,
  toAssocsUnfoldlM,
  toIndicesUnfoldlM,
  toUnfoldlAtM,
  cerealGet,
  cerealGetAsInMemory,
  cerealPut,
  cerealPutAsInMemory,
  fold,
)
where

import PrimitiveExtras.Prelude hiding (replicateM, fold)
import PrimitiveExtras.Types
import qualified DeferredFolds.Unfoldl as Unfoldl
import qualified DeferredFolds.UnfoldlM as UnfoldlM
import qualified PrimitiveExtras.UnliftedArray as UnliftedArray
import qualified PrimitiveExtras.PrimArray as PrimArray
import qualified PrimitiveExtras.Folds as Folds
import qualified Data.Serialize as Cereal


deriving instance (Eq a, Prim a) => Eq (PrimMultiArray a)

instance (Show a, Prim a) => Show (PrimMultiArray a) where
  show :: PrimMultiArray a -> String
show (PrimMultiArray UnliftedArray (PrimArray a)
outerArray) =
    forall a. PrimUnlifted a => UnliftedArray a -> [a]
unliftedArrayToList UnliftedArray (PrimArray a)
outerArray forall a b. a -> (a -> b) -> b
&
    forall a b. (a -> b) -> [a] -> [b]
map forall a. Prim a => PrimArray a -> [a]
primArrayToList forall a b. a -> (a -> b) -> b
&
    forall a. Show a => a -> String
show

{-| Given a size of the outer array and a function, which executes a fold over indexed elements in a monad,
constructs a prim multi-array -}
create :: (Monad m, Prim element) => Int -> (forall x. Fold (Int, element) x -> m x) -> m (PrimMultiArray element)
create :: forall (m :: * -> *) element.
(Monad m, Prim element) =>
Int
-> (forall x. Fold (Int, element) x -> m x)
-> m (PrimMultiArray element)
create Int
outerArraySize forall x. Fold (Int, element) x -> m x
runFold =
  do
    PrimArray Word32
indexCounts <- forall x. Fold (Int, element) x -> m x
runFold (forall (p :: * -> * -> *) a b c.
Profunctor p =>
(a -> b) -> p b c -> p a c
lmap forall a b. (a, b) -> a
fst (forall count.
(Integral count, Prim count) =>
Int -> Fold Int (PrimArray count)
Folds.indexCounts Int
outerArraySize))
    forall x. Fold (Int, element) x -> m x
runFold (forall size element.
(Integral size, Prim size, Prim element) =>
PrimArray size -> Fold (Int, element) (PrimMultiArray element)
Folds.primMultiArray (PrimArray Word32
indexCounts :: PrimArray Word32))

replicateM :: (Monad m, Prim a) => Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM :: forall (m :: * -> *) a.
(Monad m, Prim a) =>
Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM Int
size m (PrimArray a)
elementM =
  do
    !MutableUnliftedArray RealWorld (PrimArray a)
mutable <- forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. IO a -> a
unsafeDupablePerformIO (forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MutableUnliftedArray (PrimState m) a)
unsafeNewUnliftedArray Int
size))
    let 
      iterate :: Int -> m (PrimMultiArray a)
iterate Int
index =
        if Int
index forall a. Ord a => a -> a -> Bool
< Int
size
          then do
            PrimArray a
element <- m (PrimArray a)
elementM
            let !() = forall a. IO a -> a
unsafeDupablePerformIO (forall (m :: * -> *) a.
(PrimMonad m, PrimUnlifted a) =>
MutableUnliftedArray (PrimState m) a -> Int -> a -> m ()
writeUnliftedArray MutableUnliftedArray RealWorld (PrimArray a)
mutable Int
index PrimArray a
element)
            Int -> m (PrimMultiArray a)
iterate (forall a. Enum a => a -> a
succ Int
index)
          else forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. UnliftedArray (PrimArray a) -> PrimMultiArray a
PrimMultiArray (forall a. IO a -> a
unsafePerformIO (forall (m :: * -> *) a.
PrimMonad m =>
MutableUnliftedArray (PrimState m) a -> m (UnliftedArray a)
unsafeFreezeUnliftedArray MutableUnliftedArray RealWorld (PrimArray a)
mutable)))
      in Int -> m (PrimMultiArray a)
iterate Int
0

{-| Get length of the outer dimension of a primitive multi array -}
outerLength :: PrimMultiArray a -> Int
outerLength :: forall a. PrimMultiArray a -> Int
outerLength (PrimMultiArray UnliftedArray (PrimArray a)
outerDimension) = forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray a)
outerDimension

toAssocsUnfoldl :: Prim a => PrimMultiArray a -> Unfoldl (Int, a)
toAssocsUnfoldl :: forall a. Prim a => PrimMultiArray a -> Unfoldl (Int, a)
toAssocsUnfoldl = forall input. UnfoldlM Identity input -> Unfoldl input
Unfoldl.unfoldlM forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. forall (m :: * -> *) a.
(Monad m, Prim a) =>
PrimMultiArray a -> UnfoldlM m (Int, a)
toAssocsUnfoldlM

toIndicesUnfoldl :: PrimMultiArray a -> Unfoldl Int
toIndicesUnfoldl :: forall a. PrimMultiArray a -> Unfoldl Int
toIndicesUnfoldl (PrimMultiArray UnliftedArray (PrimArray a)
ua) = Int -> Int -> Unfoldl Int
Unfoldl.intsInRange Int
0 (forall a. Enum a => a -> a
pred (forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray a)
ua))

toUnfoldlAt :: Prim prim => PrimMultiArray prim -> Int -> Unfoldl prim
toUnfoldlAt :: forall prim.
Prim prim =>
PrimMultiArray prim -> Int -> Unfoldl prim
toUnfoldlAt (PrimMultiArray UnliftedArray (PrimArray prim)
ua) Int
index = forall element.
PrimUnlifted element =>
UnliftedArray element
-> Int -> forall result. result -> (element -> result) -> result
UnliftedArray.at UnliftedArray (PrimArray prim)
ua Int
index forall (f :: * -> *) a. Alternative f => f a
empty forall prim. Prim prim => PrimArray prim -> Unfoldl prim
PrimArray.toElementsUnfoldl

toAssocsUnfoldlM :: (Monad m, Prim a) => PrimMultiArray a -> UnfoldlM m (Int, a)
toAssocsUnfoldlM :: forall (m :: * -> *) a.
(Monad m, Prim a) =>
PrimMultiArray a -> UnfoldlM m (Int, a)
toAssocsUnfoldlM PrimMultiArray a
pma =
  do
    Int
index <- forall (m :: * -> *) a.
Monad m =>
PrimMultiArray a -> UnfoldlM m Int
toIndicesUnfoldlM PrimMultiArray a
pma
    a
element <- forall (m :: * -> *) prim.
(Monad m, Prim prim) =>
PrimMultiArray prim -> Int -> UnfoldlM m prim
toUnfoldlAtM PrimMultiArray a
pma Int
index
    forall (m :: * -> *) a. Monad m => a -> m a
return (Int
index, a
element)

toIndicesUnfoldlM :: Monad m => PrimMultiArray a -> UnfoldlM m Int
toIndicesUnfoldlM :: forall (m :: * -> *) a.
Monad m =>
PrimMultiArray a -> UnfoldlM m Int
toIndicesUnfoldlM (PrimMultiArray UnliftedArray (PrimArray a)
ua) = forall (m :: * -> *). Monad m => Int -> Int -> UnfoldlM m Int
UnfoldlM.intsInRange Int
0 (forall a. Enum a => a -> a
pred (forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray a)
ua))

toUnfoldlAtM :: (Monad m, Prim prim) => PrimMultiArray prim -> Int -> UnfoldlM m prim
toUnfoldlAtM :: forall (m :: * -> *) prim.
(Monad m, Prim prim) =>
PrimMultiArray prim -> Int -> UnfoldlM m prim
toUnfoldlAtM (PrimMultiArray UnliftedArray (PrimArray prim)
ua) Int
index = forall element.
PrimUnlifted element =>
UnliftedArray element
-> Int -> forall result. result -> (element -> result) -> result
UnliftedArray.at UnliftedArray (PrimArray prim)
ua Int
index forall (f :: * -> *) a. Alternative f => f a
empty forall (m :: * -> *) prim.
(Monad m, Prim prim) =>
PrimArray prim -> UnfoldlM m prim
PrimArray.toElementsUnfoldlM

cerealGet :: Prim element => Cereal.Get Int -> Cereal.Get element -> Cereal.Get (PrimMultiArray element)
cerealGet :: forall element.
Prim element =>
Get Int -> Get element -> Get (PrimMultiArray element)
cerealGet Get Int
int Get element
element =
  do
    Int
size <- Get Int
int
    forall (m :: * -> *) a.
(Monad m, Prim a) =>
Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM Int
size (forall element.
Prim element =>
Get Int -> Get element -> Get (PrimArray element)
PrimArray.cerealGet Get Int
int Get element
element)

cerealGetAsInMemory :: Prim element => Cereal.Get Int -> Cereal.Get (PrimMultiArray element)
cerealGetAsInMemory :: forall element.
Prim element =>
Get Int -> Get (PrimMultiArray element)
cerealGetAsInMemory Get Int
int =
  do
    Int
size <- Get Int
int
    forall (m :: * -> *) a.
(Monad m, Prim a) =>
Int -> m (PrimArray a) -> m (PrimMultiArray a)
replicateM Int
size (forall element. Prim element => Get Int -> Get (PrimArray element)
PrimArray.cerealGetAsInMemory Get Int
int)

cerealPut :: Prim element => Cereal.Putter Int -> Cereal.Putter element -> Cereal.Putter (PrimMultiArray element)
cerealPut :: forall element.
Prim element =>
Putter Int -> Putter element -> Putter (PrimMultiArray element)
cerealPut Putter Int
int Putter element
element (PrimMultiArray UnliftedArray (PrimArray element)
outerArrayValue) =
  Put
size forall a. Semigroup a => a -> a -> a
<> Put
innerArrays
  where
    size :: Put
size = Putter Int
int (forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray element)
outerArrayValue)
    innerArrays :: Put
innerArrays = forall (m :: * -> *) a.
(Monad m, PrimUnlifted a) =>
(a -> m ()) -> UnliftedArray a -> m ()
UnliftedArray.traverse_ (forall element.
Prim element =>
Putter Int -> Putter element -> Putter (PrimArray element)
PrimArray.cerealPut Putter Int
int Putter element
element) UnliftedArray (PrimArray element)
outerArrayValue

cerealPutAsInMemory :: Prim element => Cereal.Putter Int -> Cereal.Putter (PrimMultiArray element)
cerealPutAsInMemory :: forall element.
Prim element =>
Putter Int -> Putter (PrimMultiArray element)
cerealPutAsInMemory Putter Int
int (PrimMultiArray UnliftedArray (PrimArray element)
outerArrayValue) =
  Put
size forall a. Semigroup a => a -> a -> a
<> Put
innerArrays
  where
    size :: Put
size = Putter Int
int (forall e. UnliftedArray e -> Int
sizeofUnliftedArray UnliftedArray (PrimArray element)
outerArrayValue)
    innerArrays :: Put
innerArrays = forall (m :: * -> *) a.
(Monad m, PrimUnlifted a) =>
(a -> m ()) -> UnliftedArray a -> m ()
UnliftedArray.traverse_ (forall element.
Prim element =>
Putter Int -> Putter (PrimArray element)
PrimArray.cerealPutAsInMemory Putter Int
int) UnliftedArray (PrimArray element)
outerArrayValue

{-|
Having a priorly computed array of inner dimension sizes,
e.g., using the 'PrimitiveExtras.PrimArray.indexCountsFold',
construct a fold over indexed elements into a multi-array of elements.

Thus it allows to construct it in two passes over the indexed elements.
-}
fold :: (Integral size, Prim size, Prim element) => PrimArray size -> Fold (Int, element) (PrimMultiArray element)
fold :: forall size element.
(Integral size, Prim size, Prim element) =>
PrimArray size -> Fold (Int, element) (PrimMultiArray element)
fold = forall size element.
(Integral size, Prim size, Prim element) =>
PrimArray size -> Fold (Int, element) (PrimMultiArray element)
Folds.primMultiArray