{-# LANGUAGE ScopedTypeVariables #-}
-- | Distributed ST computations.
--
--  Computations of type 'DistST' are data-parallel computations which
--  are run on each thread of a gang. At the moment, they can only access the
--  element of a (possibly mutable) distributed value owned by the current
--  thread.
--
-- /TODO:/ Add facilities for implementing parallel scans etc.
module Data.Array.Parallel.Unlifted.Distributed.DistST (
  DistST, stToDistST, distST_, distST, runDistST, runDistST_seq, traceDistST,
  myIndex, myD, readMyMD, writeMyMD
) where
import Data.Array.Parallel.Base (
  ST, runST)
import Data.Array.Parallel.Unlifted.Distributed.Gang
import Data.Array.Parallel.Unlifted.Distributed.Types (
  DT(..), Dist, MDist)

import Control.Monad (liftM)


-- | Data-parallel computations.
--   When applied to a thread gang, the computation implicitly knows the index
--   of the thread it's working on. Alternatively, if we know the thread index
--   then we can make a regular ST computation.
newtype DistST s a = DistST { unDistST :: Int -> ST s a }

instance Monad (DistST s) where
  {-# INLINE return #-}
  return         = DistST . const . return 

  {-# INLINE (>>=) #-}
  DistST p >>= f = DistST $ \i -> do
                                    x <- p i
                                    unDistST (f x) i


-- | Yields the index of the current thread within its gang.
myIndex :: DistST s Int
{-# INLINE myIndex #-}
myIndex = DistST return


-- | Lifts an 'ST' computation into the 'DistST' monad.
--   The lifted computation should be data parallel.
stToDistST :: ST s a -> DistST s a
{-# INLINE stToDistST #-}
stToDistST p = DistST $ \i -> p


-- | Yields the 'Dist' element owned by the current thread.
myD :: DT a => Dist a -> DistST s a
{-# NOINLINE myD #-}
myD dt = liftM (indexD dt) myIndex


-- | Yields the 'MDist' element owned by the current thread.
readMyMD :: DT a => MDist a s -> DistST s a
{-# NOINLINE readMyMD #-}
readMyMD mdt 
 = do	i <- myIndex
	stToDistST $ readMD mdt i


-- | Writes the 'MDist' element owned by the current thread.
writeMyMD :: DT a => MDist a s -> a -> DistST s ()
{-# NOINLINE writeMyMD #-}
writeMyMD mdt x 
 = do	i <- myIndex
	stToDistST $ writeMD mdt i x


-- | Execute a data-parallel computation on a 'Gang'.
--   The same DistST comutation runs on each thread.
distST_ :: Gang -> DistST s () -> ST s ()
{-# INLINE distST_ #-}
distST_ g = gangST g . unDistST


-- | Execute a data-parallel computation, yielding the distributed result.
distST :: DT a => Gang -> DistST s a -> ST s (Dist a)
{-# INLINE distST #-}
distST g p 
 = do	md <- newMD g
        distST_ g $ writeMyMD md =<< p
        unsafeFreezeMD md


-- | Run a data-parallel computation, yielding the distributed result.
runDistST :: DT a => Gang -> (forall s. DistST s a) -> Dist a
{-# NOINLINE runDistST #-}
runDistST g p = runST (distST g p)

runDistST_seq :: forall a. DT a => Gang -> (forall s. DistST s a) -> Dist a
{-# NOINLINE runDistST_seq #-}
runDistST_seq g p = runST (
  do
     md <- newMD g
     go md 0
     unsafeFreezeMD md)                           
  where
    !n = gangSize g
    go :: forall s. MDist a s -> Int -> ST s ()
    go md i | i < n     = do
                            writeMD md i =<< unDistST p i
                            go md (i+1)
            | otherwise = return ()

traceDistST :: String -> DistST s ()
traceDistST s = DistST $ \n -> traceGangST ("Worker " ++ show n ++ ": " ++ s)