{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE CPP #-} #include "fusion-phases.h" -- | Standard combinators for distributed types. module Data.Array.Parallel.Unlifted.Distributed.Combinators ( generateD, generateD_cheap, imapD, mapD, zipD, unzipD, fstD, sndD, zipWithD, izipWithD, foldD, scanD, mapAccumLD, -- * Monadic combinators mapDST_, mapDST, zipWithDST_, zipWithDST ) where import Data.Array.Parallel.Base ( ST, runST) import Data.Array.Parallel.Unlifted.Distributed.Gang ( Gang, gangSize) import Data.Array.Parallel.Unlifted.Distributed.Types ( DT, Dist, MDist, indexD, zipD, unzipD, fstD, sndD, deepSeqD, newMD, writeMD, unsafeFreezeMD, checkGangD, measureD, debugD) import Data.Array.Parallel.Unlifted.Distributed.DistST import Debug.Trace here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s -- | Create a distributed value, given a function that makes the value in each thread. generateD :: DT a => Gang -> (Int -> a) -> Dist a {-# NOINLINE generateD #-} generateD g f = runDistST g (myIndex >>= return . f) -- | Create a distributed value, but run it sequentially (I think?) generateD_cheap :: DT a => Gang -> (Int -> a) -> Dist a {-# NOINLINE generateD_cheap #-} generateD_cheap g f = runDistST_seq g (myIndex >>= return . f) -- Mapping -------------------------------------------------------------------- -- | Map a function across all elements of a distributed value. -- The worker function also gets the current thread index. -- As opposed to `imapD'` this version also deepSeqs each element before -- passing it to the function. imapD :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b {-# INLINE [0] imapD #-} imapD g f d = imapD' g (\i x -> x `deepSeqD` f i x) d -- | Map a function across all elements of a distributed value. -- The worker function also gets the current thread index. imapD' :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b {-# NOINLINE imapD' #-} imapD' g f !d = checkGangD (here "imapD") g d (runDistST g (do i <- myIndex x <- myD d return (f i x))) -- | Map a function over a distributed value. mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b {-# INLINE mapD #-} mapD g = imapD g . const {-# RULES "imapD/generateD" forall gang f g. imapD gang f (generateD gang g) = generateD gang (\i -> f i (g i)) "imapD/generateD_cheap" forall gang f g. imapD gang f (generateD_cheap gang g) = generateD gang (\i -> f i (g i)) "imapD/imapD" forall gang f g d. imapD gang f (imapD gang g d) = imapD gang (\i x -> f i (g i x)) d #-} -- Zipping -------------------------------------------------------------------- -- | Combine two distributed values with the given function. zipWithD :: (DT a, DT b, DT c) => Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c {-# INLINE zipWithD #-} zipWithD g f dx dy = mapD g (uncurry f) (zipD dx dy) -- | Combine two distributed values with the given function. -- The worker function also gets the index of the current thread. izipWithD :: (DT a, DT b, DT c) => Gang -> (Int -> a -> b -> c) -> Dist a -> Dist b -> Dist c {-# INLINE izipWithD #-} izipWithD g f dx dy = imapD g (\i -> uncurry (f i)) (zipD dx dy) {-# RULES "zipD/imapD[1]" forall gang f xs ys. zipD (imapD gang f xs) ys = imapD gang (\i (x,y) -> (f i x,y)) (zipD xs ys) "zipD/imapD[2]" forall gang f xs ys. zipD xs (imapD gang f ys) = imapD gang (\i (x,y) -> (x, f i y)) (zipD xs ys) "zipD/generateD[1]" forall gang f xs. zipD (generateD gang f) xs = imapD gang (\i x -> (f i, x)) xs "zipD/generateD[2]" forall gang f xs. zipD xs (generateD gang f) = imapD gang (\i x -> (x, f i)) xs #-} -- Folding -------------------------------------------------------------------- -- | Fold a distributed value. foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a {-# NOINLINE foldD #-} foldD g f !d = checkGangD ("here foldD") g d $ fold 1 (d `indexD` 0) where !n = gangSize g -- fold i x | i == n = x | otherwise = fold (i+1) (f x $ d `indexD` i) -- | Prefix sum of a distributed value. scanD :: forall a. DT a => Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a) {-# NOINLINE scanD #-} scanD g f z !d = checkGangD (here "scanD") g d $ runST (do md <- newMD g s <- scan md 0 z d' <- unsafeFreezeMD md return (d',s)) where !n = gangSize g scan :: forall s. MDist a s -> Int -> a -> ST s a scan md i !x | i == n = return x | otherwise = do writeMD md i x scan md (i+1) (f x $ d `indexD` i) -- | Combination of map and fold. mapAccumLD :: forall a b acc. (DT a, DT b) => Gang -> (acc -> a -> (acc,b)) -> acc -> Dist a -> (acc,Dist b) {-# INLINE_DIST mapAccumLD #-} mapAccumLD g f acc !d = checkGangD (here "mapAccumLD") g d $ runST (do md <- newMD g acc' <- go md 0 acc d' <- unsafeFreezeMD md return (acc',d')) where !n = gangSize g go :: MDist b s -> Int -> acc -> ST s acc go md i acc | i == n = return acc | otherwise = case f acc (d `indexD` i) of (acc',b) -> do writeMD md i b go md (i+1) acc' -- Versions that work on DistST ----------------------------------------------- -- NOTE: The following combinators must be strict in the Dists because if they -- are not, the Dist might be evaluated (in parallel) when it is requested in -- the current computation which, again, is parallel. This would break our -- model andlead to a deadlock. Hence the bangs. mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s () {-# INLINE mapDST_ #-} mapDST_ g p d = mapDST_' g (\x -> x `deepSeqD` p x) d mapDST_' :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s () mapDST_' g p !d = checkGangD (here "mapDST_") g d $ distST_ g (myD d >>= p) mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b) {-# INLINE mapDST #-} mapDST g p !d = mapDST' g (\x -> x `deepSeqD` p x) d mapDST' :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b) mapDST' g p !d = checkGangD (here "mapDST_") g d $ distST g (myD d >>= p) zipWithDST_ :: (DT a, DT b) => Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s () {-# INLINE zipWithDST_ #-} zipWithDST_ g p !dx !dy = mapDST_ g (uncurry p) (zipD dx dy) zipWithDST :: (DT a, DT b, DT c) => Gang -> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c) {-# INLINE zipWithDST #-} zipWithDST g p !dx !dy = mapDST g (uncurry p) (zipD dx dy)