module Data.Yarr.Fold ( -- * Fold support Fold, reduceL, reduceLeftM, reduceR, reduceRightM, -- * Fold runners runFold, runFoldP, runFoldSlicesSeparate, runFoldSlicesSeparateP, -- * Shortcuts toList, ) where import Prelude as P import Control.Monad as M import Data.List (groupBy) import Data.Function (on) import Data.Yarr.Base import Data.Yarr.Shape as S import Data.Yarr.Eval import Data.Yarr.Convolution import Data.Yarr.Utils.FixedVector as V hiding (toList) import Data.Yarr.Utils.Fork import Data.Yarr.Utils.Parallel class Shape fsh => Reduce l ash fsh | l ash -> fsh where getIndex :: USource r l ash a => UArray r l ash a -> (fsh -> IO a) getSize :: USource r l ash a => UArray r l ash a -> fsh #define SH_GET_INDEX(l,sh) \ instance Reduce l sh sh where { \ getIndex = index; \ getSize = extent; \ {-# INLINE getIndex #-}; \ {-# INLINE getSize #-}; \ } SH_GET_INDEX(SH, Dim1) SH_GET_INDEX(SH, Dim2) SH_GET_INDEX(SH, Dim3) SH_GET_INDEX(CVL, Dim1) SH_GET_INDEX(CVL, Dim2) SH_GET_INDEX(CVL, Dim3) #define LINEAR_GET_INDEX(l,sh) \ instance Reduce l sh Int where { \ getIndex = linearIndex; \ getSize = size . extent; \ {-# INLINE getIndex #-}; \ {-# INLINE getSize #-}; \ } LINEAR_GET_INDEX(L, Dim1) LINEAR_GET_INDEX(L, Dim2) LINEAR_GET_INDEX(L, Dim3) -- | Curried 'Foldl' or 'Foldr'. -- Generalizes both partially applied left and right folds. -- -- See source of following 4 functions to construct more similar ones, -- if you need. type Fold sh a b = IO b -- ^ Zero -> (sh -> IO a) -- ^ Get -> sh -- ^ Start -> sh -- ^ End -> IO b -- ^ Result -- | /O(0)/ reduceLeftM :: Foldl sh a b -- ^ 'S.foldl' or curried 'S.unrolledFoldl' -> (b -> a -> IO b) -- ^ Monaric left reduce -> Fold sh a b -- ^ Curried fold to be passed to 'runFold' functions. {-# INLINE reduceLeftM #-} reduceLeftM foldl rf = foldl (\b _ a -> rf b a) -- | /O(0)/ reduceL :: Foldl sh a b -- ^ 'S.foldl' or curried 'S.unrolledFoldl' -> (b -> a -> b) -- ^ Pure left reduce -> Fold sh a b -- ^ Curried fold to be passed to 'runFold' functions. {-# INLINE reduceL #-} reduceL foldl rf = foldl (\b _ a -> return $ rf b a) -- | /O(0)/ reduceRightM :: Foldr sh a b -- ^ 'S.foldr' or curried 'S.unrolledFoldr' -> (a -> b -> IO b) -- ^ Monaric right reduce -> Fold sh a b -- ^ Curried fold to be passed to 'runFold' functions. {-# INLINE reduceRightM #-} reduceRightM foldr rf = foldr (\_ a b -> rf a b) -- | /O(0)/ reduceR :: Foldr sh a b -- ^ 'S.foldr' or curried 'S.unrolledFoldr' -> (a -> b -> b) -- ^ Pure right reduce -> Fold sh a b -- ^ Curried fold to be passed to 'runFold' functions. {-# INLINE reduceR #-} reduceR foldr rf = foldr (\_ a b -> return $ rf a b) -- | /O(n)/ -- -- Example: -- -- @'toList' = runFold ('reduceR' 'S.foldr' (:)) (return [])@ runFold :: (USource r l sh a, Reduce l sh fsh) => Fold fsh a b -- ^ Curried folding worker function -> IO b -- ^ Monadic fold zero. Wrap pure zero in 'return'. -> UArray r l sh a -- ^ Source array -> IO b -- ^ Fold result {-# INLINE runFold #-} runFold fold mz arr = do force arr res <- fold mz (getIndex arr) zero (getSize arr) touchArray arr return res -- | /O(n)/ Run associative fold in parallel. -- -- Example -- associative image histogram filling in the test: -- runFoldP :: (USource r l sh a, Reduce l sh fsh) => Threads -- ^ Number of threads to parallelize folding on -> Fold fsh a b -- ^ Curried folding worker function -> IO b -- ^ Monadic fold zero. Wrap pure zero in 'return'. -> (b -> b -> IO b) -- ^ Associative monadic result joining function -> UArray r l sh a -- ^ Source array -> IO b -- ^ Fold result {-# INLINE runFoldP #-} runFoldP threads fold mz join arr = do force arr ts <- threads (r:rs) <- parallel ts $ makeFork ts zero (getSize arr) (fold mz (getIndex arr)) touchArray arr M.foldM join r rs -- | /O(n)/ runFoldSlicesSeparate :: (UVecSource r slr l sh v e, Reduce l sh fsh) => Fold fsh e b -- ^ Curried folding function to work on slices -> IO b -- ^ Monadic fold zero. Wrap pure zero in 'return'. -> UArray r l sh (v e) -- ^ Source array of vectors -> IO (VecList (Dim v) b) -- ^ Vector of fold results {-# INLINE runFoldSlicesSeparate #-} runFoldSlicesSeparate fold mz arr = V.mapM (\sl -> runFold fold mz sl) (slices arr) -- | /O(n)/ Run associative fold over slices of array of vectors in parallel. runFoldSlicesSeparateP :: (UVecSource r slr l sh v e, Reduce l sh fsh) => Threads -- ^ Number of threads to parallelize folding on -> Fold fsh e b -- ^ Curried folding function to work on slices -> IO b -- ^ Monadic fold zero. Wrap pure zero in 'return'. -> (b -> b -> IO b) -- ^ Associative monadic result joining function -> UArray r l sh (v e) -- ^ Source array of vectors -> IO (VecList (Dim v) b) -- ^ Vector of fold results {-# INLINE runFoldSlicesSeparateP #-} runFoldSlicesSeparateP threads fold mz join arr = do force arr ts <- threads trs <- parallel ts $ makeForkSlicesOnce ts (V.replicate (zero, getSize arr)) (V.map (\sl -> fold mz (getIndex sl)) (slices arr)) touchArray arr let rsBySlices = P.map (P.map snd) $ groupBy ((==) `on` fst) $ concat trs rs <- M.mapM (\(r:rs) -> M.foldM join r rs) rsBySlices return (VecList rs) -- | /O(n)/ Covert array to list. toList :: (USource r l sh a, Reduce l sh fsh) => UArray r l sh a -> IO [a] toList = runFold (reduceR S.foldr (:)) (return [])