module Data.Array.Accelerate.CUDA.CodeGen.PrefixSum (
mkScanl, mkScanl1, mkScanl',
mkScanr, mkScanr1, mkScanr',
) where
import Data.Maybe
import Foreign.CUDA.Analysis
import Language.C.Quote.CUDA
import qualified Language.C.Syntax as C
import Data.Array.Accelerate.Array.Sugar ( Vector, Scalar, Elt, DIM1 )
import Data.Array.Accelerate.CUDA.AST
import Data.Array.Accelerate.CUDA.CodeGen.Base
mkScanl, mkScanr
:: Elt e
=> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (e -> e -> e)
-> CUExp aenv e
-> CUDelayedAcc aenv DIM1 e
-> [CUTranslSkel aenv (Vector e)]
mkScanl dev aenv f z a =
[ mkScan L dev aenv f (Just z) a
, mkScanUp1 L dev aenv f a
, mkScanUp2 L dev aenv f (Just z) ]
mkScanr dev aenv f z a =
[ mkScan R dev aenv f (Just z) a
, mkScanUp1 R dev aenv f a
, mkScanUp2 R dev aenv f (Just z) ]
mkScanl1, mkScanr1
:: Elt e
=> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (e -> e -> e)
-> CUDelayedAcc aenv DIM1 e
-> [CUTranslSkel aenv (Vector e)]
mkScanl1 dev aenv f a =
[ mkScan L dev aenv f Nothing a
, mkScanUp1 L dev aenv f a
, mkScanUp2 L dev aenv f Nothing ]
mkScanr1 dev aenv f a =
[ mkScan R dev aenv f Nothing a
, mkScanUp1 R dev aenv f a
, mkScanUp2 R dev aenv f Nothing ]
mkScanl', mkScanr'
:: Elt e
=> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (e -> e -> e)
-> CUExp aenv e
-> CUDelayedAcc aenv DIM1 e
-> [CUTranslSkel aenv (Vector e, Scalar e)]
mkScanl' dev aenv f z = map cast . mkScanl dev aenv f z
mkScanr' dev aenv f z = map cast . mkScanr dev aenv f z
cast :: CUTranslSkel aenv a -> CUTranslSkel aenv b
cast (CUTranslSkel entry code) = CUTranslSkel entry code
data Direction = L | R
deriving Eq
instance Show Direction where
show L = "l"
show R = "r"
mkScan :: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (e -> e -> e)
-> Maybe (CUExp aenv e)
-> CUDelayedAcc aenv DIM1 e
-> CUTranslSkel aenv (Vector e)
mkScan dir dev aenv fun@(CUFun2 _ _ combine) mseed (CUDelayed (CUExp shIn) _ (CUFun1 _ get)) =
CUTranslSkel scan [cunit|
$esc:("#include <accelerate_cuda.h>")
$edecls:texIn
extern "C" __global__ void
$id:scan
(
$params:argIn,
$params:argOut,
$params:argBlk,
$params:(tail argSum) // just the pointers, no shape information
)
{
$decls:smem
$decls:declt
$decls:declx
$decls:decly
$decls:declz
$items:(sh .=. shIn)
const int shapeSize = $exp:(csize sh);
const int intervalSize = (shapeSize + gridDim.x 1) / gridDim.x;
/*
* Read in previous result partial sum. We store the carry value in
* temporary value 'z' and read new values from the input array into
* 'x', since 'scanBlock' will store its results into 'y' on completion.
*/
int carryIn = 0;
if ( threadIdx.x == 0 ) {
$stm:initialise
}
const int start = blockIdx.x * intervalSize;
const int end = min(start + intervalSize, shapeSize);
const int numElements = end start;
int seg;
for ( seg = threadIdx.x
; seg < numElements
; seg += blockDim.x )
{
const int ix = $exp:firstIndex;
/*
* Generate the next set of values
*/
$items:(x .=. get ix)
/*
* Carry in the result from the privous segment
*/
if ( $exp:carryIn ) {
$items:(t .=. combine z x)
$items:(x .=. t)
}
/*
* Store our input into shared memory and perform a cooperative
* inclusive left scan.
*/
$items:(sdata "threadIdx.x" .=. x)
__syncthreads();
$items:(scanBlock dev fun x y t sdata Nothing)
/*
* Exclusive scans write the result of the previous thread to global
* memory. The first thread must reinstate the carryin value which
* is the result of the last thread from the previous interval, or
* the carryin/seed value for multiblock scans.
*/
if ( $exp:(cbool (isJust mseed)) ) {
if ( threadIdx.x == 0 ) {
$items:(x .=. z)
} else {
$items:(x .=. sdata "threadIdx.x - 1")
}
}
$items:(setOut "ix" .=. x)
/*
* Carry the final result of this block through the set 'z'. If this
* is the final interval, this is the value to write out as the
* reduction result
*/
if ( threadIdx.x == 0 ) {
const int last = min(numElements seg, blockDim.x) 1;
$items:(z .=. sdata "last")
}
$items:setCarry
}
/*
* Finally, exclusive scans set the overall scan result (reduction value)
*/
$items:setFinal
}
|]
where
scan = "scan" ++ show dir ++ maybe "1" (const []) mseed
(texIn, argIn) = environment dev aenv
(argOut, _, setOut) = writeArray "Out" (undefined :: Vector e)
(argSum, _, totalSum) = writeArray "Sum" (undefined :: Vector e)
(argBlk, _, blkSum) = writeArray "Blk" (undefined :: Vector e)
(_, t, declt) = locals "t" (undefined :: e)
(_, x, declx) = locals "x" (undefined :: e)
(_, y, decly) = locals "y" (undefined :: e)
(_, z, declz) = locals "z" (undefined :: e)
(sh, _, _) = locals "sh" (undefined :: DIM1)
(smem, sdata) = shared (undefined :: e) "sdata" [cexp| blockDim.x |] Nothing
ix = [cvar "ix"]
setSum = totalSum "0"
setCarry
| isNothing mseed = [[citem| carryIn = 1; |]]
| otherwise = []
setFinal
| isNothing mseed = []
| otherwise = [[citem| if ( threadIdx.x == 0 && blockIdx.x == $id:lastBlock ) {
$items:(setSum .=. z)
} |]]
firstBlock = if dir == L then "0" else "gridDim.x - 1"
lastBlock = if dir == R then "0" else "gridDim.x - 1"
prevBlock = if dir == L then "blockIdx.x - 1" else "blockIdx.x + 1"
firstIndex
| dir == L = [cexp| start + seg |]
| otherwise = [cexp| end seg 1 |]
carryIn
| isJust mseed = [cexp| threadIdx.x == 0 |]
| otherwise = [cexp| threadIdx.x == 0 && carryIn |]
initialise
| Just (CUExp seed) <- mseed
= [cstm| if ( gridDim.x > 1 ) {
$items:(z .=. blkSum "blockIdx.x")
} else {
$items:(z .=. seed)
}
|]
| otherwise
= [cstm| if ( blockIdx.x != $id:firstBlock ) {
$items:(z .=. blkSum prevBlock)
carryIn = 1;
}
|]
mkScanUp1
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (e -> e -> e)
-> CUDelayedAcc aenv DIM1 e
-> CUTranslSkel aenv (Vector e)
mkScanUp1 dir dev aenv fun@(CUFun2 _ _ combine) (CUDelayed (CUExp shIn) _ (CUFun1 _ get)) =
CUTranslSkel scan [cunit|
$esc:("#include <accelerate_cuda.h>")
$edecls:texIn
extern "C" __global__ void
$id:scan
(
$params:argIn,
$params:argOut
)
{
$decls:smem
$decls:declt
$decls:declx
$decls:decly
$items:(sh .=. shIn)
const int shapeSize = $exp:(csize sh);
const int intervalSize = (shapeSize + gridDim.x 1) / gridDim.x;
const int start = blockIdx.x * intervalSize;
const int end = min(start + intervalSize, shapeSize);
const int numElements = end start;
int carryIn = 0;
int seg;
for ( seg = threadIdx.x
; seg < numElements
; seg += blockDim.x )
{
const int ix = $exp:firstIndex ;
/*
* Read in new values, combine with carryin
*/
$items:(x .=. get ix)
if ( threadIdx.x == 0 && carryIn ) {
$items:(t .=. combine y x)
$items:(x .=. t)
}
/*
* Store in shared memory and cooperatively scan
*/
$items:(sdata "threadIdx.x" .=. x)
__syncthreads();
$items:(scanBlock dev fun x y t sdata Nothing)
/*
* Store the final result of the block to be carried in
*/
if ( threadIdx.x == 0 ) {
const int last = min(numElements seg, blockDim.x) 1;
$items:(y .=. sdata "last")
}
carryIn = 1;
}
/*
* Finally, the first thread writes the result of this interval
*/
if ( threadIdx.x == 0 ) {
$items:(setOut "blockIdx.x" .=. y)
}
}
|]
where
scan = "scan" ++ show dir ++ "Up"
(texIn, argIn) = environment dev aenv
(argOut, _, setOut) = writeArray "Out" (undefined :: Vector e)
(_, x, declx) = locals "x" (undefined :: e)
(_, y, decly) = locals "y" (undefined :: e)
(_, t, declt) = locals "t" (undefined :: e)
(sh, _, _) = locals "sh" (undefined :: DIM1)
(smem, sdata) = shared (undefined :: e) "sdata" [cexp| blockDim.x |] Nothing
ix = [cvar "ix"]
firstIndex
| dir == L = [cexp| start + seg |]
| otherwise = [cexp| end seg 1 |]
mkScanUp2
:: forall aenv e. Elt e
=> Direction
-> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (e -> e -> e)
-> Maybe (CUExp aenv e)
-> CUTranslSkel aenv (Vector e)
mkScanUp2 dir dev aenv f z
= let (_, get) = readArray "Blk" (undefined :: Vector e)
in mkScan dir dev aenv f z get
scanBlock
:: forall aenv e. Elt e
=> DeviceProperties
-> CUFun2 aenv (e -> e -> e)
-> [C.Exp] -> [C.Exp] -> [C.Exp]
-> (Name -> [C.Exp])
-> Maybe C.Exp
-> [C.BlockItem]
scanBlock dev f x0 x1 x2 sdata mlim
| shflOK dev (undefined :: e) = error "shfl-scan"
| otherwise = scanBlockTree dev f x0 x1 x2 sdata mlim
scanBlockTree
:: forall aenv e. Elt e
=> DeviceProperties
-> CUFun2 aenv (e -> e -> e)
-> [C.Exp] -> [C.Exp] -> [C.Exp]
-> (Name -> [C.Exp])
-> Maybe C.Exp
-> [C.BlockItem]
scanBlockTree dev (CUFun2 _ _ f) x0 x1 x2 sdata mlim = map (scan . pow2) [ 0 .. maxThreads ]
where
pow2 :: Int -> Int
pow2 x = 2 ^ x
maxThreads = floor (logBase 2 (fromIntegral $ maxThreadsPerBlock dev :: Double))
inrange n
| Just m <- mlim = [cexp| threadIdx.x >= $int:n && threadIdx.x < $exp:m |]
| otherwise = [cexp| threadIdx.x >= $int:n |]
scan n = [citem|
if ( blockDim.x > $int:n ) {
if ( $exp:(inrange n) ) {
$items:(x1 .=. sdata ("threadIdx.x - " ++ show n))
$items:(x2 .=. f x1 x0)
$items:(x0 .=. x2)
}
__syncthreads();
$items:(sdata "threadIdx.x" .=. x0)
__syncthreads();
}
|]
shflOK :: Elt e => DeviceProperties -> e -> Bool
shflOK _dev _ = False