module Data.Repa.Eval.Generic.Par.Chunked
        ( fillChunked
        , fillChunkedIO)
where
import Data.Repa.Eval.Gang
import GHC.Exts


-------------------------------------------------------------------------------
-- | Fill something in parallel.
-- 
--   * The array is split into linear chunks,
--     and each thread linearly fills one chunk.
-- 
fillChunked
        :: Gang                  -- ^ Gang to run the operation on.
        -> (Int# -> a -> IO ())  -- ^ Update function to write into result buffer.
        -> (Int# -> a)           -- ^ Function to get the value at a given index.
        -> Int#                  -- ^ Number of elements.
        -> IO ()

fillChunked :: forall a.
Gang -> (Int# -> a -> IO ()) -> (Int# -> a) -> Int# -> IO ()
fillChunked Gang
gang Int# -> a -> IO ()
write Int# -> a
getElem Int#
len
 = Gang -> (Int# -> IO ()) -> IO ()
gangIO Gang
gang
 ((Int# -> IO ()) -> IO ()) -> (Int# -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$  \Int#
thread -> 
    let !start :: Int#
start   = Int# -> Int#
splitIx Int#
thread
        !end :: Int#
end     = Int# -> Int#
splitIx (Int#
thread Int# -> Int# -> Int#
+# Int#
1#)
    in  Int# -> Int# -> IO ()
fill Int#
start Int#
end

 where
        -- Decide now to split the work across the threads.
        -- If the length of the vector doesn't divide evenly among the threads,
        -- then the first few get an extra element.
        !threads :: Int#
threads        = Gang -> Int#
gangSize Gang
gang
        !chunkLen :: Int#
chunkLen       = Int#
len Int# -> Int# -> Int#
`quotInt#` Int#
threads
        !chunkLeftover :: Int#
chunkLeftover  = Int#
len Int# -> Int# -> Int#
`remInt#`  Int#
threads

        splitIx :: Int# -> Int#
splitIx Int#
thread
         | Int#
1# <- Int#
thread Int# -> Int# -> Int#
<# Int#
chunkLeftover = Int#
thread Int# -> Int# -> Int#
*# (Int#
chunkLen Int# -> Int# -> Int#
+# Int#
1#)
         | Bool
otherwise                     = Int#
thread Int# -> Int# -> Int#
*# Int#
chunkLen  Int# -> Int# -> Int#
+# Int#
chunkLeftover
        {-# INLINE splitIx #-}

        -- Evaluate the elements of a single chunk.
        fill :: Int# -> Int# -> IO ()
fill !Int#
ix !Int#
end
         | Int#
1# <- Int#
ix Int# -> Int# -> Int#
>=# Int#
end        = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
         | Bool
otherwise
         = do   Int# -> a -> IO ()
write Int#
ix (Int# -> a
getElem Int#
ix)
                Int# -> Int# -> IO ()
fill (Int#
ix Int# -> Int# -> Int#
+# Int#
1#) Int#
end
        {-# INLINE fill #-}

{-# INLINE [0] fillChunked #-}


-------------------------------------------------------------------------------
-- | Fill something in parallel, using a separate IO action for each thread.
--
--   * The array is split into linear chunks,
--     and each thread linearly fills one chunk.
--
fillChunkedIO
        :: Gang  -- ^ Gang to run the operation on.
        -> (Int# -> a -> IO ())          
                 -- ^ Update function to write into result buffer.
        -> (Int# -> IO (Int# -> IO a))    
                 -- ^ Create a function to get the value at a given index.
                 --   The first argument is the thread number, so you can do some
                 --   per-thread initialisation.
        -> Int#  -- ^ Number of elements.
        -> IO ()

fillChunkedIO :: forall a.
Gang
-> (Int# -> a -> IO ())
-> (Int# -> IO (Int# -> IO a))
-> Int#
-> IO ()
fillChunkedIO Gang
gang Int# -> a -> IO ()
write Int# -> IO (Int# -> IO a)
mkGetElem Int#
len
 = Gang -> (Int# -> IO ()) -> IO ()
gangIO Gang
gang
 ((Int# -> IO ()) -> IO ()) -> (Int# -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$  \Int#
thread -> 
    let !start :: Int#
start = Int# -> Int#
splitIx Int#
thread
        !end :: Int#
end   = Int# -> Int#
splitIx (Int#
thread Int# -> Int# -> Int#
+# Int#
1#)
    in Int# -> Int# -> Int# -> IO ()
fillChunk Int#
thread Int#
start Int#
end 

 where
        -- Decide now to split the work across the threads.
        -- If the length of the vector doesn't divide evenly among the threads,
        -- then the first few get an extra element.
        !threads :: Int#
threads        = Gang -> Int#
gangSize Gang
gang
        !chunkLen :: Int#
chunkLen       = Int#
len Int# -> Int# -> Int#
`quotInt#` Int#
threads
        !chunkLeftover :: Int#
chunkLeftover  = Int#
len Int# -> Int# -> Int#
`remInt#`  Int#
threads

        splitIx :: Int# -> Int#
splitIx Int#
thread
         | Int#
1# <- Int#
thread Int# -> Int# -> Int#
<# Int#
chunkLeftover = Int#
thread Int# -> Int# -> Int#
*# (Int#
chunkLen Int# -> Int# -> Int#
+# Int#
1#)
         | Bool
otherwise                     = Int#
thread Int# -> Int# -> Int#
*# Int#
chunkLen  Int# -> Int# -> Int#
+# Int#
chunkLeftover
        {-# INLINE splitIx #-}

        -- Given the threadId, starting and ending indices. 
        --      Make a function to get each element for this chunk
        --      and call it for every index.
        fillChunk :: Int# -> Int# -> Int# -> IO ()
fillChunk !Int#
thread !Int#
ixStart !Int#
ixEnd
         = do   Int# -> IO a
getElem <- Int# -> IO (Int# -> IO a)
mkGetElem Int#
thread
                (Int# -> IO a) -> Int# -> Int# -> IO ()
fill Int# -> IO a
getElem Int#
ixStart Int#
ixEnd
        {-# INLINE fillChunk #-}
                
        -- Call the provided getElem function for every element
        --      in a chunk, and feed the result to the write function.
        fill :: (Int# -> IO a) -> Int# -> Int# -> IO ()
fill !Int# -> IO a
getElem !Int#
ix0 !Int#
end
         = Int# -> IO ()
go Int#
ix0 
         where  go :: Int# -> IO ()
go !Int#
ix
                 | Int#
1# <- Int#
ix Int# -> Int# -> Int#
>=# Int#
end   = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                 | Bool
otherwise
                 = do   a
x       <- Int# -> IO a
getElem Int#
ix
                        Int# -> a -> IO ()
write Int#
ix a
x
                        Int# -> IO ()
go (Int#
ix Int# -> Int# -> Int#
+# Int#
1#)
        {-# INLINE fill #-}

{-# INLINE [0] fillChunkedIO #-}