{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.KernelBabysitting
( babysitKernels
, nonlinearInMemory
)
where
import Control.Arrow (first)
import Control.Monad.State.Strict
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Foldable
import Data.List
import Data.Maybe
import Data.Semigroup ((<>))
import Futhark.MonadFreshNames
import Futhark.Representation.AST
import Futhark.Representation.Kernels
hiding (Prog, Body, Stm, Pattern, PatElem,
BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Tools
import Futhark.Pass
import Futhark.Util
babysitKernels :: Pass Kernels Kernels
babysitKernels = Pass "babysit kernels"
"Transpose kernel input arrays for better performance." $
intraproceduralTransformation transformFunDef
transformFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
transformFunDef fundec = do
(body', _) <- modifyNameSource $ runState (runBinderT m M.empty)
return fundec { funDefBody = body' }
where m = inScopeOf fundec $
transformBody mempty $ funDefBody fundec
type BabysitM = Binder Kernels
transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody expmap (Body () bnds res) = insertStmsM $ do
foldM_ transformStm expmap bnds
return $ resultBody res
type ExpMap = M.Map VName (Stm Kernels)
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory name m =
case M.lookup name m of
Just (Let _ _ (BasicOp (Rearrange perm _))) -> Just $ Just $ rearrangeInverse perm
Just (Let _ _ (BasicOp (Reshape _ arr))) -> nonlinearInMemory arr m
Just (Let _ _ (BasicOp (Manifest perm _))) -> Just $ Just perm
Just (Let pat _ (Op (Kernel _ _ ts _))) ->
nonlinear =<< find ((==name) . patElemName . fst)
(zip (patternElements pat) ts)
_ -> Nothing
where nonlinear (pe, t)
| inner_r <- arrayRank t, inner_r > 0 = do
let outer_r = arrayRank (patElemType pe) - inner_r
return $ Just $ rearrangeInverse $ [inner_r..inner_r+outer_r-1] ++ [0..inner_r-1]
| otherwise = Nothing
transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap
transformStm expmap (Let pat aux (Op (Kernel desc space ts kbody))) = do
scope <- askScope
let thread_gids = map fst $ spaceDimensions space
thread_local = S.fromList $ spaceGlobalId space : spaceLocalId space : thread_gids
kbody'' <- evalStateT (traverseKernelBodyArrayIndexes
thread_local
(castScope scope <> scopeOfKernelSpace space)
(ensureCoalescedAccess expmap (spaceDimensions space) num_threads)
kbody)
mempty
let bnd' = Let pat aux $ Op $ Kernel desc space ts kbody''
addStm bnd'
return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap
where num_threads = spaceNumThreads space
transformStm expmap (Let pat aux e) = do
e' <- mapExpM (transform expmap) e
let bnd' = Let pat aux e'
addStm bnd'
return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap
transform :: ExpMap -> Mapper Kernels Kernels BabysitM
transform expmap =
identityMapper { mapOnBody = \scope -> localScope scope . transformBody expmap }
type ArrayIndexTransform m =
(VName -> Bool) ->
(SubExp -> Maybe SubExp) ->
Scope InKernel ->
VName -> Slice SubExp -> m (Maybe (VName, Slice SubExp))
traverseKernelBodyArrayIndexes :: (Applicative f, Monad f) =>
Names
-> Scope InKernel
-> ArrayIndexTransform f
-> KernelBody InKernel
-> f (KernelBody InKernel)
traverseKernelBodyArrayIndexes thread_variant outer_scope f (KernelBody () kstms kres) =
KernelBody () . stmsFromList <$>
mapM (onStm (varianceInStms mempty kstms,
mkSizeSubsts kstms,
outer_scope)) (stmsToList kstms) <*>
pure kres
where onLambda (variance, szsubst, scope) lam =
(\body' -> lam { lambdaBody = body' }) <$>
onBody (variance, szsubst, scope') (lambdaBody lam)
where scope' = scope <> scopeOfLParams (lambdaParams lam)
onStreamLambda (variance, szsubst, scope) lam =
(\body' -> lam { groupStreamLambdaBody = body' }) <$>
onBody (variance, szsubst, scope') (groupStreamLambdaBody lam)
where scope' = scope <> scopeOf lam
onBody (variance, szsubst, scope) (Body battr stms bres) = do
stms' <- stmsFromList <$> mapM (onStm (variance', szsubst', scope')) (stmsToList stms)
Body battr stms' <$> pure bres
where variance' = varianceInStms variance stms
szsubst' = mkSizeSubsts stms <> szsubst
scope' = scope <> scopeOf stms
onStm (variance, szsubst, _) (Let pat attr (BasicOp (Index arr is))) =
Let pat attr . oldOrNew <$> f isThreadLocal sizeSubst outer_scope arr is
where oldOrNew Nothing =
BasicOp $ Index arr is
oldOrNew (Just (arr', is')) =
BasicOp $ Index arr' is'
isThreadLocal v =
not $ S.null $
thread_variant `S.intersection`
M.findWithDefault (S.singleton v) v variance
sizeSubst (Constant v) = Just $ Constant v
sizeSubst (Var v)
| v `M.member` outer_scope = Just $ Var v
| Just v' <- M.lookup v szsubst = sizeSubst v'
| otherwise = Nothing
onStm (variance, szsubst, scope) (Let pat attr e) =
Let pat attr <$> mapExpM (mapper (variance, szsubst, scope)) e
mapper ctx = identityMapper { mapOnBody = const (onBody ctx)
, mapOnOp = onOp ctx
}
onOp ctx (GroupReduce w lam input) =
GroupReduce w <$> onLambda ctx lam <*> pure input
onOp ctx (GroupStream w maxchunk lam accs arrs) =
GroupStream w maxchunk <$> onStreamLambda ctx lam <*> pure accs <*> pure arrs
onOp _ stm = pure stm
mkSizeSubsts = fold . fmap mkStmSizeSubst
where mkStmSizeSubst (Let (Pattern [] [pe]) _ (Op (SplitSpace _ _ _ elems_per_i))) =
M.singleton (patElemName pe) elems_per_i
mkStmSizeSubst _ = mempty
type Replacements = M.Map (VName, Slice SubExp) VName
ensureCoalescedAccess :: MonadBinder m =>
ExpMap
-> [(VName,SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess expmap thread_space num_threads isThreadLocal sizeSubst outer_scope arr slice = do
seen <- gets $ M.lookup (arr, slice)
case (seen, isThreadLocal arr, typeOf <$> M.lookup arr outer_scope) of
(Just arr', _, _) ->
pure $ Just (arr', slice)
(Nothing, False, Just t)
| Just is <- sliceIndices slice,
length is == arrayRank t,
Just is' <- coalescedIndexes (map Var thread_gids) is,
Just perm <- is' `isPermutationOf` is ->
replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr)
| Just (Let _ _ (BasicOp (Rearrange perm _))) <- M.lookup arr expmap,
not $ null perm,
length slice >= length perm,
slice' <- map (\i -> slice !! i) perm,
DimFix inner_ind <- last slice',
not $ null thread_gids,
inner_ind == (Var $ last thread_gids) ->
return Nothing
| (is, rem_slice) <- splitSlice slice,
not $ null rem_slice,
allDimAreSlice rem_slice,
Nothing <- M.lookup arr expmap,
not $ tooSmallSlice (primByteSize (elemType t)) rem_slice,
is /= map Var (take (length is) thread_gids) || length is == length thread_gids,
not (null thread_gids || null is),
not ( S.member (last thread_gids) (S.union (freeIn is) (freeIn rem_slice)) ) ->
return Nothing
| (is, rem_slice) <- splitSlice slice,
not $ null rem_slice,
not $ tooSmallSlice (primByteSize (elemType t)) rem_slice,
is /= map Var (take (length is) thread_gids) || length is == length thread_gids,
any isThreadLocal (S.toList $ freeIn is) -> do
let perm = coalescingPermutation (length is) $ arrayRank t
replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr)
| (is, rem_slice) <- splitSlice slice,
and $ zipWith (==) is $ map Var thread_gids,
DimSlice offset len (Constant stride):_ <- rem_slice,
isThreadLocalSubExp offset,
Just {} <- sizeSubst len,
oneIsh stride -> do
let num_chunks = if null is
then primExpFromSubExp int32 num_threads
else coerceIntPrimExp Int32 $
product $ map (primExpFromSubExp int32) $
drop (length is) thread_gdims
replace =<< lift (rearrangeSlice (length is) (arraySize (length is) t) num_chunks arr)
| Just{} <- nonlinearInMemory arr expmap ->
case sliceIndices slice of
Just is | Just _ <- coalescedIndexes (map Var thread_gids) is ->
replace =<< lift (rowMajorArray arr)
| otherwise ->
return Nothing
_ -> replace =<< lift (rowMajorArray arr)
_ -> return Nothing
where (thread_gids, thread_gdims) = unzip thread_space
replace arr' = do
modify $ M.insert (arr, slice) arr'
return $ Just (arr', slice)
isThreadLocalSubExp (Var v) = isThreadLocal v
isThreadLocalSubExp Constant{} = False
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice bs = fst . foldl comb (True,bs) . sliceDims
where comb (True, x) (Constant (IntValue (Int32Value d))) = (d*x < 4, d*x)
comb (_, x) _ = (False, x)
splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice [] = ([], [])
splitSlice (DimFix i:is) = first (i:) $ splitSlice is
splitSlice is = ([], is)
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice [] = True
allDimAreSlice (DimFix _:_) = False
allDimAreSlice (_:is) = allDimAreSlice is
coalescedIndexes :: [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes tgids is
| any isCt is =
Nothing
| num_is > 0 && not (null tgids) && last is == last tgids =
Just is
| otherwise =
Just $ reverse $ foldl move (reverse is) $ zip [0..] (reverse tgids)
where num_is = length is
move is_rev (i, tgid)
| Just j <- elemIndex tgid is_rev, i /= j, i < num_is =
swap i j is_rev
| otherwise =
is_rev
swap i j l
| Just ix <- maybeNth i l,
Just jx <- maybeNth j l =
update i jx $ update j ix l
| otherwise =
error $ "coalescedIndexes swap: invalid indices" ++ show (i, j, l)
update 0 x (_:ys) = x : ys
update i x (y:ys) = y : update (i-1) x ys
update _ _ [] = error "coalescedIndexes: update"
isCt :: SubExp -> Bool
isCt (Constant _) = True
isCt (Var _) = False
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation num_is rank =
[num_is..rank-1] ++ [0..num_is-1]
rearrangeInput :: MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just current_perm)) perm arr
| current_perm == perm = return arr
rearrangeInput Nothing perm arr
| sort perm == perm = return arr
rearrangeInput (Just Just{}) perm arr
| sort perm == perm = rowMajorArray arr
rearrangeInput manifest perm arr = do
manifested <- if isJust manifest then rowMajorArray arr else return arr
letExp (baseString arr ++ "_coalesced") $
BasicOp $ Manifest perm manifested
rowMajorArray :: MonadBinder m =>
VName -> m VName
rowMajorArray arr = do
rank <- arrayRank <$> lookupType arr
letExp (baseString arr ++ "_rowmajor") $ BasicOp $ Manifest [0..rank-1] arr
rearrangeSlice :: MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName
-> m VName
rearrangeSlice d w num_chunks arr = do
num_chunks' <- letSubExp "num_chunks" =<< toExp num_chunks
(w_padded, padding) <- paddedScanReduceInput w num_chunks'
per_chunk <- letSubExp "per_chunk" $ BasicOp $ BinOp (SQuot Int32) w_padded num_chunks'
arr_t <- lookupType arr
arr_padded <- padArray w_padded padding arr_t
rearrange num_chunks' w_padded per_chunk (baseString arr) arr_padded arr_t
where padArray w_padded padding arr_t = do
let arr_shape = arrayShape arr_t
padding_shape = setDim d arr_shape padding
arr_padding <-
letExp (baseString arr <> "_padding") $
BasicOp $ Scratch (elemType arr_t) (shapeDims padding_shape)
letExp (baseString arr <> "_padded") $
BasicOp $ Concat d arr [arr_padding] w_padded
rearrange num_chunks' w_padded per_chunk arr_name arr_padded arr_t = do
let arr_dims = arrayDims arr_t
pre_dims = take d arr_dims
post_dims = drop (d+1) arr_dims
extradim_shape = Shape $ pre_dims ++ [num_chunks', per_chunk] ++ post_dims
tr_perm = [0..d-1] ++ map (+d) ([1] ++ [2..shapeRank extradim_shape-1-d] ++ [0])
arr_extradim <-
letExp (arr_name <> "_extradim") $
BasicOp $ Reshape (map DimNew $ shapeDims extradim_shape) arr_padded
arr_extradim_tr <-
letExp (arr_name <> "_extradim_tr") $
BasicOp $ Manifest tr_perm arr_extradim
arr_inv_tr <- letExp (arr_name <> "_inv_tr") $
BasicOp $ Reshape (map DimCoercion pre_dims ++ map DimNew (w_padded : post_dims))
arr_extradim_tr
letExp (arr_name <> "_inv_tr_init") =<<
eSliceArray d arr_inv_tr (eSubExp $ constant (0::Int32)) (eSubExp w)
paddedScanReduceInput :: MonadBinder m =>
SubExp -> SubExp
-> m (SubExp, SubExp)
paddedScanReduceInput w stride = do
w_padded <- letSubExp "padded_size" =<<
eRoundToMultipleOf Int32 (eSubExp w) (eSubExp stride)
padding <- letSubExp "padding" $ BasicOp $ BinOp (Sub Int32) w_padded w
return (w_padded, padding)
type VarianceTable = M.Map VName Names
varianceInStms :: VarianceTable -> Stms InKernel -> VarianceTable
varianceInStms t = foldl varianceInStm t . stmsToList
varianceInStm :: VarianceTable -> Stm InKernel -> VarianceTable
varianceInStm variance bnd =
foldl' add variance $ patternNames $ stmPattern bnd
where add variance' v = M.insert v binding_variance variance'
look variance' v = S.insert v $ M.findWithDefault mempty v variance'
binding_variance = mconcat $ map (look variance) $ S.toList (freeInStm bnd)