{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE ViewPatterns #-} -- | -- Module : Data.Array.Accelerate.CUDA.CodeGen.Stencil.Extra -- Copyright : [2008..2014] Manuel M T Chakravarty, Gabriele Keller -- [2009..2014] Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Trevor L. McDonell -- Stability : experimental -- Portability : non-portable (GHC extensions) -- module Data.Array.Accelerate.CUDA.CodeGen.Stencil.Extra ( stencilAccess, cinRange, cclamp, cmirror, cwrap, insideRegion, borderRegion, ) where -- standard library import Prelude hiding ( and, zipWith, zipWith3 ) import Data.List ( transpose ) import Control.Monad import Control.Monad.State.Strict -- language-c import Language.C.Quote.CUDA import qualified Language.C.Syntax as C -- friends import Data.Array.Accelerate.Error ( internalError ) import Data.Array.Accelerate.Type ( Boundary(..) ) import Data.Array.Accelerate.Array.Sugar ( Array, Shape, Elt, shapeToList ) import Data.Array.Accelerate.Analysis.Shape import Foreign.CUDA.Analysis import Data.Array.Accelerate.CUDA.AST hiding ( stencil, stencilAccess ) import Data.Array.Accelerate.CUDA.CodeGen.Base import Data.Array.Accelerate.CUDA.CodeGen.Type -- Stencil Access -- -------------- -- Generate declarations for reading in a stencil pattern surrounding a given -- focal point. -- stencilAccess :: forall aenv sh e. (Shape sh, Elt e) => DeviceProperties -> Bool -- can we use linear indexing? -> Bool -- do we need to do bounds checking? -> Name -- array group name -> Name -- group name for array shape (hax!) -> Name -- seed name for temporary variables -> Name -- linear index at the focus -> [sh] -- list of offset indices -> Boundary (CUExp aenv e) -- stencil boundary condition -> Eliminate e -- dead code elimination flags -> ( [C.Definition] -- kernel texture reference definitions , [C.Param] -- kernel function arguments , Instantiate1 sh e ) -- access stencil at given multidimensional index stencilAccess dev doLinearIndexing doBoundsChecks grp sh tmp centroid positions boundary dce = ( decls, params, stencil ) where (decls, params, _, getIn) = readStencil dev grp (undefined :: Array sh e) (_, _, shIn, _) = readStencil dev sh (undefined :: Array sh e) getInAt ix = fresh >>= \j -> return ( [[citem| const $ty:cint $id:j = $exp:ix; |]], getIn j ) -- Generate the entire stencil, reading elements from those positions of the -- pattern that are used and eliminating reads from those that are not. -- stencil ix = withNameGen tmp $ do (envs, xs) <- mapAndUnzipM (access ix . shapeToList) positions let (envs', xs') = unzip $ eliminate $ zipWith (,) envs -- our version of zipwith that checks lengths $ unconcat (map length xs) $ dce (concat xs) return ( concat envs', concat xs' ) -- Read the stencil component at the given offset (second argument). This -- may generate additional environment terms, such as for the index -- calculations. -- access :: Rvalue x => [x] -> [Int] -> Gen ([C.BlockItem], [C.Exp]) access (map rvalue -> ix) dx | doBoundsChecks = safeAccess | otherwise = unsafeAccess where focus = all (==0) dx -- The current stencil position into the array, as a multidimensional index cursor | focus = ix | otherwise = zipWith (\i d -> [cexp| $exp:i + $int:d |]) ix (reverse dx) -- Read the array position without any bounds checks unsafeAccess | doLinearIndexing && focus = return $ ([], getIn centroid) | otherwise = getInAt (ctoIndex shIn cursor) -- Read the array, applying appropriate bounds checks safeAccess = case boundary of Clamp -> bounded cclamp Mirror -> bounded cmirror Wrap -> bounded cwrap Constant (CUExp (_,c)) -> inrange c bounded f | focus = unsafeAccess | otherwise = getInAt (ctoIndex shIn (f shIn cursor)) inrange cs | focus = unsafeAccess | otherwise = do (env, as) <- unsafeAccess p <- fresh return ( [citem| const int $id:p = $exp:(cinRange shIn cursor); |] : env , zipWith (\a c -> [cexp| $id:p ? $exp:a : $exp:c |]) as cs ) -- Filter unused components of the stencil. Environment bindings are shared -- between tuple components of each cursor position, so filter these out only if -- all elements of that position are unused. -- eliminate :: [ ([a], [(Bool,b)]) ] -> [ ([a],[b]) ] eliminate [] = [] eliminate ((e,v):xs) = (e', x) : eliminate xs where (flags, x) = unzip v e' | or flags = e | otherwise = [] -- A simple fresh name supply -- type Gen = State (Name,Int) withNameGen :: Name -> Gen a -> a withNameGen base f = evalState f (base,0) fresh :: Gen Name fresh = state $ \(base,n) -> (base ++ show n, (base,n+1)) -- Boundary conditions -- ------------------- -- Test whether the given multidimensional index lies in the inside region of -- the stencil. -- insideRegion :: [C.Exp] -- The shape of the array -> [Int] -- The width of the stencil in each direction -> [C.Exp] -- The index in question -> C.Exp insideRegion shape border index = foldl1 and (zipWith3 inside shape border index) where inside sz dx i = [cexp| $exp:i >= $int:dx && $exp:i < $exp:sz - $int:dx |] and x y = [cexp| $exp:x && $exp:y |] -- Given a list of stencil offset positions, calculate the size of the border -- region along each dimension. -- -- Note that this does not consider any positions of the stencil that are not -- actually used. We assume the user is sensible and uses the minimally sized -- stencil for their application, but this can still be problematic for -- non-symmetric stencils. For example, a large stencil that uses elements from -- only one quadrant. -- borderRegion :: Shape sh => [sh] -> [Int] borderRegion = reverse . map maximum . transpose . map shapeToList -- Test whether an index lies within the boundaries of a shape (first argument) -- cinRange :: [C.Exp] -> [C.Exp] -> C.Exp cinRange [] [] = $internalError "inRange" "singleton index" cinRange shape index = foldl1 and (zipWith inside shape index) where inside sz i = [cexp| ({ const $ty:cint _i = $exp:i; _i >= 0 && _i < $exp:sz; }) |] and x y = [cexp| $exp:x && $exp:y |] -- Clamp an index to the boundary of the shape (first argument) -- cclamp :: [C.Exp] -> [C.Exp] -> [C.Exp] cclamp = zipWith f where f sz i = [cexp| max(($ty:cint) 0, min( $exp:i, $exp:sz - 1 )) |] -- Indices out of bounds of the shape are mirrored back in range. Assumes that -- the array is at least as large as the stencil. -- cmirror :: [C.Exp] -> [C.Exp] -> [C.Exp] cmirror = zipWith f where f sz i = [cexp| ({ const $ty:cint _i = $exp:i; const $ty:cint _sz = $exp:sz; _i < 0 ? -_i : _i >= _sz ? _sz - (_i - _sz + 2) : _i; }) |] -- Indices out of bounds are wrapped to the opposite edge of the shape -- cwrap :: [C.Exp] -> [C.Exp] -> [C.Exp] cwrap = zipWith f where f sz i = [cexp| ({ const $ty:cint _i = $exp:i; const $ty:cint _sz = $exp:sz; _i < 0 ? _sz + _i : _i >= _sz ? _i - _sz : _i; }) |] -- Kernel parameters -- ----------------- -- Generate kernel parameters for input arrays. This is similar to 'readArray', -- but we force compute 1.x devices to read through the texture cache as well. -- readStencil :: forall sh e. (Shape sh, Elt e) => DeviceProperties -> Name -- group names -> Array sh e -- dummy to fix the types -> ( [C.Definition] -- global definitions for stencils read via texture references (compute < 2.0) , [C.Param] -- function arguments for stencils read as arrays (compute >= 2.0) , [C.Exp] -- shape of the array , forall x. Rvalue x => x -> [C.Exp] -- read elements from a linear index ) readStencil dev grp dummy = let (sh, arrs) = namesOfArray grp (undefined :: e) (decl, args) | computeCapability dev < Compute 2 0 = arrayAsTex dummy grp | otherwise = ([], arrayAsArg dummy grp) dim = expDim (undefined :: Exp aenv sh) sh' = cshape dim sh fetch ix = zipWith (\t a -> indexArray dev t (cvar a) (rvalue ix)) (eltType (undefined :: e)) arrs in ( decl, args, sh', fetch ) -- Prelude' -- -------- -- A version of 'zipWith' that requires the lists to be equal length -- zipWith :: (a -> b -> c) -> [a] -> [b] -> [c] zipWith f (x:xs) (y:ys) = f x y : zipWith f xs ys zipWith _ [] [] = [] zipWith _ _ _ = $internalError "zipWith" "argument mismatch" zipWith3 :: (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d] zipWith3 f (x:xs) (y:ys) (z:zs) = f x y z : zipWith3 f xs ys zs zipWith3 _ [] [] [] = [] zipWith3 _ _ _ _ = $internalError "zipWith3" "argument mismatch" -- Split a list into segments of given length -- unconcat :: [Int] -> [a] -> [[a]] unconcat [] _ = [] unconcat (n:ns) xs = let (h,t) = splitAt n xs in h : unconcat ns t