{-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- | Distributed ST computations.
--
--  Computations of type 'DistST' are data-parallel computations which
--  are run on each thread of a gang. At the moment, they can only access the
--  element of a (possibly mutable) distributed value owned by the current
--  thread.
--
-- /TODO:/ Add facilities for implementing parallel scans etc.
module Data.Array.Parallel.Unlifted.Distributed.DistST 
        ( DistST
        , stToDistST
        , distST_, distST
        , runDistST, runDistST_seq
        , traceDistST
        , myIndex
        , myD
        , readMyMD, writeMyMD)
where
import Data.Array.Parallel.Base (ST, runST)
import Data.Array.Parallel.Unlifted.Distributed.Gang
import Data.Array.Parallel.Unlifted.Distributed.Types (DT(..), Dist, MDist)

import Control.Monad (liftM)


-- | Data-parallel computations.
--   When applied to a thread gang, the computation implicitly knows the index
--   of the thread it's working on. Alternatively, if we know the thread index
--   then we can make a regular ST computation.
newtype DistST s a = DistST { unDistST :: Int -> ST s a }

instance Monad (DistST s) where
  {-# INLINE return #-}
  return         = DistST . const . return 

  {-# INLINE (>>=) #-}
  DistST p >>= f = DistST $ \i -> do
                                    x <- p i
                                    unDistST (f x) i


-- | Yields the index of the current thread within its gang.
myIndex :: DistST s Int
myIndex = DistST return
{-# INLINE myIndex #-}


-- | Lifts an 'ST' computation into the 'DistST' monad.
--   The lifted computation should be data parallel.
stToDistST :: ST s a -> DistST s a
stToDistST p = DistST $ \_ -> p
{-# INLINE stToDistST #-}


-- | Yields the 'Dist' element owned by the current thread.
myD :: DT a => Dist a -> DistST s a
myD dt = liftM (indexD "myD" dt) myIndex
{-# NOINLINE myD #-}


-- | Yields the 'MDist' element owned by the current thread.
readMyMD :: DT a => MDist a s -> DistST s a
readMyMD mdt 
 = do	i <- myIndex
	stToDistST $ readMD mdt i
{-# NOINLINE readMyMD #-}


-- | Writes the 'MDist' element owned by the current thread.
writeMyMD :: DT a => MDist a s -> a -> DistST s ()
writeMyMD mdt x 
 = do	i <- myIndex
	stToDistST $ writeMD mdt i x
{-# NOINLINE writeMyMD #-}


-- | Execute a data-parallel computation on a 'Gang'.
--   The same DistST comutation runs on each thread.
distST_ :: Gang -> DistST s () -> ST s ()
distST_ g = gangST g . unDistST
{-# INLINE distST_ #-}


-- | Execute a data-parallel computation, yielding the distributed result.
distST :: DT a => Gang -> DistST s a -> ST s (Dist a)
distST g p 
 = do	md <- newMD g
        distST_ g $ writeMyMD md =<< p
        unsafeFreezeMD md
{-# INLINE distST #-}


-- | Run a data-parallel computation, yielding the distributed result.
runDistST :: DT a => Gang -> (forall s. DistST s a) -> Dist a
runDistST g p = runST (distST g p)
{-# NOINLINE runDistST #-}


runDistST_seq :: forall a. DT a => Gang -> (forall s. DistST s a) -> Dist a
runDistST_seq g p = runST (
  do
     md <- newMD g
     go md 0
     unsafeFreezeMD md)                           
  where
    !n = gangSize g
    go :: forall s. MDist a s -> Int -> ST s ()
    go md i | i < n     = do
                            writeMD md i =<< unDistST p i
                            go md (i+1)
            | otherwise = return ()
{-# NOINLINE runDistST_seq #-}


traceDistST :: String -> DistST s ()
traceDistST s = DistST $ \n -> traceGangST ("Worker " ++ show n ++ ": " ++ s)