{-# LANGUAGE BangPatterns, ExplicitForAll, ScopedTypeVariables, PatternGuards #-} module Data.Array.Repa.Internals.Select (selectChunkedS, selectChunkedP) where import Data.Array.Repa.Internals.Gang import Data.Array.Repa.Shape import Data.Vector.Unboxed as V import Data.Vector.Unboxed.Mutable as VM import GHC.Base (remInt, quotInt) import Prelude as P import Control.Monad as P import Data.IORef -- | Select indices matching a predicate selectChunkedS :: (Shape sh, Unbox a) => (sh -> Bool) -- ^ See if this predicate matches. -> (sh -> a) -- ^ .. and apply fn to the matching index -> IOVector a -- ^ .. then write the result into the vector. -> sh -- ^ Extent of indices to apply to predicate. -> IO Int -- ^ Number of elements written to destination array. {-# INLINE selectChunkedS #-} selectChunkedS match produce !vDst !shSize = fill 0 0 where lenSrc = size shSize lenDst = VM.length vDst fill !nSrc !nDst | nSrc >= lenSrc = return nDst | nDst >= lenDst = return nDst | ixSrc <- fromIndex shSize nSrc , match ixSrc = do VM.unsafeWrite vDst nDst (produce ixSrc) fill (nSrc + 1) (nDst + 1) | otherwise = fill (nSrc + 1) nDst -- | Select indices matching a predicate, in parallel. -- The array is chunked up, with one chunk being given to each thread. -- The number of elements in the result array depends on how many threads -- you're running the program with. selectChunkedP :: forall a . Unbox a => (Int -> Bool) -- ^ See if this predicate matches. -> (Int -> a) -- .. and apply fn to the matching index -> Int -- Extent of indices to apply to predicate. -> IO [IOVector a] -- Chunks containing array elements. {-# INLINE selectChunkedP #-} selectChunkedP !match !produce !len = do -- Make IORefs that the threads will write their result chunks to. -- We start with a chunk size proportial to the number of threads we have, -- but the threads themselves can grow the chunks if they run out of space. refs <- P.replicateM threads $ do vec <- VM.new $ len `div` threads newIORef vec -- Fire off a thread to fill each chunk. gangIO theGang $ \thread -> makeChunk (refs !! thread) (splitIx thread) (splitIx (thread + 1) - 1) -- Read the result chunks back from the IORefs. -- If a thread had to grow a chunk, then these might not be the same ones -- we created back in the first step. P.mapM readIORef refs where -- See how many threads we have available. !threads = gangSize theGang !chunkLen = len `quotInt` threads !chunkLeftover = len `remInt` threads -- Decide where to split the source array. {-# INLINE splitIx #-} splitIx thread | thread < chunkLeftover = thread * (chunkLen + 1) | otherwise = thread * chunkLen + chunkLeftover -- Fill the given chunk with elements selected from this range of indices. makeChunk :: IORef (IOVector a) -> Int -> Int -> IO () makeChunk !ref !ixSrc !ixSrcEnd = do vecDst <- VM.new (len `div` threads) vecDst' <- fillChunk ixSrc ixSrcEnd vecDst 0 (VM.length vecDst - 1) writeIORef ref vecDst' -- The main filling loop. fillChunk :: Int -> Int -> IOVector a -> Int -> Int -> IO (IOVector a) fillChunk !ixSrc !ixSrcEnd !vecDst !ixDst !ixDstEnd -- If we've finished selecting elements, then slice the vector down -- so it doesn't have any empty space at the end. | ixSrc >= ixSrcEnd = return $ VM.slice 0 ixDst vecDst -- If we've run out of space in the chunk then grow it some more. | ixDst >= ixDstEnd = do let ixDstEnd' = VM.length vecDst * 2 - 1 vecDst' <- VM.grow vecDst (ixDstEnd + 1) fillChunk (ixSrc + 1) ixSrcEnd vecDst' (ixDst + 1) ixDstEnd' -- We've got a maching element, so add it to the chunk. | match ixSrc = do VM.unsafeWrite vecDst ixDst (produce ixSrc) fillChunk (ixSrc + 1) ixSrcEnd vecDst (ixDst + 1) ixDstEnd -- The element doesnt match, so keep going. | otherwise = fillChunk (ixSrc + 1) ixSrcEnd vecDst ixDst ixDstEnd