module Data.Array.Accelerate.CUDA.CodeGen.IndexSpace (
mkGenerate,
mkTransform, mkPermute,
) where
import Language.C.Quote.CUDA
import Foreign.CUDA.Analysis.Device
import qualified Language.C.Syntax as C
import Data.Array.Accelerate.Array.Sugar ( Array, Shape, Elt, ignore, shapeToList )
import Data.Array.Accelerate.Error ( internalError )
import Data.Array.Accelerate.CUDA.AST ( Gamma )
import Data.Array.Accelerate.CUDA.CodeGen.Base
mkGenerate
:: forall aenv sh e. (Shape sh, Elt e)
=> DeviceProperties
-> Gamma aenv
-> CUFun1 aenv (sh -> e)
-> [CUTranslSkel aenv (Array sh e)]
mkGenerate dev aenv (CUFun1 dce f)
= return
$ CUTranslSkel "generate" [cunit|
$esc:("#include <accelerate_cuda.h>")
$edecls:texIn
extern "C" __global__ void
generate
(
$params:argIn,
$params:argOut
)
{
const int shapeSize = $exp:(csize shOut);
const int gridSize = $exp:(gridSize dev);
int ix;
for ( ix = $exp:(threadIdx dev)
; ix < shapeSize
; ix += gridSize )
{
$items:(dce sh .=. cfromIndex shOut "ix" "tmp")
$items:(setOut "ix" .=. f sh)
}
}
|]
where
(sh, _, _) = locals "sh" (undefined :: sh)
(texIn, argIn) = environment dev aenv
(argOut, shOut, setOut) = writeArray "Out" (undefined :: Array sh e)
mkTransform
:: forall aenv sh sh' a b. (Shape sh, Shape sh', Elt a, Elt b)
=> DeviceProperties
-> Gamma aenv
-> CUFun1 aenv (sh' -> sh)
-> CUFun1 aenv (a -> b)
-> CUDelayedAcc aenv sh a
-> [CUTranslSkel aenv (Array sh' b)]
mkTransform dev aenv perm fun arr
| CUFun1 dce_p p <- perm
, CUFun1 dce_f f <- fun
, CUDelayed _ (CUFun1 dce_g get) _ <- arr
= return
$ CUTranslSkel "transform" [cunit|
$esc:("#include <accelerate_cuda.h>")
$edecls:texIn
extern "C" __global__ void
transform
(
$params:argIn,
$params:argOut
)
{
const int shapeSize = $exp:(csize shOut);
const int gridSize = $exp:(gridSize dev);
int ix;
for ( ix = $exp:(threadIdx dev)
; ix < shapeSize
; ix += gridSize )
{
$items:(dce_p sh' .=. cfromIndex shOut "ix" "tmp")
$items:(dce_g sh .=. p sh')
$items:(dce_f x0 .=. get sh)
$items:(setOut "ix" .=. f x0)
}
}
|]
where
(texIn, argIn) = environment dev aenv
(argOut, shOut, setOut) = writeArray "Out" (undefined :: Array sh' b)
(x0, _, _) = locals "x" (undefined :: a)
(sh, _, _) = locals "sh" (undefined :: sh)
(sh', _, _) = locals "sh_" (undefined :: sh')
mkPermute
:: forall aenv sh sh' e. (Shape sh, Shape sh', Elt e)
=> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (e -> e -> e)
-> CUFun1 aenv (sh -> sh')
-> CUDelayedAcc aenv sh e
-> [CUTranslSkel aenv (Array sh' e)]
mkPermute dev aenv (CUFun2 dce_x dce_y combine) (CUFun1 dce_p prj) arr
| CUDelayed (CUExp shIn) _ (CUFun1 _ get) <- arr
= return
$ CUTranslSkel "permute" [cunit|
$esc:("#include <accelerate_cuda.h>")
$edecls:texIn
extern "C" __global__ void
permute
(
$params:argIn,
$params:argOut,
typename Int32 * __restrict__ lock
)
{
/*
* The input shape might be a complex expression. Evaluate it first to reuse the result.
*/
$items:(sh .=. shIn)
const int shapeSize = $exp:(csize sh);
const int gridSize = $exp:(gridSize dev);
int ix;
for ( ix = $exp:(threadIdx dev)
; ix < shapeSize
; ix += gridSize )
{
$items:(dce_p src .=. cfromIndex sh "ix" "srcTmp")
$items:(dst .=. prj src)
if ( ! $exp:(cignore dst) )
{
$items:(jx .=. ctoIndex shOut dst)
$items:(dce_x x .=. get ix)
$items:(atomically jx
[ dce_y y .=. setOut jx
, setOut jx .=. combine x y ]
)
}
}
}
|]
where
(texIn, argIn) = environment dev aenv
(argOut, shOut, setOut) = writeArray "Out" (undefined :: Array sh' e)
(x, _, _) = locals "x" (undefined :: e)
(y, _, _) = locals "y" (undefined :: e)
(sh, _, _) = locals "shIn" (undefined :: sh)
(src, _, _) = locals "sh" (undefined :: sh)
(dst, _, _) = locals "sh_" (undefined :: sh')
([jx], _, _) = locals "jx" (undefined :: Int)
ix = [cvar "ix"]
sm = computeCapability dev
cignore :: Rvalue x => [x] -> C.Exp
cignore [] = $internalError "permute" "singleton arrays not supported"
cignore xs = foldl1 (\a b -> [cexp| $exp:a && $exp:b |])
$ zipWith (\a b -> [cexp| $exp:(rvalue a) == $int:b |]) xs
$ shapeToList (ignore :: sh')
mustLock = or . fst . unzip $ dce_y y
atomically :: (C.Type, Name) -> [[C.BlockItem]] -> [C.BlockItem]
atomically (_,i) (concat -> body)
| not mustLock = body
| sm < Compute 1 1 = $internalError "permute" "Requires at least compute compatibility 1.1"
| otherwise =
[ [citem| typename Int32 done = 0; |]
, [citem| do {
__threadfence();
if ( atomicExch(&lock[ $exp:(cvar i) ], 1) == 0 ) {
$items:body
done = 1;
atomicExch(&lock[ $exp:(cvar i) ], 0);
}
} while (done == 0);
|]
]