{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -- | Do various kernel optimisations - mostly related to coalescing. module Futhark.Pass.KernelBabysitting ( babysitKernels , nonlinearInMemory ) where import Control.Arrow (first) import Control.Monad.State.Strict import qualified Data.Map.Strict as M import qualified Data.Set as S import Data.Foldable import Data.List import Data.Maybe import Data.Semigroup ((<>)) import Futhark.MonadFreshNames import Futhark.Representation.AST import Futhark.Representation.Kernels hiding (Prog, Body, Stm, Pattern, PatElem, BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType) import Futhark.Tools import Futhark.Pass import Futhark.Util babysitKernels :: Pass Kernels Kernels babysitKernels = Pass "babysit kernels" "Transpose kernel input arrays for better performance." $ intraproceduralTransformation transformFunDef transformFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels) transformFunDef fundec = do (body', _) <- modifyNameSource $ runState (runBinderT m M.empty) return fundec { funDefBody = body' } where m = inScopeOf fundec $ transformBody mempty $ funDefBody fundec type BabysitM = Binder Kernels transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels) transformBody expmap (Body () bnds res) = insertStmsM $ do foldM_ transformStm expmap bnds return $ resultBody res -- | Map from variable names to defining expression. We use this to -- hackily determine whether something is transposed or otherwise -- funky in memory (and we'd prefer it not to be). If we cannot find -- it in the map, we just assume it's all good. HACK and FIXME, I -- suppose. We really should do this at the memory level. type ExpMap = M.Map VName (Stm Kernels) nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int]) nonlinearInMemory name m = case M.lookup name m of Just (Let _ _ (BasicOp (Rearrange perm _))) -> Just $ Just $ rearrangeInverse perm Just (Let _ _ (BasicOp (Reshape _ arr))) -> nonlinearInMemory arr m Just (Let _ _ (BasicOp (Manifest perm _))) -> Just $ Just perm Just (Let pat _ (Op (Kernel _ _ ts _))) -> nonlinear =<< find ((==name) . patElemName . fst) (zip (patternElements pat) ts) _ -> Nothing where nonlinear (pe, t) | inner_r <- arrayRank t, inner_r > 0 = do let outer_r = arrayRank (patElemType pe) - inner_r return $ Just $ rearrangeInverse $ [inner_r..inner_r+outer_r-1] ++ [0..inner_r-1] | otherwise = Nothing transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap transformStm expmap (Let pat aux (Op (Kernel desc space ts kbody))) = do -- Go spelunking for accesses to arrays that are defined outside the -- kernel body and where the indices are kernel thread indices. scope <- askScope let thread_gids = map fst $ spaceDimensions space thread_local = S.fromList $ spaceGlobalId space : spaceLocalId space : thread_gids kbody'' <- evalStateT (traverseKernelBodyArrayIndexes thread_local (castScope scope <> scopeOfKernelSpace space) (ensureCoalescedAccess expmap (spaceDimensions space) num_threads) kbody) mempty let bnd' = Let pat aux $ Op $ Kernel desc space ts kbody'' addStm bnd' return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap where num_threads = spaceNumThreads space transformStm expmap (Let pat aux e) = do e' <- mapExpM (transform expmap) e let bnd' = Let pat aux e' addStm bnd' return $ M.fromList [ (name, bnd') | name <- patternNames pat ] <> expmap transform :: ExpMap -> Mapper Kernels Kernels BabysitM transform expmap = identityMapper { mapOnBody = \scope -> localScope scope . transformBody expmap } type ArrayIndexTransform m = (VName -> Bool) -> -- thread local? (SubExp -> Maybe SubExp) -> -- split substitution? Scope InKernel -> -- type environment VName -> Slice SubExp -> m (Maybe (VName, Slice SubExp)) traverseKernelBodyArrayIndexes :: (Applicative f, Monad f) => Names -> Scope InKernel -> ArrayIndexTransform f -> KernelBody InKernel -> f (KernelBody InKernel) traverseKernelBodyArrayIndexes thread_variant outer_scope f (KernelBody () kstms kres) = KernelBody () . stmsFromList <$> mapM (onStm (varianceInStms mempty kstms, mkSizeSubsts kstms, outer_scope)) (stmsToList kstms) <*> pure kres where onLambda (variance, szsubst, scope) lam = (\body' -> lam { lambdaBody = body' }) <$> onBody (variance, szsubst, scope') (lambdaBody lam) where scope' = scope <> scopeOfLParams (lambdaParams lam) onStreamLambda (variance, szsubst, scope) lam = (\body' -> lam { groupStreamLambdaBody = body' }) <$> onBody (variance, szsubst, scope') (groupStreamLambdaBody lam) where scope' = scope <> scopeOf lam onBody (variance, szsubst, scope) (Body battr stms bres) = do stms' <- stmsFromList <$> mapM (onStm (variance', szsubst', scope')) (stmsToList stms) Body battr stms' <$> pure bres where variance' = varianceInStms variance stms szsubst' = mkSizeSubsts stms <> szsubst scope' = scope <> scopeOf stms onStm (variance, szsubst, _) (Let pat attr (BasicOp (Index arr is))) = Let pat attr . oldOrNew <$> f isThreadLocal sizeSubst outer_scope arr is where oldOrNew Nothing = BasicOp $ Index arr is oldOrNew (Just (arr', is')) = BasicOp $ Index arr' is' isThreadLocal v = not $ S.null $ thread_variant `S.intersection` M.findWithDefault (S.singleton v) v variance sizeSubst (Constant v) = Just $ Constant v sizeSubst (Var v) | v `M.member` outer_scope = Just $ Var v | Just v' <- M.lookup v szsubst = sizeSubst v' | otherwise = Nothing onStm (variance, szsubst, scope) (Let pat attr e) = Let pat attr <$> mapExpM (mapper (variance, szsubst, scope)) e mapper ctx = identityMapper { mapOnBody = const (onBody ctx) , mapOnOp = onOp ctx } onOp ctx (GroupReduce w lam input) = GroupReduce w <$> onLambda ctx lam <*> pure input onOp ctx (GroupStream w maxchunk lam accs arrs) = GroupStream w maxchunk <$> onStreamLambda ctx lam <*> pure accs <*> pure arrs onOp _ stm = pure stm mkSizeSubsts = fold . fmap mkStmSizeSubst where mkStmSizeSubst (Let (Pattern [] [pe]) _ (Op (SplitSpace _ _ _ elems_per_i))) = M.singleton (patElemName pe) elems_per_i mkStmSizeSubst _ = mempty -- Not a hashmap, as SubExp is not hashable. type Replacements = M.Map (VName, Slice SubExp) VName ensureCoalescedAccess :: MonadBinder m => ExpMap -> [(VName,SubExp)] -> SubExp -> ArrayIndexTransform (StateT Replacements m) ensureCoalescedAccess expmap thread_space num_threads isThreadLocal sizeSubst outer_scope arr slice = do seen <- gets $ M.lookup (arr, slice) case (seen, isThreadLocal arr, typeOf <$> M.lookup arr outer_scope) of -- Already took care of this case elsewhere. (Just arr', _, _) -> pure $ Just (arr', slice) (Nothing, False, Just t) -- We are fully indexing the array with thread IDs, but the -- indices are in a permuted order. | Just is <- sliceIndices slice, length is == arrayRank t, Just is' <- coalescedIndexes (map Var thread_gids) is, Just perm <- is' `isPermutationOf` is -> replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr) -- Check whether the access is already coalesced because of a -- previous rearrange being applied to the current array: -- 1. get the permutation of the source-array rearrange -- 2. apply it to the slice -- 3. check that the innermost index is actually the gid -- of the innermost kernel dimension. -- If so, the access is already coalesced, nothing to do! -- (Cosmin's Heuristic.) | Just (Let _ _ (BasicOp (Rearrange perm _))) <- M.lookup arr expmap, ---- Just (Just perm) <- nonlinearInMemory arr expmap, not $ null perm, length slice >= length perm, slice' <- map (\i -> slice !! i) perm, DimFix inner_ind <- last slice', not $ null thread_gids, inner_ind == (Var $ last thread_gids) -> return Nothing -- We are not fully indexing an array, but the remaining slice -- is invariant to the innermost-kernel dimension. We assume -- the remaining slice will be sequentially streamed, hence -- tiling will be applied later and will solve coalescing. -- Hence nothing to do at this point. (Cosmin's Heuristic.) | (is, rem_slice) <- splitSlice slice, not $ null rem_slice, allDimAreSlice rem_slice, Nothing <- M.lookup arr expmap, not $ tooSmallSlice (primByteSize (elemType t)) rem_slice, is /= map Var (take (length is) thread_gids) || length is == length thread_gids, not (null thread_gids || null is), not ( S.member (last thread_gids) (S.union (freeIn is) (freeIn rem_slice)) ) -> return Nothing -- We are not fully indexing the array, and the indices are not -- a proper prefix of the thread indices, and some indices are -- thread local, so we assume (HEURISTIC!) that the remaining -- dimensions will be traversed sequentially. | (is, rem_slice) <- splitSlice slice, not $ null rem_slice, not $ tooSmallSlice (primByteSize (elemType t)) rem_slice, is /= map Var (take (length is) thread_gids) || length is == length thread_gids, any isThreadLocal (S.toList $ freeIn is) -> do let perm = coalescingPermutation (length is) $ arrayRank t replace =<< lift (rearrangeInput (nonlinearInMemory arr expmap) perm arr) -- We are taking a slice of the array with a unit stride. We -- assume that the slice will be traversed sequentially. -- -- We will really want to treat the sliced dimension like two -- dimensions so we can transpose them. This may require -- padding. | (is, rem_slice) <- splitSlice slice, and $ zipWith (==) is $ map Var thread_gids, DimSlice offset len (Constant stride):_ <- rem_slice, isThreadLocalSubExp offset, Just {} <- sizeSubst len, oneIsh stride -> do let num_chunks = if null is then primExpFromSubExp int32 num_threads else coerceIntPrimExp Int32 $ product $ map (primExpFromSubExp int32) $ drop (length is) thread_gdims replace =<< lift (rearrangeSlice (length is) (arraySize (length is) t) num_chunks arr) -- Everything is fine... assuming that the array is in row-major -- order! Make sure that is the case. | Just{} <- nonlinearInMemory arr expmap -> case sliceIndices slice of Just is | Just _ <- coalescedIndexes (map Var thread_gids) is -> replace =<< lift (rowMajorArray arr) | otherwise -> return Nothing _ -> replace =<< lift (rowMajorArray arr) _ -> return Nothing where (thread_gids, thread_gdims) = unzip thread_space replace arr' = do modify $ M.insert (arr, slice) arr' return $ Just (arr', slice) isThreadLocalSubExp (Var v) = isThreadLocal v isThreadLocalSubExp Constant{} = False -- Heuristic for avoiding rearranging too small arrays. tooSmallSlice :: Int32 -> Slice SubExp -> Bool tooSmallSlice bs = fst . foldl comb (True,bs) . sliceDims where comb (True, x) (Constant (IntValue (Int32Value d))) = (d*x < 4, d*x) comb (_, x) _ = (False, x) splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp) splitSlice [] = ([], []) splitSlice (DimFix i:is) = first (i:) $ splitSlice is splitSlice is = ([], is) allDimAreSlice :: Slice SubExp -> Bool allDimAreSlice [] = True allDimAreSlice (DimFix _:_) = False allDimAreSlice (_:is) = allDimAreSlice is -- Try to move thread indexes into their proper position. coalescedIndexes :: [SubExp] -> [SubExp] -> Maybe [SubExp] coalescedIndexes tgids is -- Do Nothing if: -- 1. the innermost index is the innermost thread id -- (because access is already coalesced) -- 2. any of the indices is a constant, i.e., kernel free variable -- (because it would transpose a bigger array then needed -- big overhead). | any isCt is = Nothing | num_is > 0 && not (null tgids) && last is == last tgids = Just is -- Otherwise try fix coalescing | otherwise = Just $ reverse $ foldl move (reverse is) $ zip [0..] (reverse tgids) where num_is = length is move is_rev (i, tgid) -- If tgid is in is_rev anywhere but at position i, and -- position i exists, we move it to position i instead. | Just j <- elemIndex tgid is_rev, i /= j, i < num_is = swap i j is_rev | otherwise = is_rev swap i j l | Just ix <- maybeNth i l, Just jx <- maybeNth j l = update i jx $ update j ix l | otherwise = error $ "coalescedIndexes swap: invalid indices" ++ show (i, j, l) update 0 x (_:ys) = x : ys update i x (y:ys) = y : update (i-1) x ys update _ _ [] = error "coalescedIndexes: update" isCt :: SubExp -> Bool isCt (Constant _) = True isCt (Var _) = False coalescingPermutation :: Int -> Int -> [Int] coalescingPermutation num_is rank = [num_is..rank-1] ++ [0..num_is-1] rearrangeInput :: MonadBinder m => Maybe (Maybe [Int]) -> [Int] -> VName -> m VName rearrangeInput (Just (Just current_perm)) perm arr | current_perm == perm = return arr -- Already has desired representation. rearrangeInput Nothing perm arr | sort perm == perm = return arr -- We don't know the current -- representation, but the indexing -- is linear, so let's hope the -- array is too. rearrangeInput (Just Just{}) perm arr | sort perm == perm = rowMajorArray arr -- We just want a row-major array, no tricks. rearrangeInput manifest perm arr = do -- We may first manifest the array to ensure that it is flat in -- memory. This is sometimes unnecessary, in which case the copy -- will hopefully be removed by the simplifier. manifested <- if isJust manifest then rowMajorArray arr else return arr letExp (baseString arr ++ "_coalesced") $ BasicOp $ Manifest perm manifested rowMajorArray :: MonadBinder m => VName -> m VName rowMajorArray arr = do rank <- arrayRank <$> lookupType arr letExp (baseString arr ++ "_rowmajor") $ BasicOp $ Manifest [0..rank-1] arr rearrangeSlice :: MonadBinder m => Int -> SubExp -> PrimExp VName -> VName -> m VName rearrangeSlice d w num_chunks arr = do num_chunks' <- letSubExp "num_chunks" =<< toExp num_chunks (w_padded, padding) <- paddedScanReduceInput w num_chunks' per_chunk <- letSubExp "per_chunk" $ BasicOp $ BinOp (SQuot Int32) w_padded num_chunks' arr_t <- lookupType arr arr_padded <- padArray w_padded padding arr_t rearrange num_chunks' w_padded per_chunk (baseString arr) arr_padded arr_t where padArray w_padded padding arr_t = do let arr_shape = arrayShape arr_t padding_shape = setDim d arr_shape padding arr_padding <- letExp (baseString arr <> "_padding") $ BasicOp $ Scratch (elemType arr_t) (shapeDims padding_shape) letExp (baseString arr <> "_padded") $ BasicOp $ Concat d arr [arr_padding] w_padded rearrange num_chunks' w_padded per_chunk arr_name arr_padded arr_t = do let arr_dims = arrayDims arr_t pre_dims = take d arr_dims post_dims = drop (d+1) arr_dims extradim_shape = Shape $ pre_dims ++ [num_chunks', per_chunk] ++ post_dims tr_perm = [0..d-1] ++ map (+d) ([1] ++ [2..shapeRank extradim_shape-1-d] ++ [0]) arr_extradim <- letExp (arr_name <> "_extradim") $ BasicOp $ Reshape (map DimNew $ shapeDims extradim_shape) arr_padded arr_extradim_tr <- letExp (arr_name <> "_extradim_tr") $ BasicOp $ Manifest tr_perm arr_extradim arr_inv_tr <- letExp (arr_name <> "_inv_tr") $ BasicOp $ Reshape (map DimCoercion pre_dims ++ map DimNew (w_padded : post_dims)) arr_extradim_tr letExp (arr_name <> "_inv_tr_init") =<< eSliceArray d arr_inv_tr (eSubExp $ constant (0::Int32)) (eSubExp w) paddedScanReduceInput :: MonadBinder m => SubExp -> SubExp -> m (SubExp, SubExp) paddedScanReduceInput w stride = do w_padded <- letSubExp "padded_size" =<< eRoundToMultipleOf Int32 (eSubExp w) (eSubExp stride) padding <- letSubExp "padding" $ BasicOp $ BinOp (Sub Int32) w_padded w return (w_padded, padding) --- Computing variance. type VarianceTable = M.Map VName Names varianceInStms :: VarianceTable -> Stms InKernel -> VarianceTable varianceInStms t = foldl varianceInStm t . stmsToList varianceInStm :: VarianceTable -> Stm InKernel -> VarianceTable varianceInStm variance bnd = foldl' add variance $ patternNames $ stmPattern bnd where add variance' v = M.insert v binding_variance variance' look variance' v = S.insert v $ M.findWithDefault mempty v variance' binding_variance = mconcat $ map (look variance) $ S.toList (freeInStm bnd)