module Data.Array.Parallel.Unlifted.Distributed.DistST (
DistST, stToDistST, distST_, distST, runDistST, runDistST_seq, traceDistST,
myIndex, myD, readMyMD, writeMyMD
) where
import Data.Array.Parallel.Base (
ST, runST)
import Data.Array.Parallel.Unlifted.Distributed.Gang
import Data.Array.Parallel.Unlifted.Distributed.Types (
DT(..), Dist, MDist)
import Control.Monad (liftM)
newtype DistST s a = DistST { unDistST :: Int -> ST s a }
instance Monad (DistST s) where
return = DistST . const . return
DistST p >>= f = DistST $ \i -> do
x <- p i
unDistST (f x) i
myIndex :: DistST s Int
myIndex = DistST return
stToDistST :: ST s a -> DistST s a
stToDistST p = DistST $ \i -> p
myD :: DT a => Dist a -> DistST s a
myD dt = liftM (indexD dt) myIndex
readMyMD :: DT a => MDist a s -> DistST s a
readMyMD mdt
= do i <- myIndex
stToDistST $ readMD mdt i
writeMyMD :: DT a => MDist a s -> a -> DistST s ()
writeMyMD mdt x
= do i <- myIndex
stToDistST $ writeMD mdt i x
distST_ :: Gang -> DistST s () -> ST s ()
distST_ g = gangST g . unDistST
distST :: DT a => Gang -> DistST s a -> ST s (Dist a)
distST g p
= do md <- newMD g
distST_ g $ writeMyMD md =<< p
unsafeFreezeMD md
runDistST :: DT a => Gang -> (forall s. DistST s a) -> Dist a
runDistST g p = runST (distST g p)
runDistST_seq :: forall a. DT a => Gang -> (forall s. DistST s a) -> Dist a
runDistST_seq g p = runST (
do
md <- newMD g
go md 0
unsafeFreezeMD md)
where
!n = gangSize g
go :: forall s. MDist a s -> Int -> ST s ()
go md i | i < n = do
writeMD md i =<< unDistST p i
go md (i+1)
| otherwise = return ()
traceDistST :: String -> DistST s ()
traceDistST s = DistST $ \n -> traceGangST ("Worker " ++ show n ++ ": " ++ s)