#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Distributed.Primitive.Operators
( generateD
, generateD_cheap
, imapD'
, foldD
, scanD)
where
import Data.Array.Parallel.Base ( ST, runST)
import Data.Array.Parallel.Unlifted.Distributed.Primitive.DistST
import Data.Array.Parallel.Unlifted.Distributed.Primitive.DT
import Data.Array.Parallel.Unlifted.Distributed.Primitive.Gang
import qualified Data.Array.Parallel.Unlifted.Distributed.What as W
import Debug.Trace
here s = "Data.Array.Parallel.Unlifted.Distributed.Combinators." ++ s
generateD
:: DT a
=> W.What
-> Gang
-> (Int -> a)
-> Dist a
generateD what gang f
= runDistST (W.CGen False what)
gang
(myIndex >>= return . f)
generateD_cheap
:: DT a
=> W.What
-> Gang
-> (Int -> a)
-> Dist a
generateD_cheap what g f
= traceEvent (show $ W.CGen True what)
$ runDistST_seq g (myIndex >>= return . f)
imapD' :: (DT a, DT b)
=> W.What -> Gang -> (Int -> a -> b) -> Dist a -> Dist b
imapD' what gang f !d
= runDistST (W.CMap what) gang
$ do i <- myIndex
x <- myD d
let result = f i x
deepSeqD result (return ())
return result
foldD :: DT a => W.What -> Gang -> (a -> a -> a) -> Dist a -> a
foldD what gang f !d
= traceEvent (show (W.CFold what))
$ checkGangD ("here foldD") gang d
$ fold 1 (indexD (here "foldD") d 0)
where
!n = gangSize gang
fold i x
| i == n = x
| otherwise = fold (i+1) (f x $ indexD (here "foldD") d i)
scanD :: forall a. DT a => W.What -> Gang -> (a -> a -> a) -> a -> Dist a -> (Dist a, a)
scanD what gang f z !d
= traceEvent (show (W.CScan what))
$ checkGangD (here "scanD") gang d
$ runST (do
md <- newMD gang
s <- scan md 0 z
d' <- unsafeFreezeMD md
return (d',s))
where
!n = gangSize gang
scan :: forall s. MDist a s -> Int -> a -> ST s a
scan md i !x
| i == n = return x
| otherwise
= do writeMD md i x
scan md (i+1) (f x $ indexD (here "scanD") d i)