module Data.Array.Accelerate.CUDA.CodeGen.Stencil.Extra (
stencilAccess,
cinRange, cclamp, cmirror, cwrap, insideRegion, borderRegion,
) where
import Prelude hiding ( and, zipWith, zipWith3 )
import Data.List ( transpose )
import Control.Monad
import Control.Monad.State.Strict
import Language.C.Quote.CUDA
import qualified Language.C.Syntax as C
import Data.Array.Accelerate.Error ( internalError )
import Data.Array.Accelerate.Type ( Boundary(..) )
import Data.Array.Accelerate.Array.Sugar ( Array, Shape, Elt, shapeToList )
import Data.Array.Accelerate.Analysis.Shape
import Foreign.CUDA.Analysis
import Data.Array.Accelerate.CUDA.AST hiding ( stencil, stencilAccess )
import Data.Array.Accelerate.CUDA.CodeGen.Base
import Data.Array.Accelerate.CUDA.CodeGen.Type
stencilAccess
:: forall aenv sh e. (Shape sh, Elt e)
=> DeviceProperties
-> Bool
-> Bool
-> Name
-> Name
-> Name
-> Name
-> [sh]
-> Boundary (CUExp aenv e)
-> Eliminate e
-> ( [C.Definition]
, [C.Param]
, Instantiate1 sh e )
stencilAccess dev doLinearIndexing doBoundsChecks grp sh tmp centroid positions boundary dce =
( decls, params, stencil )
where
(decls, params, _, getIn) = readStencil dev grp (undefined :: Array sh e)
(_, _, shIn, _) = readStencil dev sh (undefined :: Array sh e)
getInAt ix = fresh >>= \j ->
return ( [[citem| const $ty:cint $id:j = $exp:ix; |]], getIn j )
stencil ix = withNameGen tmp $ do
(envs, xs) <- mapAndUnzipM (access ix . shapeToList) positions
let (envs', xs') = unzip
$ eliminate
$ zipWith (,) envs
$ unconcat (map length xs)
$ dce (concat xs)
return ( concat envs', concat xs' )
access :: Rvalue x => [x] -> [Int] -> Gen ([C.BlockItem], [C.Exp])
access (map rvalue -> ix) dx
| doBoundsChecks = safeAccess
| otherwise = unsafeAccess
where
focus = all (==0) dx
cursor | focus = ix
| otherwise = zipWith (\i d -> [cexp| $exp:i + $int:d |]) ix (reverse dx)
unsafeAccess
| doLinearIndexing && focus = return $ ([], getIn centroid)
| otherwise = getInAt (ctoIndex shIn cursor)
safeAccess = case boundary of
Clamp -> bounded cclamp
Mirror -> bounded cmirror
Wrap -> bounded cwrap
Constant (CUExp (_,c)) -> inrange c
bounded f
| focus = unsafeAccess
| otherwise = getInAt (ctoIndex shIn (f shIn cursor))
inrange cs
| focus = unsafeAccess
| otherwise = do
(env, as) <- unsafeAccess
p <- fresh
return ( [citem| const int $id:p = $exp:(cinRange shIn cursor); |] : env
, zipWith (\a c -> [cexp| $id:p ? $exp:a : $exp:c |]) as cs )
eliminate :: [ ([a], [(Bool,b)]) ] -> [ ([a],[b]) ]
eliminate [] = []
eliminate ((e,v):xs) = (e', x) : eliminate xs
where
(flags, x) = unzip v
e' | or flags = e
| otherwise = []
type Gen = State (Name,Int)
withNameGen :: Name -> Gen a -> a
withNameGen base f = evalState f (base,0)
fresh :: Gen Name
fresh = state $ \(base,n) -> (base ++ show n, (base,n+1))
insideRegion
:: [C.Exp]
-> [Int]
-> [C.Exp]
-> C.Exp
insideRegion shape border index = foldl1 and (zipWith3 inside shape border index)
where
inside sz dx i = [cexp| $exp:i >= $int:dx && $exp:i < $exp:sz $int:dx |]
and x y = [cexp| $exp:x && $exp:y |]
borderRegion :: Shape sh => [sh] -> [Int]
borderRegion
= reverse
. map maximum
. transpose
. map shapeToList
cinRange :: [C.Exp] -> [C.Exp] -> C.Exp
cinRange [] [] = $internalError "inRange" "singleton index"
cinRange shape index = foldl1 and (zipWith inside shape index)
where
inside sz i = [cexp| ({ const $ty:cint _i = $exp:i; _i >= 0 && _i < $exp:sz; }) |]
and x y = [cexp| $exp:x && $exp:y |]
cclamp :: [C.Exp] -> [C.Exp] -> [C.Exp]
cclamp = zipWith f
where
f sz i = [cexp| max(($ty:cint) 0, min( $exp:i, $exp:sz 1 )) |]
cmirror :: [C.Exp] -> [C.Exp] -> [C.Exp]
cmirror = zipWith f
where
f sz i = [cexp| ({ const $ty:cint _i = $exp:i;
const $ty:cint _sz = $exp:sz;
_i < 0 ? _i
: _i >= _sz ? _sz (_i _sz + 2)
: _i; }) |]
cwrap :: [C.Exp] -> [C.Exp] -> [C.Exp]
cwrap = zipWith f
where
f sz i = [cexp| ({ const $ty:cint _i = $exp:i;
const $ty:cint _sz = $exp:sz;
_i < 0 ? _sz + _i
: _i >= _sz ? _i _sz
: _i; }) |]
readStencil
:: forall sh e. (Shape sh, Elt e)
=> DeviceProperties
-> Name
-> Array sh e
-> ( [C.Definition]
, [C.Param]
, [C.Exp]
, forall x. Rvalue x => x -> [C.Exp]
)
readStencil dev grp dummy
= let (sh, arrs) = namesOfArray grp (undefined :: e)
(decl, args)
| computeCapability dev < Compute 2 0 = arrayAsTex dummy grp
| otherwise = ([], arrayAsArg dummy grp)
dim = expDim (undefined :: Exp aenv sh)
sh' = cshape dim sh
fetch ix = zipWith (\t a -> indexArray dev t (cvar a) (rvalue ix)) (eltType (undefined :: e)) arrs
in
( decl, args, sh', fetch )
zipWith :: (a -> b -> c) -> [a] -> [b] -> [c]
zipWith f (x:xs) (y:ys) = f x y : zipWith f xs ys
zipWith _ [] [] = []
zipWith _ _ _ = $internalError "zipWith" "argument mismatch"
zipWith3 :: (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 f (x:xs) (y:ys) (z:zs) = f x y z : zipWith3 f xs ys zs
zipWith3 _ [] [] [] = []
zipWith3 _ _ _ _ = $internalError "zipWith3" "argument mismatch"
unconcat :: [Int] -> [a] -> [[a]]
unconcat [] _ = []
unconcat (n:ns) xs = let (h,t) = splitAt n xs in h : unconcat ns t