{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-}
module Data.Array.Repa.Eval.Selection
	(selectChunkedS, selectChunkedP)
where
import Data.Array.Repa.Eval.Gang
import Data.Array.Repa.Shape
import Data.Vector.Unboxed			as V
import Data.Vector.Unboxed.Mutable		as VM
import GHC.Base					(remInt, quotInt)
import Prelude					as P
import Control.Monad				as P
import Data.IORef


-- | Select indices matching a predicate.
--  
--   * This primitive can be useful for writing filtering functions.
--
selectChunkedS
	:: Shape sh
	=> (sh -> a -> IO ())	-- ^ Update function to write into result.
	-> (sh -> Bool)		-- ^ See if this predicate matches.
	-> (sh -> a)		-- ^  .. and apply fn to the matching index
	-> sh 			-- ^ Extent of indices to apply to predicate.
	-> IO Int		-- ^ Number of elements written to destination array.

{-# INLINE selectChunkedS #-}
selectChunkedS fnWrite fnMatch fnProduce !shSize
 = fill 0 0
 where	lenSrc	= size shSize

	fill !nSrc !nDst
	 | nSrc >= lenSrc	= return nDst

	 | ixSrc	<- fromIndex shSize nSrc
	 , fnMatch ixSrc
	 = do	fnWrite ixSrc (fnProduce ixSrc)
		fill (nSrc + 1) (nDst + 1)

	 | otherwise
	 = 	fill (nSrc + 1) nDst


-- | Select indices matching a predicate, in parallel.
--  
--   * This primitive can be useful for writing filtering functions.
--
--   * The array is split into linear chunks, with one chunk being given to
--     each thread.
--
--   * The number of elements in the result array depends on how many threads
--     you're running the program with.
--
selectChunkedP
	:: forall a
	.  Unbox a
	=> (Int -> Bool)	-- ^ See if this predicate matches.
	-> (Int -> a)		--   .. and apply fn to the matching index
	-> Int			-- Extent of indices to apply to predicate.
	-> IO [IOVector a]	-- Chunks containing array elements.

{-# INLINE selectChunkedP #-}
selectChunkedP fnMatch fnProduce !len
 = do
	-- Make IORefs that the threads will write their result chunks to.
	-- We start with a chunk size proportial to the number of threads we have,
	-- but the threads themselves can grow the chunks if they run out of space.
	refs	<- P.replicateM threads
		$ do	vec	<- VM.new $ len `div` threads
			newIORef vec

	-- Fire off a thread to fill each chunk.
	gangIO theGang
	 $ \thread -> makeChunk (refs !! thread)
			(splitIx thread)
			(splitIx (thread + 1) - 1)

	-- Read the result chunks back from the IORefs.
	-- If a thread had to grow a chunk, then these might not be the same ones
	-- we created back in the first step.
	P.mapM readIORef refs

 where	-- See how many threads we have available.
	!threads 	= gangSize theGang
	!chunkLen 	= len `quotInt` threads
	!chunkLeftover	= len `remInt`  threads


	-- Decide where to split the source array.
	{-# INLINE splitIx #-}
	splitIx thread
	 | thread < chunkLeftover = thread * (chunkLen + 1)
	 | otherwise		  = thread * chunkLen  + chunkLeftover


	-- Fill the given chunk with elements selected from this range of indices.
	makeChunk :: IORef (IOVector a) -> Int -> Int -> IO ()
	makeChunk !ref !ixSrc !ixSrcEnd
         | ixSrc > ixSrcEnd
         = do  vecDst   <- VM.new 0
               writeIORef ref vecDst

         | otherwise
	 = do  vecDst	<- VM.new (len `div` threads)
               vecDst'	<- fillChunk ixSrc ixSrcEnd vecDst 0 (VM.length vecDst)
	       writeIORef ref vecDst'


	-- The main filling loop.
	fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a)
	fillChunk !ixSrc !ixSrcEnd !vecDst !ixDst !ixDstLen
         -- If we've finished selecting elements, then slice the vector down
         -- so it doesn't have any empty space at the end.
	 | ixSrc > ixSrcEnd
	 = 	return	$ VM.slice 0 ixDst vecDst

	 -- If we've run out of space in the chunk then grow it some more.
	 | ixDst >= ixDstLen
	 = do	let ixDstLen'	= (VM.length vecDst + 1) * 2
		vecDst' 	<- VM.grow vecDst ixDstLen'
		fillChunk ixSrc ixSrcEnd vecDst' ixDst ixDstLen'

	 -- We've got a maching element, so add it to the chunk.
	 | fnMatch ixSrc
	 = do	VM.unsafeWrite vecDst ixDst (fnProduce ixSrc)
		fillChunk (ixSrc + 1) ixSrcEnd vecDst (ixDst + 1) ixDstLen

	 -- The element doesnt match, so keep going.
	 | otherwise
	 =	fillChunk (ixSrc + 1) ixSrcEnd vecDst ixDst ixDstLen