{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -- | Do various kernel optimisations - mostly related to coalescing. module Futhark.Pass.KernelBabysitting ( babysitKernels , nonlinearInMemory ) where import Control.Arrow (first) import Control.Monad.State.Strict import qualified Data.Map.Strict as M import qualified Data.Set as S import Data.Foldable import Data.List import Data.Maybe import Futhark.MonadFreshNames import Futhark.Representation.AST import Futhark.Representation.Kernels hiding (Prog, Body, Stm, Pattern, PatElem, BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType) import Futhark.Tools import Futhark.Pass import Futhark.Util babysitKernels :: Pass Kernels Kernels babysitKernels = Pass "babysit kernels" "Transpose kernel input arrays for better performance." $ intraproceduralTransformation transformFunDef transformFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels) transformFunDef fundec = do (body', _) <- modifyNameSource $ runState (runBinderT m M.empty) return fundec { funDefBody = body' } where m = inScopeOf fundec $ transformBody mempty $ funDefBody fundec type BabysitM = Binder Kernels transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels) transformBody expmap (Body () bnds res) = insertStmsM $ do foldM_ transformStm expmap bnds return $ resultBody res -- | Map from variable names to defining expression. We use this to -- hackily determine whether something is transposed or otherwise -- funky in memory (and we'd prefer it not to be). If we cannot find -- it in the map, we just assume it's all good. HACK and FIXME, I -- suppose. We really should do this at the memory level. type ExpMap = M.Map VName (Stm Kernels) nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int]) nonlinearInMemory name m = case M.lookup name m of Just (Let _ _ (BasicOp (Rearrange perm _))) -> Just $ Just $ rearrangeInverse perm Just (Let _ _ (BasicOp (Reshape _ arr))) -> nonlinearInMemory arr m Just (Let _ _ (BasicOp (Manifest perm _))) -> Just $ Just perm Just (Let pat _ (Op (Kernel _ _ ts _))) -> nonlinear =<< find ((==name) . patElemName . fst) (zip (patternElements pat) ts) _ -> Nothing where nonlinear (pe, t) | inner_r <- arrayRank t, inner_r > 0 = do let outer_r = arrayRank (patElemType pe) - inner_r return $ Just $ rearrangeInverse $ [inner_r..inner_r+outer_r-1] ++ [0..inner_r-1] | otherwise = Nothing transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap transformStm expmap (Let pat aux ke@(Op (Kernel desc space ts kbody))) = do -- Go spelunking for accesses to arrays that are defined outside the -- kernel body and where the indices are kernel thread indices. scope <- askScope let thread_gids = map fst $ spaceDimensions space thread_local = S.fromList $ spaceGlobalId space : spaceLocalId space : thread_gids free_ker_vars = freeInExp ke `S.difference` getKerVariantIds space kbody'' <- evalStateT (traverseKernelBodyArrayIndexes free_ker_vars thread_local (castScope scope <> scopeOfKernelSpace space) (ensureCoalescedAccess expmap (spaceDimensions space) num_threads) kbody) mempty let bnd' = Let pat aux $ Op $ Kernel desc space ts kbody'' addStm bnd' return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap where num_threads = spaceNumThreads space getKerVariantIds (KernelSpace glb_id loc_id grp_id _ _ _ (FlatThreadSpace strct)) = let (gids, _) = unzip strct in S.fromList $ [glb_id, loc_id, grp_id] ++ gids getKerVariantIds (KernelSpace glb_id loc_id grp_id _ _ _ (NestedThreadSpace strct)) = let (gids, _, lids, _) = unzip4 strct in S.fromList $ [glb_id, loc_id, grp_id] ++ gids ++ lids transformStm expmap (Let pat aux e) = do e' <- mapExpM (transform expmap) e let bnd' = Let pat aux e' addStm bnd' return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap transform :: ExpMap -> Mapper Kernels Kernels BabysitM transform expmap = identityMapper { mapOnBody = \scope -> localScope scope . transformBody expmap } type ArrayIndexTransform m = Names -> (VName -> Bool) -> -- thread local? (VName -> SubExp -> Bool)-> -- variant to a certain gid (given as first param)? (SubExp -> Maybe SubExp) -> -- split substitution? Scope InKernel -> -- type environment VName -> Slice SubExp -> m (Maybe (VName, Slice SubExp)) traverseKernelBodyArrayIndexes :: (Applicative f, Monad f) => Names -> Names -> Scope InKernel -> ArrayIndexTransform f -> KernelBody InKernel -> f (KernelBody InKernel) traverseKernelBodyArrayIndexes free_ker_vars thread_variant outer_scope f (KernelBody () kstms kres) = KernelBody () . stmsFromList <$> mapM (onStm (varianceInStms mempty kstms, mkSizeSubsts kstms, outer_scope)) (stmsToList kstms) <*> pure kres where onLambda (variance, szsubst, scope) lam = (\body' -> lam { lambdaBody = body' }) <$> onBody (variance, szsubst, scope') (lambdaBody lam) where scope' = scope <> scopeOfLParams (lambdaParams lam) onStreamLambda (variance, szsubst, scope) lam = (\body' -> lam { groupStreamLambdaBody = body' }) <$> onBody (variance, szsubst, scope') (groupStreamLambdaBody lam) where scope' = scope <> scopeOf lam onBody (variance, szsubst, scope) (Body battr stms bres) = do stms' <- stmsFromList <$> mapM (onStm (variance', szsubst', scope')) (stmsToList stms) Body battr stms' <$> pure bres where variance' = varianceInStms variance stms szsubst' = mkSizeSubsts stms <> szsubst scope' = scope <> scopeOf stms onStm (variance, szsubst, _) (Let pat attr (BasicOp (Index arr is))) = Let pat attr . oldOrNew <$> f free_ker_vars isThreadLocal isGidVariant sizeSubst outer_scope arr is where oldOrNew Nothing = BasicOp $ Index arr is oldOrNew (Just (arr', is')) = BasicOp $ Index arr' is' isGidVariant gid (Var v) = gid == v || S.member gid (M.findWithDefault (S.singleton v) v variance) isGidVariant _ _ = False isThreadLocal v = not $ S.null $ thread_variant `S.intersection` M.findWithDefault (S.singleton v) v variance sizeSubst (Constant v) = Just $ Constant v sizeSubst (Var v) | v `M.member` outer_scope = Just $ Var v | Just v' <- M.lookup v szsubst = sizeSubst v' | otherwise = Nothing onStm (variance, szsubst, scope) (Let pat attr e) = Let pat attr <$> mapExpM (mapper (variance, szsubst, scope)) e mapper ctx = identityMapper { mapOnBody = const (onBody ctx) , mapOnOp = onOp ctx } onOp ctx (GroupReduce w lam input) = GroupReduce w <$> onLambda ctx lam <*> pure input onOp ctx (GroupStream w maxchunk lam accs arrs) = GroupStream w maxchunk <$> onStreamLambda ctx lam <*> pure accs <*> pure arrs onOp _ stm = pure stm mkSizeSubsts = fold . fmap mkStmSizeSubst where mkStmSizeSubst (Let (Pattern [] [pe]) _ (Op (SplitSpace _ _ _ elems_per_i))) = M.singleton (patElemName pe) elems_per_i mkStmSizeSubst _ = mempty -- Not a hashmap, as SubExp is not hashable. type Replacements = M.Map (VName, Slice SubExp) VName ensureCoalescedAccess :: MonadBinder m => ExpMap -> [(VName,SubExp)] -> SubExp -> ArrayIndexTransform (StateT Replacements m) ensureCoalescedAccess expmap thread_space num_threads free_ker_vars isThreadLocal isGidVariant sizeSubst outer_scope arr slice = do seen <- gets $ M.lookup (arr, slice) case (seen, isThreadLocal arr, typeOf <$> M.lookup arr outer_scope) of -- Already took care of this case elsewhere. (Just arr', _, _) -> pure $ Just (arr', slice) (Nothing, False, Just t) -- We are fully indexing the array with thread IDs, but the -- indices are in a permuted order. | Just is <- sliceIndices slice, length is == arrayRank t, Just is' <- coalescedIndexes free_ker_vars isGidVariant (map Var thread_gids) is, Just perm <- is' `isPermutationOf` is -> replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr) -- Check whether the access is already coalesced because of a -- previous rearrange being applied to the current array: -- 1. get the permutation of the source-array rearrange -- 2. apply it to the slice -- 3. check that the innermost index is actually the gid -- of the innermost kernel dimension. -- If so, the access is already coalesced, nothing to do! -- (Cosmin's Heuristic.) | Just (Let _ _ (BasicOp (Rearrange perm _))) <- M.lookup arr expmap, ---- Just (Just perm) <- nonlinearInMemory arr expmap, not $ null perm, not $ null thread_gids, inner_gid <- last thread_gids, length slice >= length perm, slice' <- map (\i -> slice !! i) perm, DimFix inner_ind <- last slice', not $ null thread_gids, isGidVariant inner_gid inner_ind -> -- inner_ind == (Var $ inner_gid) -> return Nothing -- We are not fully indexing an array, but the remaining slice -- is invariant to the innermost-kernel dimension. We assume -- the remaining slice will be sequentially streamed, hence -- tiling will be applied later and will solve coalescing. -- Hence nothing to do at this point. (Cosmin's Heuristic.) | (is, rem_slice) <- splitSlice slice, not $ null rem_slice, allDimAreSlice rem_slice, Nothing <- M.lookup arr expmap, not $ tooSmallSlice (primByteSize (elemType t)) rem_slice, is /= map Var (take (length is) thread_gids) || length is == length thread_gids, not (null thread_gids || null is), not ( S.member (last thread_gids) (S.union (freeIn is) (freeIn rem_slice)) ) -> return Nothing -- We are not fully indexing the array, and the indices are not -- a proper prefix of the thread indices, and some indices are -- thread local, so we assume (HEURISTIC!) that the remaining -- dimensions will be traversed sequentially. | (is, rem_slice) <- splitSlice slice, not $ null rem_slice, not $ tooSmallSlice (primByteSize (elemType t)) rem_slice, is /= map Var (take (length is) thread_gids) || length is == length thread_gids, any isThreadLocal (S.toList $ freeIn is) -> do let perm = coalescingPermutation (length is) $ arrayRank t replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr) -- We are taking a slice of the array with a unit stride. We -- assume that the slice will be traversed sequentially. -- -- We will really want to treat the sliced dimension like two -- dimensions so we can transpose them. This may require -- padding. | (is, rem_slice) <- splitSlice slice, and $ zipWith (==) is $ map Var thread_gids, DimSlice offset len (Constant stride):_ <- rem_slice, isThreadLocalSubExp offset, Just {} <- sizeSubst len, oneIsh stride -> do let num_chunks = if null is then primExpFromSubExp int32 num_threads else coerceIntPrimExp Int32 $ product $ map (primExpFromSubExp int32) $ drop (length is) thread_gdims replace =<< lift (rearrangeSlice (length is) (arraySize (length is) t) num_chunks arr) -- Everything is fine... assuming that the array is in row-major -- order! Make sure that is the case. | Just{} <- nonlinearInMemory arr expmap -> case sliceIndices slice of Just is | Just _ <- coalescedIndexes free_ker_vars isGidVariant (map Var thread_gids) is -> replace =<< lift (rowMajorArray arr) | otherwise -> return Nothing _ -> replace =<< lift (rowMajorArray arr) _ -> return Nothing where (thread_gids, thread_gdims) = unzip thread_space replace arr' = do modify $ M.insert (arr, slice) arr' return $ Just (arr', slice) isThreadLocalSubExp (Var v) = isThreadLocal v isThreadLocalSubExp Constant{} = False -- Heuristic for avoiding rearranging too small arrays. tooSmallSlice :: Int32 -> Slice SubExp -> Bool tooSmallSlice bs = fst . foldl comb (True,bs) . sliceDims where comb (True, x) (Constant (IntValue (Int32Value d))) = (d*x < 4, d*x) comb (_, x) _ = (False, x) splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp) splitSlice [] = ([], []) splitSlice (DimFix i:is) = first (i:) $ splitSlice is splitSlice is = ([], is) allDimAreSlice :: Slice SubExp -> Bool allDimAreSlice [] = True allDimAreSlice (DimFix _:_) = False allDimAreSlice (_:is) = allDimAreSlice is -- Try to move thread indexes into their proper position. coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp] coalescedIndexes free_ker_vars isGidVariant tgids is -- Do Nothing if: -- 1. any of the indices is a constant or a kernel free variable -- (because it would transpose a bigger array then needed -- big overhead). -- 2. the innermost index is variant to the innermost-thread gid -- (because access is likely to be already coalesced) | any isCt is = Nothing | any (`S.member` free_ker_vars) (mapMaybe mbVarId is) = Nothing | not (null tgids), not (null is), Var innergid <- last tgids, num_is > 0 && isGidVariant innergid (last is) = Just is -- 3. Otherwise try fix coalescing | otherwise = Just $ reverse $ foldl move (reverse is) $ zip [0..] (reverse tgids) where num_is = length is move is_rev (i, tgid) -- If tgid is in is_rev anywhere but at position i, and -- position i exists, we move it to position i instead. | Just j <- elemIndex tgid is_rev, i /= j, i < num_is = swap i j is_rev | otherwise = is_rev swap i j l | Just ix <- maybeNth i l, Just jx <- maybeNth j l = update i jx $ update j ix l | otherwise = error $ "coalescedIndexes swap: invalid indices" ++ show (i, j, l) update 0 x (_:ys) = x : ys update i x (y:ys) = y : update (i-1) x ys update _ _ [] = error "coalescedIndexes: update" isCt :: SubExp -> Bool isCt (Constant _) = True isCt (Var _) = False mbVarId (Constant _) = Nothing mbVarId (Var v) = Just v coalescingPermutation :: Int -> Int -> [Int] coalescingPermutation num_is rank = [num_is..rank-1] ++ [0..num_is-1] rearrangeInput :: MonadBinder m => Maybe (Maybe [Int]) -> [Int] -> VName -> m VName rearrangeInput (Just (Just current_perm)) perm arr | current_perm == perm = return arr -- Already has desired representation. rearrangeInput Nothing perm arr | sort perm == perm = return arr -- We don't know the current -- representation, but the indexing -- is linear, so let's hope the -- array is too. rearrangeInput (Just Just{}) perm arr | sort perm == perm = rowMajorArray arr -- We just want a row-major array, no tricks. rearrangeInput manifest perm arr = do -- We may first manifest the array to ensure that it is flat in -- memory. This is sometimes unnecessary, in which case the copy -- will hopefully be removed by the simplifier. manifested <- if isJust manifest then rowMajorArray arr else return arr letExp (baseString arr ++ "_coalesced") $ BasicOp $ Manifest perm manifested rowMajorArray :: MonadBinder m => VName -> m VName rowMajorArray arr = do rank <- arrayRank <$> lookupType arr letExp (baseString arr ++ "_rowmajor") $ BasicOp $ Manifest [0..rank-1] arr rearrangeSlice :: MonadBinder m => Int -> SubExp -> PrimExp VName -> VName -> m VName rearrangeSlice d w num_chunks arr = do num_chunks' <- letSubExp "num_chunks" =<< toExp num_chunks (w_padded, padding) <- paddedScanReduceInput w num_chunks' per_chunk <- letSubExp "per_chunk" $ BasicOp $ BinOp (SQuot Int32) w_padded num_chunks' arr_t <- lookupType arr arr_padded <- padArray w_padded padding arr_t rearrange num_chunks' w_padded per_chunk (baseString arr) arr_padded arr_t where padArray w_padded padding arr_t = do let arr_shape = arrayShape arr_t padding_shape = setDim d arr_shape padding arr_padding <- letExp (baseString arr <> "_padding") $ BasicOp $ Scratch (elemType arr_t) (shapeDims padding_shape) letExp (baseString arr <> "_padded") $ BasicOp $ Concat d arr [arr_padding] w_padded rearrange num_chunks' w_padded per_chunk arr_name arr_padded arr_t = do let arr_dims = arrayDims arr_t pre_dims = take d arr_dims post_dims = drop (d+1) arr_dims extradim_shape = Shape $ pre_dims ++ [num_chunks', per_chunk] ++ post_dims tr_perm = [0..d-1] ++ map (+d) ([1] ++ [2..shapeRank extradim_shape-1-d] ++ [0]) arr_extradim <- letExp (arr_name <> "_extradim") $ BasicOp $ Reshape (map DimNew $ shapeDims extradim_shape) arr_padded arr_extradim_tr <- letExp (arr_name <> "_extradim_tr") $ BasicOp $ Manifest tr_perm arr_extradim arr_inv_tr <- letExp (arr_name <> "_inv_tr") $ BasicOp $ Reshape (map DimCoercion pre_dims ++ map DimNew (w_padded : post_dims)) arr_extradim_tr letExp (arr_name <> "_inv_tr_init") =<< eSliceArray d arr_inv_tr (eSubExp $ constant (0::Int32)) (eSubExp w) paddedScanReduceInput :: MonadBinder m => SubExp -> SubExp -> m (SubExp, SubExp) paddedScanReduceInput w stride = do w_padded <- letSubExp "padded_size" =<< eRoundToMultipleOf Int32 (eSubExp w) (eSubExp stride) padding <- letSubExp "padding" $ BasicOp $ BinOp (Sub Int32) w_padded w return (w_padded, padding) --- Computing variance. type VarianceTable = M.Map VName Names varianceInStms :: VarianceTable -> Stms InKernel -> VarianceTable varianceInStms t = foldl varianceInStm t . stmsToList varianceInStm :: VarianceTable -> Stm InKernel -> VarianceTable varianceInStm variance bnd = foldl' add variance $ patternNames $ stmPattern bnd where add variance' v = M.insert v binding_variance variance' look variance' v = S.insert v $ M.findWithDefault mempty v variance' binding_variance = mconcat $ map (look variance) $ S.toList (freeInStm bnd)