{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-} module Data.Array.Repa.Eval.Selection (selectChunkedS, selectChunkedP) where import Data.Array.Repa.Eval.Gang import Data.Array.Repa.Shape import Data.Vector.Unboxed as V import Data.Vector.Unboxed.Mutable as VM import GHC.Base (remInt, quotInt) import Prelude as P import Control.Monad as P import Data.IORef -- | Select indices matching a predicate. -- -- * This primitive can be useful for writing filtering functions. -- selectChunkedS :: Shape sh => (sh -> a -> IO ()) -- ^ Update function to write into result. -> (sh -> Bool) -- ^ See if this predicate matches. -> (sh -> a) -- ^ .. and apply fn to the matching index -> sh -- ^ Extent of indices to apply to predicate. -> IO Int -- ^ Number of elements written to destination array. {-# INLINE selectChunkedS #-} selectChunkedS fnWrite fnMatch fnProduce !shSize = fill 0 0 where lenSrc = size shSize fill !nSrc !nDst | nSrc >= lenSrc = return nDst | ixSrc <- fromIndex shSize nSrc , fnMatch ixSrc = do fnWrite ixSrc (fnProduce ixSrc) fill (nSrc + 1) (nDst + 1) | otherwise = fill (nSrc + 1) nDst -- | Select indices matching a predicate, in parallel. -- -- * This primitive can be useful for writing filtering functions. -- -- * The array is split into linear chunks, with one chunk being given to -- each thread. -- -- * The number of elements in the result array depends on how many threads -- you're running the program with. -- selectChunkedP :: forall a . Unbox a => (Int -> Bool) -- ^ See if this predicate matches. -> (Int -> a) -- .. and apply fn to the matching index -> Int -- Extent of indices to apply to predicate. -> IO [IOVector a] -- Chunks containing array elements. {-# INLINE selectChunkedP #-} selectChunkedP fnMatch fnProduce !len = do -- Make IORefs that the threads will write their result chunks to. -- We start with a chunk size proportial to the number of threads we have, -- but the threads themselves can grow the chunks if they run out of space. refs <- P.replicateM threads $ do vec <- VM.new $ len `div` threads newIORef vec -- Fire off a thread to fill each chunk. gangIO theGang $ \thread -> makeChunk (refs !! thread) (splitIx thread) (splitIx (thread + 1) - 1) -- Read the result chunks back from the IORefs. -- If a thread had to grow a chunk, then these might not be the same ones -- we created back in the first step. P.mapM readIORef refs where -- See how many threads we have available. !threads = gangSize theGang !chunkLen = len `quotInt` threads !chunkLeftover = len `remInt` threads -- Decide where to split the source array. {-# INLINE splitIx #-} splitIx thread | thread < chunkLeftover = thread * (chunkLen + 1) | otherwise = thread * chunkLen + chunkLeftover -- Fill the given chunk with elements selected from this range of indices. makeChunk :: IORef (IOVector a) -> Int -> Int -> IO () makeChunk !ref !ixSrc !ixSrcEnd | ixSrc > ixSrcEnd = do vecDst <- VM.new 0 writeIORef ref vecDst | otherwise = do vecDst <- VM.new (len `div` threads) vecDst' <- fillChunk ixSrc ixSrcEnd vecDst 0 (VM.length vecDst) writeIORef ref vecDst' -- The main filling loop. fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a) fillChunk !ixSrc !ixSrcEnd !vecDst !ixDst !ixDstLen -- If we've finished selecting elements, then slice the vector down -- so it doesn't have any empty space at the end. | ixSrc > ixSrcEnd = return $ VM.slice 0 ixDst vecDst -- If we've run out of space in the chunk then grow it some more. | ixDst >= ixDstLen = do let ixDstLen' = (VM.length vecDst + 1) * 2 vecDst' <- VM.grow vecDst ixDstLen' fillChunk ixSrc ixSrcEnd vecDst' ixDst ixDstLen' -- We've got a maching element, so add it to the chunk. | fnMatch ixSrc = do VM.unsafeWrite vecDst ixDst (fnProduce ixSrc) fillChunk (ixSrc + 1) ixSrcEnd vecDst (ixDst + 1) ixDstLen -- The element doesnt match, so keep going. | otherwise = fillChunk (ixSrc + 1) ixSrcEnd vecDst ixDst ixDstLen