module Data.Array.Repa.Internals.EvalReduction
( foldS, foldP
, foldAllS, foldAllP)
where
import Data.Array.Repa.Internals.Elt
import Data.Array.Repa.Internals.Gang
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as M
import GHC.Base ( quotInt, divInt )
foldS :: Elt a
=> M.IOVector a
-> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO ()
foldS vec !f !c !r !n = iter 0 0
where
!end = M.length vec
iter !sh !sz | sh >= end = return ()
| otherwise =
let !next = sz + n
in M.unsafeWrite vec sh (reduce f c r sz next) >> iter (sh+1) next
foldP :: Elt a
=> M.IOVector a
-> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO ()
foldP vec !f !c !r !n
= gangIO theGang
$ \tid -> fill (split tid) (split (tid+1))
where
!threads = gangSize theGang
!len = M.length vec
!step = (len + threads 1) `quotInt` threads
split !ix = len `min` (ix * step)
fill !start !end = iter start (start * n)
where
iter !sh !sz | sh >= end = return ()
| otherwise =
let !next = sz + n
in M.unsafeWrite vec sh (reduce f c r sz next) >> iter (sh+1) next
foldAllS :: Elt a
=> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO a
foldAllS !f !c !r !len = return $! reduce f c r 0 len
foldAllP :: Elt a
=> (Int -> a)
-> (a -> a -> a)
-> a
-> Int
-> IO a
foldAllP !f !c !r !len
| len == 0 = return r
| otherwise = do
mvec <- M.unsafeNew chunks
gangIO theGang $ \tid -> fill mvec tid (split tid) (split (tid+1))
vec <- V.unsafeFreeze mvec
return $! V.foldl' c r vec
where
!threads = gangSize theGang
!step = (len + threads 1) `quotInt` threads
chunks = ((len + step 1) `divInt` step) `min` threads
split !ix = len `min` (ix * step)
fill !mvec !tid !start !end
| start >= end = return ()
| otherwise = M.unsafeWrite mvec tid (reduce f c (f start) (start+1) end)
reduce :: (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
reduce !f !c !r !start !end = iter start r
where
iter !i !z | i >= end = z
| otherwise = iter (i+1) (f i `c` z)