module Data.Array.Accelerate.CUDA.CodeGen.Stencil (
mkStencil, mkStencil2
) where
import Foreign.CUDA.Analysis
import Language.C.Quote.CUDA
import Data.Array.Accelerate.Type ( Boundary(..) )
import Data.Array.Accelerate.Array.Sugar ( Array, Elt )
import Data.Array.Accelerate.Analysis.Stencil
import Data.Array.Accelerate.CUDA.AST hiding ( stencil, stencilAccess )
import Data.Array.Accelerate.CUDA.CodeGen.Base
import Data.Array.Accelerate.CUDA.CodeGen.Stencil.Extra
mkStencil
:: forall aenv sh stencil a b. (Stencil sh a stencil, Elt b)
=> DeviceProperties
-> Gamma aenv
-> CUFun1 aenv (stencil -> b)
-> Boundary (CUExp aenv a)
-> [CUTranslSkel aenv (Array sh b)]
mkStencil dev aenv (CUFun1 dce f) boundary
= return
$ CUTranslSkel "stencil" [cunit|
$esc:("#include <accelerate_cuda.h>")
$edecls:texIn
$edecls:texStencil
extern "C" __global__ void
stencil
(
$params:argIn,
$params:argOut,
$params:argStencil
)
{
const int shapeSize = $exp:(csize shOut);
const int gridSize = $exp:(gridSize dev);
int ix;
for ( ix = $exp:(threadIdx dev)
; ix < shapeSize
; ix += gridSize )
{
$items:(sh .=. cfromIndex shOut "ix" "tmp")
$items:stencilBody
}
}
|]
where
(texIn, argIn) = environment dev aenv
(argOut, shOut, setOut) = writeArray "Out" (undefined :: Array sh b)
(sh, _, _) = locals "sh" (undefined :: sh)
(xs,_,_) = locals "x" (undefined :: stencil)
dx = offsets (undefined :: Fun aenv (stencil -> b))
(undefined :: OpenAcc aenv (Array sh a))
(texStencil, argStencil, safeIndex) = stencilAccess dev True True "Stencil" "Stencil" "w" "ix" dx boundary dce
(_, _, unsafeIndex) = stencilAccess dev True False "Stencil" "Stencil" "w" "ix" dx boundary dce
stencilBody
| computeCapability dev < Compute 1 2 = with safeIndex
| otherwise =
[[citem| if ( __all( $exp:(insideRegion shOut (borderRegion dx) (map rvalue sh)) ) ) {
$items:(with unsafeIndex)
} else {
$items:(with safeIndex)
} |]]
where
with stencil = (dce xs .=. stencil sh) ++
(setOut "ix" .=. f xs)
mkStencil2
:: forall aenv sh stencil1 stencil2 a b c.
(Stencil sh a stencil1, Stencil sh b stencil2, Elt c)
=> DeviceProperties
-> Gamma aenv
-> CUFun2 aenv (stencil1 -> stencil2 -> c)
-> Boundary (CUExp aenv a)
-> Boundary (CUExp aenv b)
-> [CUTranslSkel aenv (Array sh c)]
mkStencil2 dev aenv stencil boundary1 boundary2
= [ mkStencil2' dev False aenv stencil boundary1 boundary2
, mkStencil2' dev True aenv stencil boundary1 boundary2
]
mkStencil2'
:: forall aenv sh stencil1 stencil2 a b c.
(Stencil sh a stencil1, Stencil sh b stencil2, Elt c)
=> DeviceProperties
-> Bool
-> Gamma aenv
-> CUFun2 aenv (stencil1 -> stencil2 -> c)
-> Boundary (CUExp aenv a)
-> Boundary (CUExp aenv b)
-> CUTranslSkel aenv (Array sh c)
mkStencil2' dev sameExtent aenv (CUFun2 dce1 dce2 f) boundary1 boundary2
= CUTranslSkel "stencil2" [cunit|
$esc:("#include <accelerate_cuda.h>")
$edecls:texIn
$edecls:texS1
$edecls:texS2
extern "C" __global__ void
stencil2
(
$params:argIn,
$params:argOut,
$params:argS1,
$params:argS2
)
{
const int shapeSize = $exp:(csize shOut);
const int gridSize = $exp:(gridSize dev);
int ix;
for ( ix = $exp:(threadIdx dev)
; ix < shapeSize
; ix += gridSize )
{
$items:(sh .=. cfromIndex shOut "ix" "tmp")
$items:stencilBody
}
}
|]
where
(texIn, argIn) = environment dev aenv
(argOut, shOut, setOut) = writeArray "Out" (undefined :: Array sh c)
(sh, _, _) = locals "sh" (undefined :: sh)
(xs,_,_) = locals "x" (undefined :: stencil1)
(ys,_,_) = locals "y" (undefined :: stencil2)
grp1 = "Stencil1"
grp2 = "Stencil2"
sh1 = grp1
sh2 | sameExtent = sh1
| otherwise = grp2
(dx1, dx2) = offsets2 (undefined :: Fun aenv (stencil1 -> stencil2 -> c))
(undefined :: OpenAcc aenv (Array sh a))
(undefined :: OpenAcc aenv (Array sh b))
border = zipWith max (borderRegion dx1) (borderRegion dx2)
(texS1, argS1, safeIndex1) = stencilAccess dev sameExtent True grp1 sh1 "w" "ix" dx1 boundary1 dce1
(_, _, unsafeIndex1) = stencilAccess dev sameExtent False grp1 sh1 "w" "ix" dx1 boundary1 dce1
(texS2, argS2, safeIndex2) = stencilAccess dev sameExtent True grp2 sh2 "z" "ix" dx2 boundary2 dce2
(_, _, unsafeIndex2) = stencilAccess dev sameExtent False grp2 sh2 "z" "ix" dx2 boundary2 dce2
stencilBody
| computeCapability dev < Compute 1 2 = with safeIndex1 safeIndex2
| otherwise =
[[citem| if ( __all( $exp:(insideRegion shOut border (map rvalue sh)) ) ) {
$items:(with unsafeIndex1 unsafeIndex2)
} else {
$items:(with safeIndex1 safeIndex2)
} |]]
where
with stencil1 stencil2 =
(dce1 xs .=. stencil1 sh) ++
(dce2 ys .=. stencil2 sh) ++
(setOut "ix" .=. f xs ys)