{-# LANGUAGE MagicHash #-}
-- | Evaluate an array by breaking it up into linear chunks and filling
--   each chunk in parallel.
module Data.Array.Repa.Eval.Chunked
	( fillChunkedP
	, fillChunkedS
	, fillChunkedIOP)
where
import Data.Array.Repa.Eval.Gang
import GHC.Exts
import Prelude		as P

-- | Fill something sequentially.
-- 
--   * The array is filled linearly from start to finish.  
-- 
fillChunkedS
	:: Int                  -- ^ Number of elements.
	-> (Int -> a -> IO ())	-- ^ Update function to write into result buffer.
	-> (Int -> a)	        -- ^ Fn to get the value at a given index.
	-> IO ()

{-# INLINE [0] fillChunkedS #-}
fillChunkedS !(I# len) write getElem
 = fill 0#
 where	fill !ix
	 | ix >=# len	= return ()
	 | otherwise
	 = do	write (I# ix) (getElem (I# ix))
		fill (ix +# 1#)


-- | Fill something in parallel.
-- 
--   * The array is split into linear chunks and each thread fills one chunk.
-- 
fillChunkedP
        :: Int                  -- ^ Number of elements.
	-> (Int -> a -> IO ())	-- ^ Update function to write into result buffer.
	-> (Int -> a)	        -- ^ Fn to get the value at a given index.
	-> IO ()

{-# INLINE [0] fillChunkedP #-}
fillChunkedP !(I# len) write getElem
 = 	gangIO theGang
	 $  \(I# thread) -> 
              let !start   = splitIx thread
                  !end     = splitIx (thread +# 1#)
              in  fill start end

 where
	-- Decide now to split the work across the threads.
	-- If the length of the vector doesn't divide evenly among the threads,
	-- then the first few get an extra element.
	!(I# threads) 	= gangSize theGang
	!chunkLen 	= len `quotInt#` threads
	!chunkLeftover	= len `remInt#`  threads

	{-# INLINE splitIx #-}
	splitIx thread
	 | thread <# chunkLeftover = thread *# (chunkLen +# 1#)
	 | otherwise	 	   = thread *# chunkLen  +# chunkLeftover

	-- Evaluate the elements of a single chunk.
	{-# INLINE fill #-}
	fill !ix !end
	 | ix >=# end		= return ()
	 | otherwise
	 = do	write (I# ix) (getElem (I# ix))
		fill (ix +# 1#) end


-- | Fill something in parallel, using a separate IO action for each thread.
fillChunkedIOP
        :: Int                          -- ^ Number of elements.
        -> (Int -> a -> IO ())          -- ^ Update fn to write into result buffer.
        -> (Int -> IO (Int -> IO a))    -- ^ Create a fn to get the value at a given index.
                                        --   The first `Int` is the thread number, so you can do some
                                        --   per-thread initialisation.
        -> IO ()

{-# INLINE [0] fillChunkedIOP #-}
fillChunkedIOP !(I# len) write mkGetElem
 = 	gangIO theGang
	 $  \(I# thread) -> 
              let !start = splitIx thread
                  !end   = splitIx (thread +# 1#)
              in fillChunk thread start end 

 where
	-- Decide now to split the work across the threads.
	-- If the length of the vector doesn't divide evenly among the threads,
	-- then the first few get an extra element.
	!(I# threads) 	= gangSize theGang
	!chunkLen 	= len `quotInt#` threads
	!chunkLeftover	= len `remInt#`  threads

	{-# INLINE splitIx #-}
	splitIx thread
	 | thread <# chunkLeftover = thread *# (chunkLen +# 1#)
	 | otherwise		   = thread *# chunkLen  +# chunkLeftover


        -- Given the threadId, starting and ending indices. 
        --      Make a function to get each element for this chunk
        --      and call it for every index.
        {-# INLINE fillChunk #-}
        fillChunk !thread !ixStart !ixEnd
         = do   getElem <- mkGetElem (I# thread)
                fill getElem ixStart ixEnd
                
        -- Call the provided getElem function for every element
        --      in a chunk, and feed the result to the write function.
	{-# INLINE fill #-}
	fill !getElem !ix0 !end
	 = go ix0 
	 where  go !ix
	         | ix >=# end	= return ()
 	         | otherwise
	         = do	x       <- getElem (I# ix)
	                write (I# ix) x
                        go (ix +# 1#)