{-# LANGUAGE BangPatterns #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} -- | -- Module : Data.Array.Accelerate.CUDA.CodeGen.Stencil -- Copyright : [2008..2014] Manuel M T Chakravarty, Gabriele Keller -- [2009..2014] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.CUDA.CodeGen.Stencil ( mkStencil, mkStencil2 ) where import Foreign.CUDA.Analysis import Language.C.Quote.CUDA import Data.Array.Accelerate.Type ( Boundary(..) ) import Data.Array.Accelerate.Array.Sugar ( Array, Elt ) import Data.Array.Accelerate.Analysis.Stencil import Data.Array.Accelerate.CUDA.AST hiding ( stencil, stencilAccess ) import Data.Array.Accelerate.CUDA.CodeGen.Base import Data.Array.Accelerate.CUDA.CodeGen.Stencil.Extra -- Map a stencil over an array. In contrast to 'map', the domain of a stencil -- function is an entire /neighbourhood/ of each array element. Neighbourhoods -- are sub-arrays centred around a focal point. They are not necessarily -- rectangular, but they are symmetric and have an extent of at least three in -- each dimensions. Due to this symmetry requirement, the extent is necessarily -- odd. The focal point is the array position that determines the single output -- element for each application of the stencil. -- -- For those array positions where the neighbourhood extends past the boundaries -- of the source array, a boundary condition determines the contents of the -- out-of-bounds neighbourhood positions. -- -- stencil :: (Shape ix, Elt a, Elt b, Stencil ix a stencil) -- => (stencil -> Exp b) -- stencil function -- -> Boundary a -- boundary condition -- -> Acc (Array ix a) -- source array -- -> Acc (Array ix b) -- destination array -- -- To improve performance on older (1.x series) devices, the input array(s) are -- read through the texture cache. -- mkStencil :: forall aenv sh stencil a b. (Stencil sh a stencil, Elt b) => DeviceProperties -> Gamma aenv -> CUFun1 aenv (stencil -> b) -> Boundary (CUExp aenv a) -> [CUTranslSkel aenv (Array sh b)] mkStencil dev aenv (CUFun1 dce f) boundary = return $ CUTranslSkel "stencil" [cunit| $esc:("#include ") $edecls:texIn $edecls:texStencil extern "C" __global__ void stencil ( $params:argIn, $params:argOut, $params:argStencil ) { const int shapeSize = $exp:(csize shOut); const int gridSize = $exp:(gridSize dev); int ix; for ( ix = $exp:(threadIdx dev) ; ix < shapeSize ; ix += gridSize ) { $items:(sh .=. cfromIndex shOut "ix" "tmp") $items:stencilBody } } |] where (texIn, argIn) = environment dev aenv (argOut, shOut, setOut) = writeArray "Out" (undefined :: Array sh b) (sh, _, _) = locals "sh" (undefined :: sh) (xs,_,_) = locals "x" (undefined :: stencil) dx = offsets (undefined :: Fun aenv (stencil -> b)) (undefined :: OpenAcc aenv (Array sh a)) (texStencil, argStencil, safeIndex) = stencilAccess dev True True "Stencil" "Stencil" "w" "ix" dx boundary dce (_, _, unsafeIndex) = stencilAccess dev True False "Stencil" "Stencil" "w" "ix" dx boundary dce stencilBody | computeCapability dev < Compute 1 2 = with safeIndex | otherwise = [[citem| if ( __all( $exp:(insideRegion shOut (borderRegion dx) (map rvalue sh)) ) ) { $items:(with unsafeIndex) } else { $items:(with safeIndex) } |]] where with stencil = (dce xs .=. stencil sh) ++ (setOut "ix" .=. f xs) -- Map a binary stencil of an array. The extent of the resulting array is the -- intersection of the extents of the two source arrays. -- -- stencil2 :: (Shape ix, Elt a, Elt b, Elt c, -- Stencil ix a stencil1, -- Stencil ix b stencil2) -- => (stencil1 -> stencil2 -> Exp c) -- binary stencil function -- -> Boundary a -- boundary condition #1 -- -> Acc (Array ix a) -- source array #1 -- -> Boundary b -- boundary condition #2 -- -> Acc (Array ix b) -- source array #2 -- -> Acc (Array ix c) -- destination array -- mkStencil2 :: forall aenv sh stencil1 stencil2 a b c. (Stencil sh a stencil1, Stencil sh b stencil2, Elt c) => DeviceProperties -> Gamma aenv -> CUFun2 aenv (stencil1 -> stencil2 -> c) -> Boundary (CUExp aenv a) -> Boundary (CUExp aenv b) -> [CUTranslSkel aenv (Array sh c)] mkStencil2 dev aenv stencil boundary1 boundary2 = [ mkStencil2' dev False aenv stencil boundary1 boundary2 , mkStencil2' dev True aenv stencil boundary1 boundary2 ] mkStencil2' :: forall aenv sh stencil1 stencil2 a b c. (Stencil sh a stencil1, Stencil sh b stencil2, Elt c) => DeviceProperties -> Bool -- are the source arrays the same extent? -> Gamma aenv -> CUFun2 aenv (stencil1 -> stencil2 -> c) -> Boundary (CUExp aenv a) -> Boundary (CUExp aenv b) -> CUTranslSkel aenv (Array sh c) mkStencil2' dev sameExtent aenv (CUFun2 dce1 dce2 f) boundary1 boundary2 = CUTranslSkel "stencil2" [cunit| $esc:("#include ") $edecls:texIn $edecls:texS1 $edecls:texS2 extern "C" __global__ void stencil2 ( $params:argIn, $params:argOut, $params:argS1, $params:argS2 ) { const int shapeSize = $exp:(csize shOut); const int gridSize = $exp:(gridSize dev); int ix; for ( ix = $exp:(threadIdx dev) ; ix < shapeSize ; ix += gridSize ) { $items:(sh .=. cfromIndex shOut "ix" "tmp") $items:stencilBody } } |] where (texIn, argIn) = environment dev aenv (argOut, shOut, setOut) = writeArray "Out" (undefined :: Array sh c) (sh, _, _) = locals "sh" (undefined :: sh) (xs,_,_) = locals "x" (undefined :: stencil1) (ys,_,_) = locals "y" (undefined :: stencil2) grp1 = "Stencil1" grp2 = "Stencil2" -- If the source arrays are the same extent, twiddle the names a bit so that -- code generation refers to the same set of shape variables. Then, if there -- are any duplicate calculations, hope that the CUDA compiler is smart -- enough and spots this. -- sh1 = grp1 sh2 | sameExtent = sh1 | otherwise = grp2 (dx1, dx2) = offsets2 (undefined :: Fun aenv (stencil1 -> stencil2 -> c)) (undefined :: OpenAcc aenv (Array sh a)) (undefined :: OpenAcc aenv (Array sh b)) border = zipWith max (borderRegion dx1) (borderRegion dx2) (texS1, argS1, safeIndex1) = stencilAccess dev sameExtent True grp1 sh1 "w" "ix" dx1 boundary1 dce1 (_, _, unsafeIndex1) = stencilAccess dev sameExtent False grp1 sh1 "w" "ix" dx1 boundary1 dce1 (texS2, argS2, safeIndex2) = stencilAccess dev sameExtent True grp2 sh2 "z" "ix" dx2 boundary2 dce2 (_, _, unsafeIndex2) = stencilAccess dev sameExtent False grp2 sh2 "z" "ix" dx2 boundary2 dce2 stencilBody | computeCapability dev < Compute 1 2 = with safeIndex1 safeIndex2 | otherwise = [[citem| if ( __all( $exp:(insideRegion shOut border (map rvalue sh)) ) ) { $items:(with unsafeIndex1 unsafeIndex2) } else { $items:(with safeIndex1 safeIndex2) } |]] where with stencil1 stencil2 = (dce1 xs .=. stencil1 sh) ++ (dce2 ys .=. stencil2 sh) ++ (setOut "ix" .=. f xs ys)