{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.KernelBabysitting
( babysitKernels
, nonlinearInMemory
)
where
import Control.Arrow (first)
import Control.Monad.State.Strict
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.List
import Data.Maybe
import Futhark.MonadFreshNames
import Futhark.Representation.AST
import Futhark.Representation.Kernels
hiding (Prog, Body, Stm, Pattern, PatElem,
BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Tools
import Futhark.Pass
import Futhark.Util
babysitKernels :: Pass Kernels Kernels
babysitKernels = Pass "babysit kernels"
"Transpose kernel input arrays for better performance." $
intraproceduralTransformation transformFunDef
transformFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
transformFunDef fundec = do
(body', _) <- modifyNameSource $ runState (runBinderT m M.empty)
return fundec { funDefBody = body' }
where m = inScopeOf fundec $
transformBody mempty $ funDefBody fundec
type BabysitM = Binder Kernels
transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody expmap (Body () bnds res) = insertStmsM $ do
foldM_ transformStm expmap bnds
return $ resultBody res
type ExpMap = M.Map VName (Stm Kernels)
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory name m =
case M.lookup name m of
Just (Let _ _ (BasicOp (Opaque (Var arr)))) -> nonlinearInMemory arr m
Just (Let _ _ (BasicOp (Rearrange perm _))) -> Just $ Just $ rearrangeInverse perm
Just (Let _ _ (BasicOp (Reshape _ arr))) -> nonlinearInMemory arr m
Just (Let _ _ (BasicOp (Manifest perm _))) -> Just $ Just perm
Just (Let pat _ (Op (SegOp (SegMap _ _ ts _)))) ->
nonlinear =<< find ((==name) . patElemName . fst)
(zip (patternElements pat) ts)
_ -> Nothing
where nonlinear (pe, t)
| inner_r <- arrayRank t, inner_r > 0 = do
let outer_r = arrayRank (patElemType pe) - inner_r
return $ Just $ rearrangeInverse $ [inner_r..inner_r+outer_r-1] ++ [0..inner_r-1]
| otherwise = Nothing
transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap
transformStm expmap (Let pat aux (Op (SegOp op))) = do
let mapper = identitySegOpMapper
{ mapOnSegOpBody =
transformKernelBody expmap (segLevel op) (segSpace op)
}
op' <- mapSegOpM mapper op
let stm' = Let pat aux $ Op $ SegOp op'
addStm stm'
return $ M.fromList [ (name, stm') | name <- patternNames pat ] <> expmap
transformStm expmap (Let pat aux e) = do
e' <- mapExpM (transform expmap) e
let bnd' = Let pat aux e'
addStm bnd'
return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap
transform :: ExpMap -> Mapper Kernels Kernels BabysitM
transform expmap =
identityMapper { mapOnBody = \scope -> localScope scope . transformBody expmap }
transformKernelBody :: ExpMap -> SegLevel -> SegSpace -> KernelBody Kernels
-> BabysitM (KernelBody Kernels)
transformKernelBody expmap lvl space kbody = do
scope <- askScope
let thread_gids = map fst $ unSegSpace space
thread_local = namesFromList $ segFlat space : thread_gids
free_ker_vars = freeIn kbody `namesSubtract` getKerVariantIds space
num_threads <- letSubExp "num_threads" $ BasicOp $ BinOp (Mul Int32)
(unCount $ segNumGroups lvl) (unCount $ segGroupSize lvl)
evalStateT (traverseKernelBodyArrayIndexes
free_ker_vars
thread_local
(scope <> scopeOfSegSpace space)
(ensureCoalescedAccess expmap (unSegSpace space) num_threads)
kbody)
mempty
where getKerVariantIds = namesFromList . M.keys . scopeOfSegSpace
type ArrayIndexTransform m =
Names ->
(VName -> Bool) ->
(VName -> SubExp -> Bool)->
(SubExp -> Maybe SubExp) ->
Scope Kernels ->
VName -> Slice SubExp -> m (Maybe (VName, Slice SubExp))
traverseKernelBodyArrayIndexes :: (Applicative f, Monad f) =>
Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes free_ker_vars thread_variant outer_scope f (KernelBody () kstms kres) =
KernelBody () . stmsFromList <$>
mapM (onStm (varianceInStms mempty kstms,
mkSizeSubsts kstms,
outer_scope)) (stmsToList kstms) <*>
pure kres
where onLambda (variance, szsubst, scope) lam =
(\body' -> lam { lambdaBody = body' }) <$>
onBody (variance, szsubst, scope') (lambdaBody lam)
where scope' = scope <> scopeOfLParams (lambdaParams lam)
onBody (variance, szsubst, scope) (Body battr stms bres) = do
stms' <- stmsFromList <$> mapM (onStm (variance', szsubst', scope')) (stmsToList stms)
Body battr stms' <$> pure bres
where variance' = varianceInStms variance stms
szsubst' = mkSizeSubsts stms <> szsubst
scope' = scope <> scopeOf stms
onStm (variance, szsubst, _) (Let pat attr (BasicOp (Index arr is))) =
Let pat attr . oldOrNew <$> f free_ker_vars isThreadLocal isGidVariant sizeSubst outer_scope arr is
where oldOrNew Nothing =
BasicOp $ Index arr is
oldOrNew (Just (arr', is')) =
BasicOp $ Index arr' is'
isGidVariant gid (Var v) =
gid == v || nameIn gid (M.findWithDefault (oneName v) v variance)
isGidVariant _ _ = False
isThreadLocal v =
thread_variant `namesIntersect`
M.findWithDefault (oneName v) v variance
sizeSubst (Constant v) = Just $ Constant v
sizeSubst (Var v)
| v `M.member` outer_scope = Just $ Var v
| Just v' <- M.lookup v szsubst = sizeSubst v'
| otherwise = Nothing
onStm (variance, szsubst, scope) (Let pat attr e) =
Let pat attr <$> mapExpM (mapper (variance, szsubst, scope)) e
onOp ctx (OtherOp soac) =
OtherOp <$> mapSOACM identitySOACMapper{ mapOnSOACLambda = onLambda ctx } soac
onOp _ op = return op
mapper ctx = identityMapper { mapOnBody = const (onBody ctx)
, mapOnOp = onOp ctx }
mkSizeSubsts = fold . fmap mkStmSizeSubst
where mkStmSizeSubst (Let (Pattern [] [pe]) _ (Op (SplitSpace _ _ _ elems_per_i))) =
M.singleton (patElemName pe) elems_per_i
mkStmSizeSubst _ = mempty
type Replacements = M.Map (VName, Slice SubExp) VName
ensureCoalescedAccess :: MonadBinder m =>
ExpMap
-> [(VName,SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess expmap thread_space num_threads free_ker_vars isThreadLocal
isGidVariant sizeSubst outer_scope arr slice = do
seen <- gets $ M.lookup (arr, slice)
case (seen, isThreadLocal arr, typeOf <$> M.lookup arr outer_scope) of
(Just arr', _, _) ->
pure $ Just (arr', slice)
(Nothing, False, Just t)
| Just is <- sliceIndices slice,
length is == arrayRank t,
Just is' <- coalescedIndexes free_ker_vars isGidVariant (map Var thread_gids) is,
Just perm <- is' `isPermutationOf` is ->
replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr)
| Just (Let _ _ (BasicOp (Rearrange perm _))) <- M.lookup arr expmap,
not $ null perm,
not $ null thread_gids,
inner_gid <- last thread_gids,
length slice >= length perm,
slice' <- map (slice !!) perm,
DimFix inner_ind <- last slice',
not $ null thread_gids,
isGidVariant inner_gid inner_ind ->
return Nothing
| (is, rem_slice) <- splitSlice slice,
not $ null rem_slice,
allDimAreSlice rem_slice,
Nothing <- M.lookup arr expmap,
not $ tooSmallSlice (primByteSize (elemType t)) rem_slice,
is /= map Var (take (length is) thread_gids) || length is == length thread_gids,
not (null thread_gids || null is),
not (last thread_gids `nameIn` (freeIn is <> freeIn rem_slice)) ->
return Nothing
| (is, rem_slice) <- splitSlice slice,
not $ null rem_slice,
not $ tooSmallSlice (primByteSize (elemType t)) rem_slice,
is /= map Var (take (length is) thread_gids) || length is == length thread_gids,
any isThreadLocal (namesToList $ freeIn is) -> do
let perm = coalescingPermutation (length is) $ arrayRank t
replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr)
| (is, rem_slice) <- splitSlice slice,
and $ zipWith (==) is $ map Var thread_gids,
DimSlice offset len (Constant stride):_ <- rem_slice,
isThreadLocalSubExp offset,
Just {} <- sizeSubst len,
oneIsh stride -> do
let num_chunks = if null is
then primExpFromSubExp int32 num_threads
else coerceIntPrimExp Int32 $
product $ map (primExpFromSubExp int32) $
drop (length is) thread_gdims
replace =<< lift (rearrangeSlice (length is) (arraySize (length is) t) num_chunks arr)
| Just{} <- nonlinearInMemory arr expmap ->
case sliceIndices slice of
Just is | Just _ <- coalescedIndexes free_ker_vars isGidVariant (map Var thread_gids) is ->
replace =<< lift (rowMajorArray arr)
| otherwise ->
return Nothing
_ -> replace =<< lift (rowMajorArray arr)
_ -> return Nothing
where (thread_gids, thread_gdims) = unzip thread_space
replace arr' = do
modify $ M.insert (arr, slice) arr'
return $ Just (arr', slice)
isThreadLocalSubExp (Var v) = isThreadLocal v
isThreadLocalSubExp Constant{} = False
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice bs = fst . foldl comb (True,bs) . sliceDims
where comb (True, x) (Constant (IntValue (Int32Value d))) = (d*x < 4, d*x)
comb (_, x) _ = (False, x)
splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice [] = ([], [])
splitSlice (DimFix i:is) = first (i:) $ splitSlice is
splitSlice is = ([], is)
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice [] = True
allDimAreSlice (DimFix _:_) = False
allDimAreSlice (_:is) = allDimAreSlice is
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes free_ker_vars isGidVariant tgids is
| any isCt is =
Nothing
| any (`nameIn` free_ker_vars) (mapMaybe mbVarId is) =
Nothing
| not (null tgids),
not (null is),
Var innergid <- last tgids,
num_is > 0 && isGidVariant innergid (last is) =
Just is
| otherwise =
Just $ reverse $ foldl move (reverse is) $ zip [0..] (reverse tgids)
where num_is = length is
move is_rev (i, tgid)
| Just j <- elemIndex tgid is_rev, i /= j, i < num_is =
swap i j is_rev
| otherwise =
is_rev
swap i j l
| Just ix <- maybeNth i l,
Just jx <- maybeNth j l =
update i jx $ update j ix l
| otherwise =
error $ "coalescedIndexes swap: invalid indices" ++ show (i, j, l)
update 0 x (_:ys) = x : ys
update i x (y:ys) = y : update (i-1) x ys
update _ _ [] = error "coalescedIndexes: update"
isCt :: SubExp -> Bool
isCt (Constant _) = True
isCt (Var _) = False
mbVarId (Constant _) = Nothing
mbVarId (Var v) = Just v
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation num_is rank =
[num_is..rank-1] ++ [0..num_is-1]
rearrangeInput :: MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just current_perm)) perm arr
| current_perm == perm = return arr
rearrangeInput Nothing perm arr
| sort perm == perm = return arr
rearrangeInput (Just Just{}) perm arr
| sort perm == perm = rowMajorArray arr
rearrangeInput manifest perm arr = do
manifested <- if isJust manifest then rowMajorArray arr else return arr
letExp (baseString arr ++ "_coalesced") $
BasicOp $ Manifest perm manifested
rowMajorArray :: MonadBinder m =>
VName -> m VName
rowMajorArray arr = do
rank <- arrayRank <$> lookupType arr
letExp (baseString arr ++ "_rowmajor") $ BasicOp $ Manifest [0..rank-1] arr
rearrangeSlice :: MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName
-> m VName
rearrangeSlice d w num_chunks arr = do
num_chunks' <- letSubExp "num_chunks" =<< toExp num_chunks
(w_padded, padding) <- paddedScanReduceInput w num_chunks'
per_chunk <- letSubExp "per_chunk" $ BasicOp $ BinOp (SQuot Int32) w_padded num_chunks'
arr_t <- lookupType arr
arr_padded <- padArray w_padded padding arr_t
rearrange num_chunks' w_padded per_chunk (baseString arr) arr_padded arr_t
where padArray w_padded padding arr_t = do
let arr_shape = arrayShape arr_t
padding_shape = setDim d arr_shape padding
arr_padding <-
letExp (baseString arr <> "_padding") $
BasicOp $ Scratch (elemType arr_t) (shapeDims padding_shape)
letExp (baseString arr <> "_padded") $
BasicOp $ Concat d arr [arr_padding] w_padded
rearrange num_chunks' w_padded per_chunk arr_name arr_padded arr_t = do
let arr_dims = arrayDims arr_t
pre_dims = take d arr_dims
post_dims = drop (d+1) arr_dims
extradim_shape = Shape $ pre_dims ++ [num_chunks', per_chunk] ++ post_dims
tr_perm = [0..d-1] ++ map (+d) ([1] ++ [2..shapeRank extradim_shape-1-d] ++ [0])
arr_extradim <-
letExp (arr_name <> "_extradim") $
BasicOp $ Reshape (map DimNew $ shapeDims extradim_shape) arr_padded
arr_extradim_tr <-
letExp (arr_name <> "_extradim_tr") $
BasicOp $ Manifest tr_perm arr_extradim
arr_inv_tr <- letExp (arr_name <> "_inv_tr") $
BasicOp $ Reshape (map DimCoercion pre_dims ++ map DimNew (w_padded : post_dims))
arr_extradim_tr
letExp (arr_name <> "_inv_tr_init") =<<
eSliceArray d arr_inv_tr (eSubExp $ constant (0::Int32)) (eSubExp w)
paddedScanReduceInput :: MonadBinder m =>
SubExp -> SubExp
-> m (SubExp, SubExp)
paddedScanReduceInput w stride = do
w_padded <- letSubExp "padded_size" =<<
eRoundToMultipleOf Int32 (eSubExp w) (eSubExp stride)
padding <- letSubExp "padding" $ BasicOp $ BinOp (Sub Int32) w_padded w
return (w_padded, padding)
type VarianceTable = M.Map VName Names
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms t = foldl varianceInStm t . stmsToList
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm variance bnd =
foldl' add variance $ patternNames $ stmPattern bnd
where add variance' v = M.insert v binding_variance variance'
look variance' v = oneName v <> M.findWithDefault mempty v variance'
binding_variance = mconcat $ map (look variance) $ namesToList (freeIn bnd)