{-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE CPP #-} #include "fusion-phases.h" -- | Standard combinators for distributed types. module Data.Array.Parallel.Unlifted.Distributed.Combinators ( W.What (..) , imapD, mapD , zipD, unzipD , fstD, sndD , zipWithD, izipWithD , foldD , scanD , mapAccumLD) where import Data.Array.Parallel.Base ( ST, runST) import Data.Array.Parallel.Unlifted.Distributed.Primitive import Data.Array.Parallel.Unlifted.Distributed.Data.Tuple import Data.Array.Parallel.Unlifted.Distributed.Data.Maybe () import qualified Data.Array.Parallel.Unlifted.Distributed.What as W here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s -- Mapping -------------------------------------------------------------------- -- -- Fusing maps -- ~~~~~~~~~~~ -- The staging here is important. -- Our rewrite rules only operate on the imapD form, so fusion between the worker -- functions of consecutive maps takes place before phase [0]. -- -- At phase [0] we then inline imapD which introduces the call to imapD' which -- uses the gang to evaluate its (now fused) worker. -- -- | Map a function to every instance of a distributed value. -- -- This applies the function to every thread, but not every value held -- by the thread. If you want that then use something like: -- -- @mapD theGang (V.map (+ 1)) :: Dist (Vector Int) -> Dist (Vector Int)@ -- mapD :: (DT a, DT b) => W.What -- ^ What is the worker function doing. -> Gang -> (a -> b) -> Dist a -> Dist b mapD wFn gang = imapD wFn gang . const {-# INLINE mapD #-} -- INLINE because this is just a convenience wrapper for imapD. -- None of our rewrite rules are particular to mapD. -- | Map a function across all elements of a distributed value. -- The worker function also gets the current thread index. -- As opposed to `imapD'` this version also deepSeqs each element before -- passing it to the function. imapD :: (DT a, DT b) => W.What -- ^ What is the worker function doing. -> Gang -> (Int -> a -> b) -> Dist a -> Dist b imapD wFn gang f d = imapD' wFn gang (\i x -> x `deepSeqD` f i x) d {-# INLINE [0] imapD #-} -- INLINE [0] because we want to wait until phase [0] before introducing -- the call to imapD'. Our rewrite rules operate directly on the imapD -- formp, so once imapD is inlined no more fusion can take place. {-# RULES "imapD/generateD" forall wMap wGen gang f g . imapD wMap gang f (generateD wGen gang g) = generateD (W.WFMapGen wMap wGen) gang (\i -> f i (g i)) "imapD/generateD_cheap" forall wMap wGen gang f g . imapD wMap gang f (generateD_cheap wGen gang g) = generateD (W.WFMapGen wMap wGen) gang (\i -> f i (g i)) "imapD/imapD" forall wMap1 wMap2 gang f g d . imapD wMap1 gang f (imapD wMap2 gang g d) = imapD (W.WFMapMap wMap1 wMap2) gang (\i x -> f i (g i x)) d #-} -- Zipping -------------------------------------------------------------------- -- | Combine two distributed values with the given function. zipWithD :: (DT a, DT b, DT c) => W.What -- ^ What is the worker function doing. -> Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c zipWithD what g f dx dy = mapD what g (uncurry f) (zipD dx dy) {-# INLINE zipWithD #-} -- | Combine two distributed values with the given function. -- The worker function also gets the index of the current thread. izipWithD :: (DT a, DT b, DT c) => W.What -- ^ What is the worker function doing. -> Gang -> (Int -> a -> b -> c) -> Dist a -> Dist b -> Dist c izipWithD what g f dx dy = imapD what g (\i -> uncurry (f i)) (zipD dx dy) {-# INLINE izipWithD #-} {-# RULES "zipD/imapD[1]" forall gang f xs ys what . zipD (imapD what gang f xs) ys = imapD what gang (\i (x,y) -> (f i x, y)) (zipD xs ys) "zipD/imapD[2]" forall gang f xs ys what . zipD xs (imapD what gang f ys) = imapD what gang (\i (x,y) -> (x, f i y)) (zipD xs ys) "zipD/generateD[1]" forall gang f xs what . zipD (generateD what gang f) xs = imapD what gang (\i x -> (f i, x)) xs "zipD/generateD[2]" forall gang f xs what . zipD xs (generateD what gang f) = imapD what gang (\i x -> (x, f i)) xs #-} -- MapAccumL ------------------------------------------------------------------ -- | Combination of map and fold. mapAccumLD :: forall a b acc. (DT a, DT b) => Gang -> (acc -> a -> (acc, b)) -> acc -> Dist a -> (acc, Dist b) mapAccumLD g f acc !d = checkGangD (here "mapAccumLD") g d $ runST (do md <- newMD g acc' <- go md 0 acc d' <- unsafeFreezeMD md return (acc',d')) where !n = gangSize g go :: MDist b s -> Int -> acc -> ST s acc go md i acc' | i == n = return acc' | otherwise = case f acc' (indexD (here "mapAccumLD") d i) of (acc'',b) -> do writeMD md i b go md (i+1) acc'' {-# INLINE_DIST mapAccumLD #-}