{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-}
module Data.Array.Repa.Eval.Selection
	(selectChunkedS, selectChunkedP)
where
import Data.Array.Repa.Eval.Gang
import Data.Array.Repa.Shape
import Data.Vector.Unboxed			as V
import Data.Vector.Unboxed.Mutable		as VM
import GHC.Base					(remInt, quotInt)
import Prelude					as P
import Control.Monad				as P
import Data.IORef


-- | Select indices matching a predicate.
--  
--   * This primitive can be useful for writing filtering functions.
--
selectChunkedS
	:: Shape sh
	=> (sh -> a -> IO ())	-- ^ Update function to write into result.
	-> (sh -> Bool)		-- ^ See if this predicate matches.
	-> (sh -> a)		-- ^  .. and apply fn to the matching index
	-> sh 			-- ^ Extent of indices to apply to predicate.
	-> IO Int		-- ^ Number of elements written to destination array.

{-# INLINE selectChunkedS #-}
selectChunkedS !fnWrite !fnMatch !fnProduce !shSize
 = fill 0 0
 where	lenSrc	= size shSize

	fill !nSrc !nDst
	 | nSrc >= lenSrc	= return nDst

	 | ixSrc	<- fromIndex shSize nSrc
	 , fnMatch ixSrc
	 = do	fnWrite ixSrc (fnProduce ixSrc)
		fill (nSrc + 1) (nDst + 1)

	 | otherwise
	 = 	fill (nSrc + 1) nDst


-- | Select indices matching a predicate, in parallel.
--  
--   * This primitive can be useful for writing filtering functions.
--
--   * The array is split into linear chunks, with one chunk being given to
--     each thread.
--
--   * The number of elements in the result array depends on how many threads
--     you're running the program with.
--
selectChunkedP
	:: forall a
	.  Unbox a
	=> (Int -> Bool)	-- ^ See if this predicate matches.
	-> (Int -> a)		--   .. and apply fn to the matching index
	-> Int			-- Extent of indices to apply to predicate.
	-> IO [IOVector a]	-- Chunks containing array elements.

{-# INLINE selectChunkedP #-}
selectChunkedP !fnMatch !fnProduce !len
 = do
	-- Make IORefs that the threads will write their result chunks to.
	-- We start with a chunk size proportial to the number of threads we have,
	-- but the threads themselves can grow the chunks if they run out of space.
	refs	<- P.replicateM threads
		$ do	vec	<- VM.new $ len `div` threads
			newIORef vec

	-- Fire off a thread to fill each chunk.
	gangIO theGang
	 $ \thread -> makeChunk (refs !! thread)
			(splitIx thread)
			(splitIx (thread + 1) - 1)

	-- Read the result chunks back from the IORefs.
	-- If a thread had to grow a chunk, then these might not be the same ones
	-- we created back in the first step.
	P.mapM readIORef refs

 where	-- See how many threads we have available.
	!threads 	= gangSize theGang
	!chunkLen 	= len `quotInt` threads
	!chunkLeftover	= len `remInt`  threads


	-- Decide where to split the source array.
	{-# INLINE splitIx #-}
	splitIx thread
	 | thread < chunkLeftover = thread * (chunkLen + 1)
	 | otherwise		  = thread * chunkLen  + chunkLeftover


	-- Fill the given chunk with elements selected from this range of indices.
	makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
	makeChunk !ref !ixSrc !ixSrcEnd
	 = do	vecDst	<- VM.new (len `div` threads)
		vecDst'	<- fillChunk ixSrc ixSrcEnd vecDst 0 (VM.length vecDst - 1)
		writeIORef ref vecDst'


	-- The main filling loop.
	fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
	fillChunk !ixSrc !ixSrcEnd !vecDst !ixDst !ixDstEnd
         -- If we've finished selecting elements, then slice the vector down
         -- so it doesn't have any empty space at the end.
	 | ixSrc >= ixSrcEnd
	 = 	return	$ VM.slice 0 ixDst vecDst

	 -- If we've run out of space in the chunk then grow it some more.
	 | ixDst >= ixDstEnd
	 = do	let ixDstEnd'	= VM.length vecDst * 2 - 1
		vecDst' 	<- VM.grow vecDst (ixDstEnd + 1)
		fillChunk (ixSrc + 1) ixSrcEnd vecDst' (ixDst + 1) ixDstEnd'

	 -- We've got a maching element, so add it to the chunk.
	 | fnMatch ixSrc
	 = do	VM.unsafeWrite vecDst ixDst (fnProduce ixSrc)
		fillChunk (ixSrc + 1) ixSrcEnd vecDst (ixDst + 1)  ixDstEnd

	 -- The element doesnt match, so keep going.
	 | otherwise
	 =	fillChunk (ixSrc + 1) ixSrcEnd vecDst ixDst ixDstEnd