{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
-- |
-- Module      : Data.Massiv.Array.Delayed.Push
-- Copyright   : (c) Alexey Kuleshevich 2019-2021
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Array.Delayed.Push
  ( DL(..)
  , Array(..)
  , Loader
  , toLoadArray
  , makeLoadArrayS
  , makeLoadArray
  , unsafeMakeLoadArray
  , unsafeMakeLoadArrayAdjusted
  , fromStrideLoad
  , appendOuterM
  , concatOuterM
  ) where

import Control.Monad
import Control.Scheduler as S (traverse_)
import Data.Foldable as F
import Data.Massiv.Core.Common
import Prelude hiding (map, zipWith)

#include "massiv.h"

-- | Delayed load representation. Also known as Push array.
data DL = DL deriving Int -> DL -> ShowS
[DL] -> ShowS
DL -> String
(Int -> DL -> ShowS)
-> (DL -> String) -> ([DL] -> ShowS) -> Show DL
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [DL] -> ShowS
$cshowList :: [DL] -> ShowS
show :: DL -> String
$cshow :: DL -> String
showsPrec :: Int -> DL -> ShowS
$cshowsPrec :: Int -> DL -> ShowS
Show

type Loader e =
  forall s. Scheduler s () -- ^ Scheduler that will be used for loading
         -> Ix1 -- ^ Start loading at this linear index
         -> (Ix1 -> e -> ST s ()) -- ^ Linear element writing action
         -> (Ix1 -> Sz1 -> e -> ST s ()) -- ^ Linear region setting action
         -> ST s ()


data instance Array DL ix e = DLArray
  { Array DL ix e -> Comp
dlComp    :: !Comp
  , Array DL ix e -> Sz ix
dlSize    :: !(Sz ix)
  , Array DL ix e
-> forall s.
   Scheduler s ()
   -> Int
   -> (Int -> e -> ST s ())
   -> (Int -> Sz1 -> e -> ST s ())
   -> ST s ()
dlLoad    :: Loader e
  }

instance Strategy DL where
  getComp :: Array DL ix e -> Comp
getComp = Array DL ix e -> Comp
forall ix e. Array DL ix e -> Comp
dlComp
  {-# INLINE getComp #-}
  setComp :: Comp -> Array DL ix e -> Array DL ix e
setComp Comp
c Array DL ix e
arr = Array DL ix e
R:ArrayDLixe ix e
arr {dlComp :: Comp
dlComp = Comp
c}
  {-# INLINE setComp #-}


instance Index ix => Shape DL ix where
  maxLinearSize :: Array DL ix e -> Maybe Sz1
maxLinearSize = Sz1 -> Maybe Sz1
forall a. a -> Maybe a
Just (Sz1 -> Maybe Sz1)
-> (Array DL ix e -> Sz1) -> Array DL ix e -> Maybe Sz1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Sz1
forall ix. ix -> Sz ix
SafeSz (Int -> Sz1) -> (Array DL ix e -> Int) -> Array DL ix e -> Sz1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array DL ix e -> Int
forall ix r e. (Index ix, Size r) => Array r ix e -> Int
elemsCount
  {-# INLINE maxLinearSize #-}


instance Size DL where
  size :: Array DL ix e -> Sz ix
size = Array DL ix e -> Sz ix
forall ix e. Array DL ix e -> Sz ix
dlSize
  {-# INLINE size #-}
  unsafeResize :: Sz ix' -> Array DL ix e -> Array DL ix' e
unsafeResize !Sz ix'
sz !Array DL ix e
arr = Array DL ix e
R:ArrayDLixe ix e
arr { dlSize :: Sz ix'
dlSize = Sz ix'
sz }
  {-# INLINE unsafeResize #-}

instance Semigroup (Array DL Ix1 e) where
  <> :: Array DL Int e -> Array DL Int e -> Array DL Int e
(<>) = Array DL Int e -> Array DL Int e -> Array DL Int e
forall e. Array DL Int e -> Array DL Int e -> Array DL Int e
mappendDL
  {-# INLINE (<>) #-}

instance Monoid (Array DL Ix1 e) where
  mempty :: Array DL Int e
mempty = DLArray :: forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray {dlComp :: Comp
dlComp = Comp
forall a. Monoid a => a
mempty, dlSize :: Sz1
dlSize = Sz1
forall ix. Index ix => Sz ix
zeroSz, dlLoad :: Loader e
dlLoad = \Scheduler s ()
_ Int
_ Int -> e -> ST s ()
_ Int -> Sz1 -> e -> ST s ()
_ -> () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()}
  {-# INLINE mempty #-}
  mappend :: Array DL Int e -> Array DL Int e -> Array DL Int e
mappend = Array DL Int e -> Array DL Int e -> Array DL Int e
forall e. Array DL Int e -> Array DL Int e -> Array DL Int e
mappendDL
  {-# INLINE mappend #-}
  mconcat :: [Array DL Int e] -> Array DL Int e
mconcat [] = Array DL Int e
forall a. Monoid a => a
mempty
  mconcat [Array DL Int e
x] = Array DL Int e
x
  mconcat [Array DL Int e
x, Array DL Int e
y] = Array DL Int e
x Array DL Int e -> Array DL Int e -> Array DL Int e
forall a. Semigroup a => a -> a -> a
<> Array DL Int e
y
  mconcat [Array DL Int e]
xs = [Array DL Int e] -> Array DL Int e
forall e. [Array DL Int e] -> Array DL Int e
mconcatDL [Array DL Int e]
xs
  {-# INLINE mconcat #-}

mconcatDL :: forall e . [Array DL Ix1 e] -> Array DL Ix1 e
mconcatDL :: [Array DL Int e] -> Array DL Int e
mconcatDL ![Array DL Int e]
arrs =
  DLArray :: forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray {dlComp :: Comp
dlComp = (Array DL Int e -> Comp) -> [Array DL Int e] -> Comp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Array DL Int e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
getComp [Array DL Int e]
arrs, dlSize :: Sz1
dlSize = Int -> Sz1
forall ix. ix -> Sz ix
SafeSz Int
k, dlLoad :: Loader e
dlLoad = Loader e
load}
  where
    !k :: Int
k = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 (Sz1 -> Int
forall ix. Sz ix -> ix
unSz (Sz1 -> Int) -> (Array DL Int e -> Sz1) -> Array DL Int e -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array DL Int e -> Sz1
forall r ix e. Size r => Array r ix e -> Sz ix
size (Array DL Int e -> Int) -> [Array DL Int e] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Array DL Int e]
arrs)
    load :: forall s .
      Scheduler s () -> Ix1 -> (Ix1 -> e -> ST s ()) -> (Ix1 -> Sz1 -> e -> ST s ()) -> ST s ()
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet =
      let loadArr :: Int -> Array DL Int e -> ST s Int
loadArr !Int
startAtCur DLArray {dlSize = SafeSz kCur, dlLoad} = do
            let !endAtCur :: Int
endAtCur = Int
startAtCur Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kCur
            Scheduler s () -> ST s () -> ST s ()
forall s (m :: * -> *).
MonadPrimBase s m =>
Scheduler s () -> m () -> m ()
scheduleWork_ Scheduler s ()
scheduler (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
Loader e
dlLoad Scheduler s ()
scheduler Int
startAtCur Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet
            Int -> ST s Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
endAtCur
          {-# INLINE loadArr #-}
       in (Int -> Array DL Int e -> ST s Int)
-> Int -> [Array DL Int e] -> ST s ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ Int -> Array DL Int e -> ST s Int
loadArr Int
startAt [Array DL Int e]
arrs
    {-# INLINE load #-}
{-# INLINE mconcatDL #-}


mappendDL :: forall e . Array DL Ix1 e -> Array DL Ix1 e -> Array DL Ix1 e
mappendDL :: Array DL Int e -> Array DL Int e -> Array DL Int e
mappendDL (DLArray c1 sz1 load1) (DLArray c2 sz2 load2) =
  DLArray :: forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray {dlComp :: Comp
dlComp = Comp
c1 Comp -> Comp -> Comp
forall a. Semigroup a => a -> a -> a
<> Comp
c2, dlSize :: Sz1
dlSize = Int -> Sz1
forall ix. ix -> Sz ix
SafeSz (Int
k1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k2), dlLoad :: Loader e
dlLoad = Loader e
load}
  where
    !k1 :: Int
k1 = Sz1 -> Int
forall ix. Sz ix -> ix
unSz Sz1
sz1
    !k2 :: Int
k2 = Sz1 -> Int
forall ix. Sz ix -> ix
unSz Sz1
sz2
    load :: forall s.
      Scheduler s () -> Ix1 -> (Ix1 -> e -> ST s ()) -> (Ix1 -> Sz1 -> e -> ST s ()) -> ST s ()
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler !Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet = do
      Scheduler s () -> ST s () -> ST s ()
forall s (m :: * -> *).
MonadPrimBase s m =>
Scheduler s () -> m () -> m ()
scheduleWork_ Scheduler s ()
scheduler (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
Loader e
load1 Scheduler s ()
scheduler Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet
      Scheduler s () -> ST s () -> ST s ()
forall s (m :: * -> *).
MonadPrimBase s m =>
Scheduler s () -> m () -> m ()
scheduleWork_ Scheduler s ()
scheduler (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
Loader e
load2 Scheduler s ()
scheduler (Int
startAt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k1) Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet
    {-# INLINE load #-}
{-# INLINE mappendDL #-}

-- | Append two arrays together along the outer most dimension. Inner dimensions must
-- agree, otherwise `SizeMismatchException`.
--
-- @since 0.4.4
appendOuterM ::
     forall ix e m. (Index ix, MonadThrow m)
  => Array DL ix e
  -> Array DL ix e
  -> m (Array DL ix e)
appendOuterM :: Array DL ix e -> Array DL ix e -> m (Array DL ix e)
appendOuterM (DLArray c1 sz1 load1) (DLArray c2 sz2 load2) = do
  let (!Sz1
i1, !Sz (Lower ix)
szl1) = Sz ix -> (Sz1, Sz (Lower ix))
forall ix. Index ix => Sz ix -> (Sz1, Sz (Lower ix))
unconsSz Sz ix
sz1
      (!Sz1
i2, !Sz (Lower ix)
szl2) = Sz ix -> (Sz1, Sz (Lower ix))
forall ix. Index ix => Sz ix -> (Sz1, Sz (Lower ix))
unconsSz Sz ix
sz2
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Sz (Lower ix)
szl1 Sz (Lower ix) -> Sz (Lower ix) -> Bool
forall a. Eq a => a -> a -> Bool
== Sz (Lower ix)
szl2) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ SizeException -> m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM (SizeException -> m ()) -> SizeException -> m ()
forall a b. (a -> b) -> a -> b
$ Sz ix -> Sz ix -> SizeException
forall ix. Index ix => Sz ix -> Sz ix -> SizeException
SizeMismatchException Sz ix
sz1 Sz ix
sz2
  Array DL ix e -> m (Array DL ix e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array DL ix e -> m (Array DL ix e))
-> Array DL ix e -> m (Array DL ix e)
forall a b. (a -> b) -> a -> b
$
    DLArray :: forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray {dlComp :: Comp
dlComp = Comp
c1 Comp -> Comp -> Comp
forall a. Semigroup a => a -> a -> a
<> Comp
c2, dlSize :: Sz ix
dlSize = Sz1 -> Sz (Lower ix) -> Sz ix
forall ix. Index ix => Sz1 -> Sz (Lower ix) -> Sz ix
consSz ((Int -> Int -> Int) -> Sz1 -> Sz1 -> Sz1
forall ix.
Index ix =>
(Int -> Int -> Int) -> Sz ix -> Sz ix -> Sz ix
liftSz2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Sz1
i1 Sz1
i2) Sz (Lower ix)
szl1, dlLoad :: Loader e
dlLoad = Loader e
load}
  where
    load :: Loader e
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler !Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet = do
      Scheduler s () -> ST s () -> ST s ()
forall s (m :: * -> *).
MonadPrimBase s m =>
Scheduler s () -> m () -> m ()
scheduleWork_ Scheduler s ()
scheduler (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
Loader e
load1 Scheduler s ()
scheduler Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet
      Scheduler s () -> ST s () -> ST s ()
forall s (m :: * -> *).
MonadPrimBase s m =>
Scheduler s () -> m () -> m ()
scheduleWork_ Scheduler s ()
scheduler (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
Loader e
load2 Scheduler s ()
scheduler (Int
startAt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem Sz ix
sz1) Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet
    {-# INLINE load #-}
{-# INLINE appendOuterM #-}

-- | Concat arrays together along the outer most dimension. Inner dimensions must agree
-- for all arrays in the list, otherwise `SizeMismatchException`.
--
-- @since 0.4.4
concatOuterM ::
     forall ix e m. (Index ix, MonadThrow m)
  => [Array DL ix e]
  -> m (Array DL ix e)
concatOuterM :: [Array DL ix e] -> m (Array DL ix e)
concatOuterM =
  \case
    []     -> Array DL ix e -> m (Array DL ix e)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Array DL ix e
forall r ix e. Load r ix e => Array r ix e
empty
    (Array DL ix e
x:[Array DL ix e]
xs) -> (Array DL ix e -> Array DL ix e -> m (Array DL ix e))
-> Array DL ix e -> [Array DL ix e] -> m (Array DL ix e)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
F.foldlM Array DL ix e -> Array DL ix e -> m (Array DL ix e)
forall ix e (m :: * -> *).
(Index ix, MonadThrow m) =>
Array DL ix e -> Array DL ix e -> m (Array DL ix e)
appendOuterM Array DL ix e
x [Array DL ix e]
xs
{-# INLINE concatOuterM #-}


-- | Describe how an array should be loaded into memory sequentially. For parallelizable
-- version see `makeLoadArray`.
--
-- @since 0.3.1
makeLoadArrayS ::
     forall ix e. Index ix
  => Sz ix
  -- ^ Size of the resulting array
  -> e
  -- ^ Default value to use for all cells that might have been ommitted by the writing function
  -> (forall m. Monad m => (ix -> e -> m Bool) -> m ())
  -- ^ Writing function that described which elements to write into the target array.
  -> Array DL ix e
makeLoadArrayS :: Sz ix
-> e
-> (forall (m :: * -> *). Monad m => (ix -> e -> m Bool) -> m ())
-> Array DL ix e
makeLoadArrayS Sz ix
sz e
defVal forall (m :: * -> *). Monad m => (ix -> e -> m Bool) -> m ()
writer = Comp -> Sz ix -> Loader e -> Array DL ix e
forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray Comp
Seq Sz ix
sz Loader e
load
  where
    load :: forall s.
      Scheduler s () -> Ix1 -> (Ix1 -> e -> ST s ()) -> (Ix1 -> Sz1 -> e -> ST s ()) -> ST s ()
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
_scheduler !Int
startAt Int -> e -> ST s ()
uWrite Int -> Sz1 -> e -> ST s ()
uSet = do
      Int -> Sz1 -> e -> ST s ()
uSet Int
startAt (Sz ix -> Sz1
forall ix. Index ix => Sz ix -> Sz1
toLinearSz Sz ix
sz) e
defVal
      let safeWrite :: ix -> e -> ST s Bool
safeWrite !ix
ix !e
e
            | Sz ix -> ix -> Bool
forall ix. Index ix => Sz ix -> ix -> Bool
isSafeIndex Sz ix
sz ix
ix = Int -> e -> ST s ()
uWrite (Int
startAt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Sz ix -> ix -> Int
forall ix. Index ix => Sz ix -> ix -> Int
toLinearIndex Sz ix
sz ix
ix) e
e ST s () -> ST s Bool -> ST s Bool
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> ST s Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
            | Bool
otherwise = Bool -> ST s Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
          {-# INLINE safeWrite #-}
      (ix -> e -> ST s Bool) -> ST s ()
forall (m :: * -> *). Monad m => (ix -> e -> m Bool) -> m ()
writer ix -> e -> ST s Bool
safeWrite
    {-# INLINE load #-}
{-# INLINE makeLoadArrayS #-}

-- | Specify how an array should be loaded into memory. Unlike `makeLoadArrayS`, loading
-- function accepts a scheduler, thus can be parallelized. If you need an unsafe version
-- of this function see `unsafeMakeLoadArray`.
--
-- @since 0.4.0
makeLoadArray ::
     forall ix e. Index ix
  => Comp
  -- ^ Computation strategy to use. Directly affects the scheduler that gets created for
  -- the loading function.
  -> Sz ix
  -- ^ Size of the resulting array
  -> e
  -- ^ Default value to use for all cells that might have been ommitted by the writing function
  -> (forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ())
  -- ^ Writing function that described which elements to write into the target array. It
  -- accepts a scheduler, that can be used for parallelization, as well as a safe element
  -- writing function.
  -> Array DL ix e
makeLoadArray :: Comp
-> Sz ix
-> e
-> (forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ())
-> Array DL ix e
makeLoadArray Comp
comp Sz ix
sz e
defVal forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ()
writer = Comp -> Sz ix -> Loader e -> Array DL ix e
forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray Comp
comp Sz ix
sz Loader e
load
  where
    load :: forall s.
      Scheduler s () -> Ix1 -> (Ix1 -> e -> ST s ()) -> (Ix1 -> Sz1 -> e -> ST s ()) -> ST s ()
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler !Int
startAt Int -> e -> ST s ()
uWrite Int -> Sz1 -> e -> ST s ()
uSet = do
      Int -> Sz1 -> e -> ST s ()
uSet Int
startAt (Sz ix -> Sz1
forall ix. Index ix => Sz ix -> Sz1
toLinearSz Sz ix
sz) e
defVal
      let safeWrite :: ix -> e -> ST s Bool
safeWrite !ix
ix !e
e
            | Sz ix -> ix -> Bool
forall ix. Index ix => Sz ix -> ix -> Bool
isSafeIndex Sz ix
sz ix
ix = Bool
True Bool -> ST s () -> ST s Bool
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Int -> e -> ST s ()
uWrite (Int
startAt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Sz ix -> ix -> Int
forall ix. Index ix => Sz ix -> ix -> Int
toLinearIndex Sz ix
sz ix
ix) e
e
            | Bool
otherwise = Bool -> ST s Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
          {-# INLINE safeWrite #-}
      Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ()
forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ()
writer Scheduler s ()
scheduler ix -> e -> ST s Bool
safeWrite
    {-# INLINE load #-}
{-# INLINE makeLoadArray #-}

-- | Specify how an array can be loaded/computed through creation of a `DL` array. Unlike
-- `makeLoadArrayS` or `makeLoadArray` this function is unsafe, since there is no
-- guarantee that all elements will be initialized and the supplied element writing
-- function does not perform any bounds checking.
--
-- @since 0.3.1
unsafeMakeLoadArray ::
     forall ix e. Index ix
  => Comp
  -- ^ Computation strategy to use. Directly affects the scheduler that gets created for
  -- the loading function.
  -> Sz ix
  -- ^ Size of the array
  -> Maybe e
  -- ^ An element to use for initialization of the mutable array that will be created in
  -- the future
  -> (forall s. Scheduler s () -> Ix1 -> (Ix1 -> e -> ST s ()) -> ST s ())
  -- ^ This function accepts:
  --
  -- * A scheduler that can be used for parallelization of loading
  --
  -- * Linear index at which this load array will start (an offset that should be added to
  --   the linear writng function)
  --
  -- * Linear element writing function
  -> Array DL ix e
unsafeMakeLoadArray :: Comp
-> Sz ix
-> Maybe e
-> (forall s.
    Scheduler s () -> Int -> (Int -> e -> ST s ()) -> ST s ())
-> Array DL ix e
unsafeMakeLoadArray Comp
comp Sz ix
sz Maybe e
mDefVal forall s. Scheduler s () -> Int -> (Int -> e -> ST s ()) -> ST s ()
writer = Comp -> Sz ix -> Loader e -> Array DL ix e
forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray Comp
comp Sz ix
sz Loader e
load
  where
    load :: Loader e
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler Int
startAt Int -> e -> ST s ()
uWrite Int -> Sz1 -> e -> ST s ()
uSet = do
      (e -> ST s ()) -> Maybe e -> ST s ()
forall (f :: * -> *) (t :: * -> *) a.
(Applicative f, Foldable t) =>
(a -> f ()) -> t a -> f ()
S.traverse_ (Int -> Sz1 -> e -> ST s ()
uSet Int
startAt (Sz ix -> Sz1
forall ix. Index ix => Sz ix -> Sz1
toLinearSz Sz ix
sz)) Maybe e
mDefVal
      Scheduler s () -> Int -> (Int -> e -> ST s ()) -> ST s ()
forall s. Scheduler s () -> Int -> (Int -> e -> ST s ()) -> ST s ()
writer Scheduler s ()
scheduler Int
startAt Int -> e -> ST s ()
uWrite
    {-# INLINE load #-}
{-# INLINE unsafeMakeLoadArray #-}

-- | Same as `unsafeMakeLoadArray`, except will ensure that starting index is correctly
-- adjusted. Which means the writing function gets one less argument.
--
-- @since 0.5.2
unsafeMakeLoadArrayAdjusted ::
     forall ix e. Index ix
  => Comp
  -> Sz ix
  -> Maybe e
  -> (forall s. Scheduler s () -> (Ix1 -> e -> ST s ()) -> ST s ())
  -> Array DL ix e
unsafeMakeLoadArrayAdjusted :: Comp
-> Sz ix
-> Maybe e
-> (forall s. Scheduler s () -> (Int -> e -> ST s ()) -> ST s ())
-> Array DL ix e
unsafeMakeLoadArrayAdjusted Comp
comp Sz ix
sz Maybe e
mDefVal forall s. Scheduler s () -> (Int -> e -> ST s ()) -> ST s ()
writer = Comp -> Sz ix -> Loader e -> Array DL ix e
forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray Comp
comp Sz ix
sz Loader e
load
  where
    load :: forall s.
      Scheduler s () -> Ix1 -> (Ix1 -> e -> ST s ()) -> (Ix1 -> Sz1 -> e -> ST s ()) -> ST s ()
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler !Int
startAt Int -> e -> ST s ()
uWrite Int -> Sz1 -> e -> ST s ()
dlSet = do
      (e -> ST s ()) -> Maybe e -> ST s ()
forall (f :: * -> *) (t :: * -> *) a.
(Applicative f, Foldable t) =>
(a -> f ()) -> t a -> f ()
S.traverse_ (Int -> Sz1 -> e -> ST s ()
dlSet Int
startAt (Sz ix -> Sz1
forall ix. Index ix => Sz ix -> Sz1
toLinearSz Sz ix
sz)) Maybe e
mDefVal
      Scheduler s () -> (Int -> e -> ST s ()) -> ST s ()
forall s. Scheduler s () -> (Int -> e -> ST s ()) -> ST s ()
writer Scheduler s ()
scheduler (\Int
i -> Int -> e -> ST s ()
uWrite (Int
startAt Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i))
    {-# INLINE load #-}
{-# INLINE unsafeMakeLoadArrayAdjusted #-}

-- | Convert any `Load`able array into `DL` representation.
--
-- @since 0.3.0
toLoadArray ::
     forall r ix e. (Size r, Load r ix e)
  => Array r ix e
  -> Array DL ix e
toLoadArray :: Array r ix e -> Array DL ix e
toLoadArray Array r ix e
arr = Comp -> Sz ix -> Loader e -> Array DL ix e
forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray (Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
getComp Array r ix e
arr) Sz ix
sz Loader e
load
  where
    !sz :: Sz ix
sz = Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
size Array r ix e
arr
    load :: forall s.
      Scheduler s () -> Ix1 -> (Ix1 -> e -> ST s ()) -> (Ix1 -> Sz1 -> e -> ST s ()) -> ST s ()
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler !Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
dlSet =
      Scheduler s ()
-> Array r ix e
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
forall r ix e s.
Load r ix e =>
Scheduler s ()
-> Array r ix e
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
iterArrayLinearWithSetST_ Scheduler s ()
scheduler Array r ix e
arr (Int -> e -> ST s ()
dlWrite (Int -> e -> ST s ()) -> (Int -> Int) -> Int -> e -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
startAt)) (\Int
offset -> Int -> Sz1 -> e -> ST s ()
dlSet (Int
offset Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
startAt))
    {-# INLINE load #-}
{-# INLINE[1] toLoadArray #-}
{-# RULES "toLoadArray/id" toLoadArray = id #-}

-- | Convert an array that can be loaded with stride into `DL` representation.
--
-- @since 0.3.0
fromStrideLoad ::
     forall r ix e. (StrideLoad r ix e)
  => Stride ix
  -> Array r ix e
  -> Array DL ix e
fromStrideLoad :: Stride ix -> Array r ix e -> Array DL ix e
fromStrideLoad Stride ix
stride Array r ix e
arr =
  Comp -> Sz ix -> Loader e -> Array DL ix e
forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray (Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
getComp Array r ix e
arr) Sz ix
newsz Loader e
load
  where
    !newsz :: Sz ix
newsz = Stride ix -> Sz ix -> Sz ix
forall ix. Index ix => Stride ix -> Sz ix -> Sz ix
strideSize Stride ix
stride (Array r ix e -> Sz ix
forall r ix e. Shape r ix => Array r ix e -> Sz ix
outerSize Array r ix e
arr)
    load :: Loader e
    load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler !Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
_ =
      Scheduler s ()
-> Stride ix
-> Sz ix
-> Array r ix e
-> (Int -> e -> ST s ())
-> ST s ()
forall r ix e s.
StrideLoad r ix e =>
Scheduler s ()
-> Stride ix
-> Sz ix
-> Array r ix e
-> (Int -> e -> ST s ())
-> ST s ()
iterArrayLinearWithStrideST_ Scheduler s ()
scheduler Stride ix
stride Sz ix
newsz Array r ix e
arr (\ !Int
i -> Int -> e -> ST s ()
dlWrite (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
startAt))
    {-# INLINE load #-}
{-# INLINE fromStrideLoad #-}

instance Index ix => Load DL ix e where
  makeArrayLinear :: Comp -> Sz ix -> (Int -> e) -> Array DL ix e
makeArrayLinear Comp
comp Sz ix
sz Int -> e
f = Comp -> Sz ix -> Loader e -> Array DL ix e
forall ix e. Comp -> Sz ix -> Loader e -> Array DL ix e
DLArray Comp
comp Sz ix
sz Loader e
load
    where
      load :: Loader e
      load :: Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
load Scheduler s ()
scheduler Int
startAt Int -> e -> ST s ()
dlWrite Int -> Sz1 -> e -> ST s ()
_ =
        Scheduler s ()
-> Int
-> Int
-> (Int -> ST s e)
-> (Int -> e -> ST s ())
-> ST s ()
forall s (m :: * -> *) b c.
MonadPrimBase s m =>
Scheduler s ()
-> Int -> Int -> (Int -> m b) -> (Int -> b -> m c) -> m ()
splitLinearlyWithStartAtM_ Scheduler s ()
scheduler Int
startAt (Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem Sz ix
sz) (e -> ST s e
forall (f :: * -> *) a. Applicative f => a -> f a
pure (e -> ST s e) -> (Int -> e) -> Int -> ST s e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> e
f) Int -> e -> ST s ()
dlWrite
      {-# INLINE load #-}
  {-# INLINE makeArrayLinear #-}
  replicate :: Comp -> Sz ix -> e -> Array DL ix e
replicate Comp
comp !Sz ix
sz !e
e = Comp
-> Sz ix
-> e
-> (forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ())
-> Array DL ix e
forall ix e.
Index ix =>
Comp
-> Sz ix
-> e
-> (forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ())
-> Array DL ix e
makeLoadArray Comp
comp Sz ix
sz e
e ((forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ())
 -> Array DL ix e)
-> (forall s. Scheduler s () -> (ix -> e -> ST s Bool) -> ST s ())
-> Array DL ix e
forall a b. (a -> b) -> a -> b
$ \Scheduler s ()
_ ix -> e -> ST s Bool
_ -> () -> ST s ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  {-# INLINE replicate #-}
  iterArrayLinearWithSetST_ :: Scheduler s ()
-> Array DL ix e
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
iterArrayLinearWithSetST_ Scheduler s ()
scheduler DLArray {dlLoad} = Scheduler s ()
-> Int
-> (Int -> e -> ST s ())
-> (Int -> Sz1 -> e -> ST s ())
-> ST s ()
Loader e
dlLoad Scheduler s ()
scheduler Int
0
  {-# INLINE iterArrayLinearWithSetST_ #-}

instance Index ix => Functor (Array DL ix) where
  fmap :: (a -> b) -> Array DL ix a -> Array DL ix b
fmap a -> b
f Array DL ix a
arr = Array DL ix a
R:ArrayDLixe ix a
arr {dlLoad :: Loader b
dlLoad = Array DL ix a
-> (a -> b)
-> Scheduler s ()
-> Int
-> (Int -> b -> ST s ())
-> (Int -> Sz1 -> b -> ST s ())
-> ST s ()
forall ix a b s.
Array DL ix a
-> (a -> b)
-> Scheduler s ()
-> Int
-> (Int -> b -> ST s ())
-> (Int -> Sz1 -> b -> ST s ())
-> ST s ()
loadFunctor Array DL ix a
arr a -> b
f}
  {-# INLINE fmap #-}
  <$ :: a -> Array DL ix b -> Array DL ix a
(<$) = a -> Array DL ix b -> Array DL ix a
forall ix a b. Index ix => a -> Array DL ix b -> Array DL ix a
overwriteFunctor
  {-# INLINE (<$) #-}

overwriteFunctor :: forall ix a b. Index ix => a -> Array DL ix b -> Array DL ix a
overwriteFunctor :: a -> Array DL ix b -> Array DL ix a
overwriteFunctor a
e Array DL ix b
arr = Array DL ix b
R:ArrayDLixe ix b
arr {dlLoad :: Loader a
dlLoad = Loader a
load}
  where
    load :: Loader a
    load :: Scheduler s ()
-> Int
-> (Int -> a -> ST s ())
-> (Int -> Sz1 -> a -> ST s ())
-> ST s ()
load Scheduler s ()
_ !Int
startAt Int -> a -> ST s ()
_ Int -> Sz1 -> a -> ST s ()
dlSet = Int -> Sz1 -> a -> ST s ()
dlSet Int
startAt (Array DL ix b -> Sz1
forall r ix e. Shape r ix => Array r ix e -> Sz1
linearSize Array DL ix b
arr) a
e
    {-# INLINE load #-}
{-# INLINE overwriteFunctor #-}


loadFunctor ::
     Array DL ix a
  -> (a -> b)
  -> Scheduler s ()
  -> Ix1
  -> (Ix1 -> b -> ST s ())
  -> (Ix1 -> Sz1 -> b -> ST s ())
  -> ST s ()
loadFunctor :: Array DL ix a
-> (a -> b)
-> Scheduler s ()
-> Int
-> (Int -> b -> ST s ())
-> (Int -> Sz1 -> b -> ST s ())
-> ST s ()
loadFunctor Array DL ix a
arr a -> b
f Scheduler s ()
scheduler Int
startAt Int -> b -> ST s ()
uWrite Int -> Sz1 -> b -> ST s ()
uSet =
  Array DL ix a
-> Scheduler s ()
-> Int
-> (Int -> a -> ST s ())
-> (Int -> Sz1 -> a -> ST s ())
-> ST s ()
forall ix e. Array DL ix e -> Loader e
dlLoad Array DL ix a
arr Scheduler s ()
scheduler Int
startAt (\ !Int
i a
e -> Int -> b -> ST s ()
uWrite Int
i (a -> b
f a
e)) (\Int
o Sz1
sz a
e -> Int -> Sz1 -> b -> ST s ()
uSet Int
o Sz1
sz (a -> b
f a
e))
{-# INLINE loadFunctor #-}