module Data.Repa.Eval.Generic.Par.Reduction
        ( foldAll
        , foldInner)
where
import Data.Repa.Eval.Gang
import GHC.Exts
import qualified Data.Repa.Eval.Generic.Seq.Reduction     as Seq
import Data.IORef


-- | Parallel tree reduction of an array to a single value. Each thread takes an
--   equally sized chunk of the data and computes a partial sum. The main thread
--   then reduces the array of partial sums to the final result.
--
--   We don't require that the initial value be a neutral element, so each thread
--   computes a fold1 on its chunk of the data, and the seed element is only
--   applied in the final reduction step.
--
foldAll :: Gang                -- ^ Gang to run the operation on.
        -> (Int# -> a)         -- ^ Function to get an element from the source.
        -> (a -> a -> a)       -- ^ Binary associative combining function.
        -> a                   -- ^ Starting value.
        -> Int#                -- ^ Number of elements.
        -> IO a

foldAll :: forall a. Gang -> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> IO a
foldAll !Gang
gang Int# -> a
f a -> a -> a
c !a
z !Int#
len
 | Int#
1# <- Int#
len Int# -> Int# -> Int#
==# Int#
0#   = a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
z
 | Bool
otherwise   
 = do   IORef a
result  <- a -> IO (IORef a)
forall a. a -> IO (IORef a)
newIORef a
z

        Gang -> (Int# -> IO ()) -> IO ()
gangIO Gang
gang
         ((Int# -> IO ()) -> IO ()) -> (Int# -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int#
tid -> IORef a -> Int# -> Int# -> IO ()
fill IORef a
result (Int# -> Int#
split Int#
tid) (Int# -> Int#
split (Int#
tid Int# -> Int# -> Int#
+# Int#
1#))

        IORef a -> IO a
forall a. IORef a -> IO a
readIORef IORef a
result
  where
        !threads :: Int#
threads    = Gang -> Int#
gangSize Gang
gang
        !step :: Int#
step       = (Int#
len Int# -> Int# -> Int#
+# Int#
threads Int# -> Int# -> Int#
-# Int#
1#) Int# -> Int# -> Int#
`quotInt#` Int#
threads

        split :: Int# -> Int#
split !Int#
ix   = Int#
len Int# -> Int# -> Int#
`foldAll_min` (Int#
ix Int# -> Int# -> Int#
*# Int#
step)

        foldAll_min :: Int# -> Int# -> Int#
foldAll_min Int#
x Int#
y
         = case Int#
x Int# -> Int# -> Int#
<=# Int#
y of
                Int#
1# -> Int#
x 
                Int#
_  -> Int#
y
        {-# NOINLINE foldAll_min #-}
        --  NOINLINE to hide the branch from the simplifier.

        foldAll_combine :: IORef a -> a -> IO ()
foldAll_combine IORef a
result a
x 
         = IORef a -> (a -> (a, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef a
result (\a
x' -> (a -> a -> a
c a
x a
x', ()))
        {-# NOINLINE foldAll_combine #-}
        --  NOINLINE because we want to keep the final use of the combining 
        --  function separate from the main use in 'fill'. If the combining
        --  function contains a branch then the combination of two instances
        --  can cause code explosion.

        fill :: IORef a -> Int# -> Int# -> IO ()
fill !IORef a
result !Int#
start !Int#
end
         | Int#
1# <- Int#
start Int# -> Int# -> Int#
>=# Int#
end = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
         | Bool
otherwise    
         = let  !x :: a
x      = (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
Seq.foldRange Int# -> a
f a -> a -> a
c (Int# -> a
f Int#
start) (Int#
start Int# -> Int# -> Int#
+# Int#
1#) Int#
end
           in   IORef a -> a -> IO ()
foldAll_combine IORef a
result a
x
        {-# INLINE fill #-}

{-# INLINE [1] foldAll #-}


-- | Parallel reduction of a multidimensional array along the innermost dimension.
--   Each output value is computed by a single thread, with the output values
--   distributed evenly amongst the available threads.
foldInner 
        :: Gang                 -- ^ Gang to run the operation on.
        -> (Int# -> a -> IO ()) -- ^ Function to write into the result buffer.
        -> (Int# -> a)          -- ^ Function to get an element from the source.
        -> (a -> a -> a)        -- ^ Binary associative combination operator.
        -> a                    -- ^ Neutral starting value.
        -> Int#                 -- ^ Total length of source.
        -> Int#                 -- ^ Inner dimension (length to fold over).
        -> IO ()

foldInner :: forall a.
Gang
-> (Int# -> a -> IO ())
-> (Int# -> a)
-> (a -> a -> a)
-> a
-> Int#
-> Int#
-> IO ()
foldInner Gang
gang Int# -> a -> IO ()
write Int# -> a
f a -> a -> a
c !a
r !Int#
len !Int#
n
 = Gang -> (Int# -> IO ()) -> IO ()
gangIO Gang
gang
 ((Int# -> IO ()) -> IO ()) -> (Int# -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int#
tid -> Int# -> Int# -> IO ()
fill (Int# -> Int#
split Int#
tid) (Int# -> Int#
split (Int#
tid Int# -> Int# -> Int#
+# Int#
1#))
  where
        !threads :: Int#
threads = Gang -> Int#
gangSize Gang
gang
        !step :: Int#
step    = (Int#
len Int# -> Int# -> Int#
+# Int#
threads Int# -> Int# -> Int#
-# Int#
1#) Int# -> Int# -> Int#
`quotInt#` Int#
threads

        split :: Int# -> Int#
split !Int#
ix 
         = let !ix' :: Int#
ix' = Int#
ix Int# -> Int# -> Int#
*# Int#
step
           in  case Int#
len Int# -> Int# -> Int#
<# Int#
ix' of
                Int#
1# -> Int#
len
                Int#
_  -> Int#
ix'
        {-# INLINE split #-}

        fill :: Int# -> Int# -> IO ()
fill !Int#
start !Int#
end 
         = Int# -> Int# -> IO ()
iter Int#
start (Int#
start Int# -> Int# -> Int#
*# Int#
n)
         where
          iter :: Int# -> Int# -> IO ()
iter !Int#
sh !Int#
sz 
           | Int#
1# <- Int#
sh Int# -> Int# -> Int#
>=# Int#
end = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
           | Bool
otherwise 
           = do let !next :: Int#
next = Int#
sz Int# -> Int# -> Int#
+# Int#
n
                Int# -> a -> IO ()
write Int#
sh ((Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
Seq.foldRange Int# -> a
f a -> a -> a
c a
r Int#
sz Int#
next)
                Int# -> Int# -> IO ()
iter (Int#
sh Int# -> Int# -> Int#
+# Int#
1#) Int#
next
          {-# INLINE iter #-}
        {-# INLINE fill #-}

{-# INLINE [1] foldInner #-}