{-# LANGUAGE BangPatterns #-} module Data.Array.Repa.Internals.EvalReduction ( foldS, foldP , foldAllS, foldAllP) where import Data.Array.Repa.Internals.Elt import Data.Array.Repa.Internals.Gang import qualified Data.Vector.Unboxed as V import qualified Data.Vector.Unboxed.Mutable as M import GHC.Base ( quotInt, divInt ) -- | Sequential reduction of a multidimensional array along the innermost dimension. foldS :: Elt a => M.IOVector a -- ^ vector to write elements into -> (Int -> a) -- ^ function to get an element from the given index -> (a -> a -> a) -- ^ binary associative combination function -> a -- ^ starting value (typically an identity) -> Int -- ^ inner dimension (length to fold over) -> IO () {-# INLINE foldS #-} foldS vec !f !c !r !n = iter 0 0 where !end = M.length vec {-# INLINE iter #-} iter !sh !sz | sh >= end = return () | otherwise = let !next = sz + n in M.unsafeWrite vec sh (reduce f c r sz next) >> iter (sh+1) next -- | Parallel reduction of a multidimensional array along the innermost dimension. -- Each output value is computed by a single thread, with the output values -- distributed evenly amongst the available threads. foldP :: Elt a => M.IOVector a -- ^ vector to write elements into -> (Int -> a) -- ^ function to get an element from the given index -> (a -> a -> a) -- ^ binary associative combination operator -> a -- ^ starting value. Must be neutral with respect -- ^ to the operator. eg @0 + a = a@. -> Int -- ^ inner dimension (length to fold over) -> IO () {-# INLINE foldP #-} foldP vec !f !c !r !n = gangIO theGang $ \tid -> fill (split tid) (split (tid+1)) where !threads = gangSize theGang !len = M.length vec !step = (len + threads - 1) `quotInt` threads {-# INLINE split #-} split !ix = len `min` (ix * step) {-# INLINE fill #-} fill !start !end = iter start (start * n) where {-# INLINE iter #-} iter !sh !sz | sh >= end = return () | otherwise = let !next = sz + n in M.unsafeWrite vec sh (reduce f c r sz next) >> iter (sh+1) next -- | Sequential reduction of all the elements in an array. foldAllS :: Elt a => (Int -> a) -- ^ function to get an element from the given index -> (a -> a -> a) -- ^ binary associative combining function -> a -- ^ starting value -> Int -- ^ number of elements -> IO a {-# INLINE foldAllS #-} foldAllS !f !c !r !len = return $! reduce f c r 0 len -- | Parallel tree reduction of an array to a single value. Each thread takes an -- equally sized chunk of the data and computes a partial sum. The main thread -- then reduces the array of partial sums to the final result. -- -- We don't require that the initial value be a neutral element, so each thread -- computes a fold1 on its chunk of the data, and the seed element is only -- applied in the final reduction step. -- foldAllP :: Elt a => (Int -> a) -- ^ function to get an element from the given index -> (a -> a -> a) -- ^ binary associative combining function -> a -- ^ starting value -> Int -- ^ number of elements -> IO a {-# INLINE foldAllP #-} foldAllP !f !c !r !len | len == 0 = return r | otherwise = do mvec <- M.unsafeNew chunks gangIO theGang $ \tid -> fill mvec tid (split tid) (split (tid+1)) vec <- V.unsafeFreeze mvec return $! V.foldl' c r vec where !threads = gangSize theGang !step = (len + threads - 1) `quotInt` threads chunks = ((len + step - 1) `divInt` step) `min` threads {-# INLINE split #-} split !ix = len `min` (ix * step) {-# INLINE fill #-} fill !mvec !tid !start !end | start >= end = return () | otherwise = M.unsafeWrite mvec tid (reduce f c (f start) (start+1) end) -- | Sequentially reduce values between the given indices {-# INLINE reduce #-} reduce :: (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a reduce !f !c !r !start !end = iter start r where {-# INLINE iter #-} iter !i !z | i >= end = z | otherwise = iter (i+1) (f i `c` z)