#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Distributed.Combinators
( generateD, generateD_cheap
, imapD, mapD
, zipD, unzipD
, fstD, sndD
, zipWithD, izipWithD
, foldD
, scanD
, mapAccumLD
, mapDST_, mapDST, zipWithDST_, zipWithDST)
where
import Data.Array.Parallel.Base ( ST, runST)
import Data.Array.Parallel.Unlifted.Distributed.Gang
import Data.Array.Parallel.Unlifted.Distributed.Types
import Data.Array.Parallel.Unlifted.Distributed.DistST
here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
generateD :: DT a => Gang -> (Int -> a) -> Dist a
generateD g f
= runDistST g (myIndex >>= return . f)
generateD_cheap :: DT a => Gang -> (Int -> a) -> Dist a
generateD_cheap g f
= runDistST_seq g (myIndex >>= return . f)
imapD :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
imapD g f d = imapD' g (\i x -> x `deepSeqD` f i x) d
imapD' :: (DT a, DT b) => Gang -> (Int -> a -> b) -> Dist a -> Dist b
imapD' g f !d
= checkGangD (here "imapD") g d
$ runDistST g
(do i <- myIndex
x <- myD d
return (f i x))
mapD :: (DT a, DT b) => Gang -> (a -> b) -> Dist a -> Dist b
mapD g = imapD g . const
zipWithD :: (DT a, DT b, DT c)
=> Gang -> (a -> b -> c) -> Dist a -> Dist b -> Dist c
zipWithD g f dx dy = mapD g (uncurry f) (zipD dx dy)
izipWithD :: (DT a, DT b, DT c)
=> Gang -> (Int -> a -> b -> c) -> Dist a -> Dist b -> Dist c
izipWithD g f dx dy = imapD g (\i -> uncurry (f i)) (zipD dx dy)
foldD :: DT a => Gang -> (a -> a -> a) -> Dist a -> a
foldD g f !d
= checkGangD ("here foldD") g d
$ fold 1 (indexD (here "foldD") d 0)
where
!n = gangSize g
fold i x | i == n = x
| otherwise = fold (i+1) (f x $ indexD (here "foldD") d i)
scanD :: forall a. DT a => Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a)
scanD g f z !d
= checkGangD (here "scanD") g d
$ runST (do
md <- newMD g
s <- scan md 0 z
d' <- unsafeFreezeMD md
return (d',s))
where
!n = gangSize g
scan :: forall s. MDist a s -> Int -> a -> ST s a
scan md i !x
| i == n = return x
| otherwise
= do writeMD md i x
scan md (i+1) (f x $ indexD (here "scanD") d i)
mapAccumLD
:: forall a b acc. (DT a, DT b)
=> Gang
-> (acc -> a -> (acc, b))
-> acc -> Dist a -> (acc, Dist b)
mapAccumLD g f acc !d
= checkGangD (here "mapAccumLD") g d
$ runST (do
md <- newMD g
acc' <- go md 0 acc
d' <- unsafeFreezeMD md
return (acc',d'))
where
!n = gangSize g
go :: MDist b s -> Int -> acc -> ST s acc
go md i acc'
| i == n = return acc'
| otherwise
= case f acc' (indexD (here "mapAccumLD") d i) of
(acc'',b) -> do
writeMD md i b
go md (i+1) acc''
mapDST_ :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
mapDST_ g p d
= mapDST_' g (\x -> x `deepSeqD` p x) d
mapDST_' :: DT a => Gang -> (a -> DistST s ()) -> Dist a -> ST s ()
mapDST_' g p !d
= checkGangD (here "mapDST_") g d
$ distST_ g (myD d >>= p)
mapDST :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
mapDST g p !d = mapDST' g (\x -> x `deepSeqD` p x) d
mapDST' :: (DT a, DT b) => Gang -> (a -> DistST s b) -> Dist a -> ST s (Dist b)
mapDST' g p !d
= checkGangD (here "mapDST_") g d
$ distST g (myD d >>= p)
zipWithDST_
:: (DT a, DT b)
=> Gang -> (a -> b -> DistST s ()) -> Dist a -> Dist b -> ST s ()
zipWithDST_ g p !dx !dy
= mapDST_ g (uncurry p) (zipD dx dy)
zipWithDST
:: (DT a, DT b, DT c)
=> Gang
-> (a -> b -> DistST s c) -> Dist a -> Dist b -> ST s (Dist c)
zipWithDST g p !dx !dy
= mapDST g (uncurry p) (zipD dx dy)