{-# LANGUAGE CPP #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.CUDA.CodeGen.PrefixSum -- Copyright : [2008..2014] Manuel M T Chakravarty, Gabriele Keller -- [2009..2014] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.CUDA.CodeGen.PrefixSum ( -- skeletons mkScanl, mkScanl1, mkScanl', mkScanr, mkScanr1, mkScanr', ) where import Data.Maybe import Foreign.CUDA.Analysis import Language.C.Quote.CUDA import qualified Language.C.Syntax as C import Data.Array.Accelerate.Array.Sugar import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.CUDA.AST import Data.Array.Accelerate.CUDA.Analysis.Shape import Data.Array.Accelerate.CUDA.CodeGen.Base errorMsg :: String errorMsg = error $ unlines [ "accelerate-cuda does not support rank-polymorphic scans. Please switch to accelerate-llvm-ptx instead." , "" , "*** https://hackage.haskell.org/package/accelerate-llvm-ptx ***" , "*** https://github.com/AccelerateHS/accelerate-llvm ***" ] -- Wrappers -- -------- mkScanl, mkScanr :: forall aenv sh e. (Shape sh, Elt e) => DeviceProperties -> Gamma aenv -> CUFun2 aenv (e -> e -> e) -> CUExp aenv e -> CUDelayedAcc aenv (sh:.Int) e -> [CUTranslSkel aenv (Array (sh:.Int) e)] mkScanl dev aenv f z a | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = [ mkScan L dev aenv f (Just z) a , mkScanUp1 L dev aenv f a , mkScanUp2 L dev aenv f (Just z) ] | otherwise = error errorMsg mkScanr dev aenv f z a | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = [ mkScan R dev aenv f (Just z) a , mkScanUp1 R dev aenv f a , mkScanUp2 R dev aenv f (Just z) ] | otherwise = error errorMsg mkScanl1, mkScanr1 :: forall aenv sh e. (Shape sh, Elt e) => DeviceProperties -> Gamma aenv -> CUFun2 aenv (e -> e -> e) -> CUDelayedAcc aenv (sh:.Int) e -> [CUTranslSkel aenv (Array (sh:.Int) e)] mkScanl1 dev aenv f a | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = [ mkScan L dev aenv f Nothing a , mkScanUp1 L dev aenv f a , mkScanUp2 L dev aenv f Nothing ] | otherwise = error errorMsg mkScanr1 dev aenv f a | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = [ mkScan R dev aenv f Nothing a , mkScanUp1 R dev aenv f a , mkScanUp2 R dev aenv f Nothing ] | otherwise = error errorMsg mkScanl', mkScanr' :: forall aenv sh e. (Shape sh, Elt e) => DeviceProperties -> Gamma aenv -> CUFun2 aenv (e -> e -> e) -> CUExp aenv e -> CUDelayedAcc aenv (sh:.Int) e -> [CUTranslSkel aenv (Array (sh:.Int) e, Array sh e)] mkScanl' dev aenv f z | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = map cast . mkScanl dev aenv f z | otherwise = error errorMsg mkScanr' dev aenv f z | Just Refl <- matchShapeType (undefined::sh) (undefined::Z) = map cast . mkScanr dev aenv f z | otherwise = error errorMsg cast :: CUTranslSkel aenv a -> CUTranslSkel aenv b cast (CUTranslSkel entry code) = CUTranslSkel entry code -- Core implementation -- ------------------- data Direction = L | R deriving Eq instance Show Direction where show L = "l" show R = "r" -- [OVERVIEW] -- -- Data.List-style exclusive scan, with the additional restriction that the -- first argument needs to be an /associative/ function to enable efficient -- parallel implementation. The initial value may be arbitrary. -- -- scanl :: Elt a -- => (Exp a -> Exp a -> Exp a) -- -> Exp a -- -> Acc (Vector a) -- -> Acc (Vector a) -- -- > scanl (+) 10 (use xs) -- > where -- > xs = fromList (Z:.10) (cycle [1]) -- > -- > ==> Array (Z:.11) [10,11,12,13,14,15,16,17,18,19,20] -- -- Data.List-style inclusive scan without an initial value -- -- scanl1 :: Elt a -- => (Exp a -> Exp a -> Exp a) -- -> Acc (Vector a) -- -> Acc (Vector a) -- -- > scanl1 (+) (use xs) -- > where -- > xs = fromList (Z:.10) (cycle [1]) -- > -- > ==> Array (Z:.10) [1,2,3,4,5,6,7,8,9,10] -- -- Variant of 'scanl' where the final result is returned separately. -- -- scanl' :: Elt a -- => (Exp a -> Exp a -> Exp a) -- -> Exp a -- -> Acc (Vector a) -- -> (Acc (Vector a), Acc (Scalar a)) -- -- Denotationally, we have: -- -- > scanl' f z xs = (init res, last res) -- > where -- > res = scanl f z xs -- -- -- [IMPLEMENTATION] -- -- This code handles all the above cases, in both left and right-handed -- variants. This is the _downsweep_ phase to a multi-block scan procedure. -- We require a work distribution such that there is a _single_ thread block for -- each interval. For multi-block scans, we have an array of interval sums that -- are used to determine the carry-in value from the previous interval. Note -- that 'argBlk' will not be accessed by a single-block scan, so may be null. -- -- We require some pointer manipulation from the calling code in order to -- support all types of scans: -- -- * scanl : argSum should point to the last position of argOut -- * scanr : argSum should be the start of argOut, argOut should be incremented by one -- * scanl1, scanr1 : no change (argSum is required, even though it will not be used Haskell-side) -- * scanl', scanr' : no change -- mkScan :: forall aenv e. Elt e => Direction -> DeviceProperties -> Gamma aenv -> CUFun2 aenv (e -> e -> e) -> Maybe (CUExp aenv e) -> CUDelayedAcc aenv DIM1 e -> CUTranslSkel aenv (Vector e) mkScan dir dev aenv fun@(CUFun2 _ _ combine) mseed (CUDelayed (CUExp shIn) _ (CUFun1 _ get)) = CUTranslSkel scan [cunit| $esc:("#include ") $edecls:texIn extern "C" __global__ void $id:scan ( $params:argIn, $params:argOut, $params:argBlk, $params:(tail argSum) // just the pointers, no shape information ) { $decls:smem $decls:declt $decls:declx $decls:decly $decls:declz $items:(sh .=. shIn) const int shapeSize = $exp:(csize sh); const int intervalSize = (shapeSize + gridDim.x - 1) / gridDim.x; /* * Read in previous result partial sum. We store the carry value in * temporary value 'z' and read new values from the input array into * 'x', since 'scanBlock' will store its results into 'y' on completion. */ int carryIn = 0; if ( threadIdx.x == 0 ) { $stm:initialise } const int start = blockIdx.x * intervalSize; const int end = min(start + intervalSize, shapeSize); const int numElements = end - start; int seg; for ( seg = threadIdx.x ; seg < numElements ; seg += blockDim.x ) { const int ix = $exp:firstIndex; /* * Generate the next set of values */ $items:(x .=. get ix) /* * Carry in the result from the privous segment */ if ( $exp:carryIn ) { $items:(t .=. combine z x) $items:(x .=. t) } /* * Store our input into shared memory and perform a cooperative * inclusive left scan. */ $items:(sdata "threadIdx.x" .=. x) __syncthreads(); $items:(scanBlock dev fun x y t sdata Nothing) /* * Exclusive scans write the result of the previous thread to global * memory. The first thread must reinstate the carry-in value which * is the result of the last thread from the previous interval, or * the carry-in/seed value for multi-block scans. */ if ( $exp:(cbool (isJust mseed)) ) { if ( threadIdx.x == 0 ) { $items:(x .=. z) } else { $items:(x .=. sdata "threadIdx.x - 1") } } $items:(setOut "ix" .=. x) /* * Carry the final result of this block through the set 'z'. If this * is the final interval, this is the value to write out as the * reduction result */ if ( threadIdx.x == 0 ) { const int last = min(numElements - seg, blockDim.x) - 1; $items:(z .=. sdata "last") } $items:setCarry } /* * Finally, exclusive scans set the overall scan result (reduction value) */ $items:setFinal } |] where scan = "scan" ++ show dir ++ maybe "1" (const []) mseed (texIn, argIn) = environment dev aenv (argOut, _, setOut) = writeArray "Out" (undefined :: Vector e) (argSum, _, totalSum) = writeArray "Sum" (undefined :: Vector e) (argBlk, _, blkSum) = writeArray "Blk" (undefined :: Vector e) (_, t, declt) = locals "t" (undefined :: e) (_, x, declx) = locals "x" (undefined :: e) (_, y, decly) = locals "y" (undefined :: e) (_, z, declz) = locals "z" (undefined :: e) (sh, _, _) = locals "sh" (undefined :: DIM1) (smem, sdata) = shared (undefined :: e) "sdata" [cexp| blockDim.x |] Nothing ix = [cvar "ix"] setSum = totalSum "0" -- depending on whether we are inclusive/exclusive scans setCarry | isNothing mseed = [[citem| carryIn = 1; |]] | otherwise = [] setFinal | isNothing mseed = [] | otherwise = [[citem| if ( threadIdx.x == 0 && blockIdx.x == $id:lastBlock ) { $items:(setSum .=. z) } |]] -- accessing neighbouring blocks firstBlock = if dir == L then "0" else "gridDim.x - 1" lastBlock = if dir == R then "0" else "gridDim.x - 1" prevBlock = if dir == L then "blockIdx.x - 1" else "blockIdx.x + 1" firstIndex | dir == L = [cexp| start + seg |] | otherwise = [cexp| end - seg - 1 |] carryIn | isJust mseed = [cexp| threadIdx.x == 0 |] | otherwise = [cexp| threadIdx.x == 0 && carryIn |] -- initialise the first thread with the results of the previous block sweep -- or exclusive scan element initialise | Just (CUExp seed) <- mseed = [cstm| if ( gridDim.x > 1 ) { $items:(z .=. blkSum "blockIdx.x") } else { $items:(z .=. seed) } |] | otherwise = [cstm| if ( blockIdx.x != $id:firstBlock ) { $items:(z .=. blkSum prevBlock) carryIn = 1; } |] -- This computes the _upsweep_ phase of a multi-block scan. This is much like a -- regular inclusive scan, except that only the final value for each interval is -- output, rather than the entire body of the scan. Indeed, if the combination -- function were commutative, this is equivalent to a parallel tree reduction. -- mkScanUp1 :: forall aenv e. Elt e => Direction -> DeviceProperties -> Gamma aenv -> CUFun2 aenv (e -> e -> e) -> CUDelayedAcc aenv DIM1 e -> CUTranslSkel aenv (Vector e) mkScanUp1 dir dev aenv fun@(CUFun2 _ _ combine) (CUDelayed (CUExp shIn) _ (CUFun1 _ get)) = CUTranslSkel scan [cunit| $esc:("#include ") $edecls:texIn extern "C" __global__ void $id:scan ( $params:argIn, $params:argOut ) { $decls:smem $decls:declt $decls:declx $decls:decly $items:(sh .=. shIn) const int shapeSize = $exp:(csize sh); const int intervalSize = (shapeSize + gridDim.x - 1) / gridDim.x; const int start = blockIdx.x * intervalSize; const int end = min(start + intervalSize, shapeSize); const int numElements = end - start; int carryIn = 0; int seg; for ( seg = threadIdx.x ; seg < numElements ; seg += blockDim.x ) { const int ix = $exp:firstIndex ; /* * Read in new values, combine with carry-in */ $items:(x .=. get ix) if ( threadIdx.x == 0 && carryIn ) { $items:(t .=. combine y x) $items:(x .=. t) } /* * Store in shared memory and cooperatively scan */ $items:(sdata "threadIdx.x" .=. x) __syncthreads(); $items:(scanBlock dev fun x y t sdata Nothing) /* * Store the final result of the block to be carried in */ if ( threadIdx.x == 0 ) { const int last = min(numElements - seg, blockDim.x) - 1; $items:(y .=. sdata "last") } carryIn = 1; } /* * Finally, the first thread writes the result of this interval */ if ( threadIdx.x == 0 ) { $items:(setOut "blockIdx.x" .=. y) } } |] where scan = "scan" ++ show dir ++ "Up" (texIn, argIn) = environment dev aenv (argOut, _, setOut) = writeArray "Out" (undefined :: Vector e) (_, x, declx) = locals "x" (undefined :: e) (_, y, decly) = locals "y" (undefined :: e) (_, t, declt) = locals "t" (undefined :: e) (sh, _, _) = locals "sh" (undefined :: DIM1) (smem, sdata) = shared (undefined :: e) "sdata" [cexp| blockDim.x |] Nothing ix = [cvar "ix"] firstIndex | dir == L = [cexp| start + seg |] | otherwise = [cexp| end - seg - 1 |] -- Second step of the upsweep phase: scan the interval sums to produce carry-in -- values for each block of the final downsweep step -- mkScanUp2 :: forall aenv e. Elt e => Direction -> DeviceProperties -> Gamma aenv -> CUFun2 aenv (e -> e -> e) -> Maybe (CUExp aenv e) -> CUTranslSkel aenv (Vector e) mkScanUp2 dir dev aenv f z = let (_, _, get) = readArray "Blk" (undefined :: Vector e) in mkScan dir dev aenv f z get -- Block scans -- =========== scanBlock :: forall aenv e. Elt e => DeviceProperties -> CUFun2 aenv (e -> e -> e) -> [C.Exp] -> [C.Exp] -> [C.Exp] -> (Name -> [C.Exp]) -> Maybe C.Exp -> [C.BlockItem] scanBlock dev f x0 x1 x2 sdata mlim | shflOK dev (undefined :: e) = error "shfl-scan" | otherwise = scanBlockTree dev f x0 x1 x2 sdata mlim -- Use a thread block to scan values in shared memory. Each thread must have -- already stored its initial value into shared memory. The final result for -- this thread will be stored in x0 as well as the appropriate place in shared -- memory. -- scanBlockTree :: forall aenv e. Elt e => DeviceProperties -> CUFun2 aenv (e -> e -> e) -> [C.Exp] -> [C.Exp] -> [C.Exp] -- input variables x0 and x1, plus a temporary to store the intermediate value -> (Name -> [C.Exp]) -- index elements from shared memory -> Maybe C.Exp -- partially full block bounds check? -> [C.BlockItem] scanBlockTree dev (CUFun2 _ _ f) x0 x1 x2 sdata mlim = map (scan . pow2) [ 0 .. maxThreads ] where pow2 :: Int -> Int pow2 x = 2 ^ x maxThreads = floor (logBase 2 (fromIntegral $ maxThreadsPerBlock dev :: Double)) inrange n | Just m <- mlim = [cexp| threadIdx.x >= $int:n && threadIdx.x < $exp:m |] | otherwise = [cexp| threadIdx.x >= $int:n |] scan n = [citem| if ( blockDim.x > $int:n ) { if ( $exp:(inrange n) ) { $items:(x1 .=. sdata ("threadIdx.x - " ++ show n)) $items:(x2 .=. f x1 x0) $items:(x0 .=. x2) } __syncthreads(); $items:(sdata "threadIdx.x" .=. x0) __syncthreads(); } |] -- Shuffle scan -- ------------ shflOK :: Elt e => DeviceProperties -> e -> Bool shflOK _dev _ = False -- shflOk dev dummy -- = computeCapability dev >= Compute 3 0 && all (`elem` [4,8]) (eltSizeOf dummy) {-- scanWarpShfl :: forall aenv e. Elt e => DeviceProperties -> CUFun2 aenv (e -> e -> e) -> [C.Exp] -> [C.Exp] -- temporary variables x0 and x1 -> Maybe C.Exp -- partially full block bounds check -> C.Exp -- thread identified, usually lane or thread ID -> C.Stm scanWarpShfl _dev (CUFun2 f) x0 x1 mlim tid = [cstm| for ( int z = 1; z <= warpSize; z *= 2 ) { $items:(x0 .=. shfl_up x1) if ( $exp:inrange ) { $items:(x1 .=. f x1 x0) } } |] where inrange | Just m <- mlim = [cexp| $exp:tid >= z && $exp:tid < $exp:m |] | otherwise = [cexp| $exp:tid >= z |] sizeof = eltSizeOf (undefined :: e) shfl_up = zipWith (\s x -> ccall (shfl s) [ x, cvar "z" ]) sizeof where shfl 4 = "shfl_up32" shfl 8 = "shfl_up64" shfl _ = INTERNAL_ERROR(error) "shfl_up" "I only know about 32- and 64-bit types" --}