module Data.Repa.Eval.Generic.Par.Reduction
        ( foldAll
        , foldInner)
where
import Data.Repa.Eval.Gang
import GHC.Exts
import qualified Data.Repa.Eval.Generic.Seq.Reduction     as Seq
import Data.IORef


-- | Parallel tree reduction of an array to a single value. Each thread takes an
--   equally sized chunk of the data and computes a partial sum. The main thread
--   then reduces the array of partial sums to the final result.
--
--   We don't require that the initial value be a neutral element, so each thread
--   computes a fold1 on its chunk of the data, and the seed element is only
--   applied in the final reduction step.
--
foldAll :: Gang                -- ^ Gang to run the operation on.
        -> (Int# -> a)         -- ^ Function to get an element from the source.
        -> (a -> a -> a)       -- ^ Binary associative combining function.
        -> a                   -- ^ Starting value.
        -> Int#                -- ^ Number of elements.
        -> IO a

foldAll !gang f c !z !len
 | 1# <- len ==# 0#   = return z
 | otherwise   
 = do   result  <- newIORef z

        gangIO gang
         $ \tid -> fill result (split tid) (split (tid +# 1#))

        readIORef result
  where
        !threads    = gangSize gang
        !step       = (len +# threads -# 1#) `quotInt#` threads

        split !ix   = len `foldAll_min` (ix *# step)

        foldAll_min x y
         = case x <=# y of
                1# -> x 
                _  -> y
        {-# NOINLINE foldAll_min #-}
        --  NOINLINE to hide the branch from the simplifier.

        foldAll_combine result x 
         = atomicModifyIORef result (\x' -> (c x x', ()))
        {-# NOINLINE foldAll_combine #-}
        --  NOINLINE because we want to keep the final use of the combining 
        --  function separate from the main use in 'fill'. If the combining
        --  function contains a branch then the combination of two instances
        --  can cause code explosion.

        fill !result !start !end
         | 1# <- start >=# end = return ()
         | otherwise    
         = let  !x      = Seq.foldRange f c (f start) (start +# 1#) end
           in   foldAll_combine result x
        {-# INLINE fill #-}

{-# INLINE [1] foldAll #-}


-- | Parallel reduction of a multidimensional array along the innermost dimension.
--   Each output value is computed by a single thread, with the output values
--   distributed evenly amongst the available threads.
foldInner 
        :: Gang                 -- ^ Gang to run the operation on.
        -> (Int# -> a -> IO ()) -- ^ Function to write into the result buffer.
        -> (Int# -> a)          -- ^ Function to get an element from the source.
        -> (a -> a -> a)        -- ^ Binary associative combination operator.
        -> a                    -- ^ Neutral starting value.
        -> Int#                 -- ^ Total length of source.
        -> Int#                 -- ^ Inner dimension (length to fold over).
        -> IO ()

foldInner gang write f c !r !len !n
 = gangIO gang
 $ \tid -> fill (split tid) (split (tid +# 1#))
  where
        !threads = gangSize gang
        !step    = (len +# threads -# 1#) `quotInt#` threads

        split !ix 
         = let !ix' = ix *# step
           in  case len <# ix' of
                1# -> len
                _  -> ix'
        {-# INLINE split #-}

        fill !start !end 
         = iter start (start *# n)
         where
          iter !sh !sz 
           | 1# <- sh >=# end = return ()
           | otherwise 
           = do let !next = sz +# n
                write sh (Seq.foldRange f c r sz next)
                iter (sh +# 1#) next
          {-# INLINE iter #-}
        {-# INLINE fill #-}

{-# INLINE [1] foldInner #-}