-- | Move variables as much as possible upwards in a program. module Futhark.Optimise.MemoryBlockMerging.CrudeMovingUp ( moveUpInFunDef ) where import qualified Data.Set as S import qualified Data.List as L import qualified Data.Map.Strict as M import Data.Maybe (mapMaybe) import Control.Monad import Control.Monad.RWS import Control.Monad.Writer import Futhark.Representation.AST import Futhark.Representation.ExplicitMemory (ExplicitMemory) import qualified Futhark.Representation.ExplicitMemory as ExpMem import Futhark.Optimise.MemoryBlockMerging.Miscellaneous import Control.Monad.State import Control.Monad.Identity type Line = Int data Origin = FromFParam | FromLine Line (Exp ExplicitMemory) deriving (Eq, Ord, Show) -- The dependencies and the location. data PrimBinding = PrimBinding { pbFrees :: Names , _pbConsumed :: Names , pbOrigin :: Origin } deriving (Show) -- A mapping from names to PrimBinding. The key is a collection of names, since -- a statement can have multiple patterns. type BindingMap = [(Names, PrimBinding)] -- | Call 'findHoistees' for every body, and then hoist every one of the found -- hoistees (variables). moveUpInFunDef :: FunDef ExplicitMemory -> (Body ExplicitMemory -> Maybe [FParam ExplicitMemory] -> [VName]) -> FunDef ExplicitMemory moveUpInFunDef fundef findHoistees = let scope_new = scopeOf fundef bindingmap_cur = [] body' = hoistInBody scope_new bindingmap_cur (Just (funDefParams fundef)) findHoistees (funDefBody fundef) fundef' = fundef { funDefBody = body' } in fundef' lookupPrimBinding :: VName -> State BindingMap PrimBinding lookupPrimBinding vname = gets $ snd . fromJust (pretty vname ++ " was not found in BindingMap." ++ " This should not happen!") . L.find ((vname `S.member`) . fst) namesDependingOn :: VName -> State BindingMap Names namesDependingOn v = gets $ S.unions . map fst . filter (\(_, pb) -> v `S.member` pbFrees pb) scopeBindingMap :: (VName, NameInfo ExplicitMemory) -> BindingMap scopeBindingMap (x, _) = [(S.singleton x, PrimBinding S.empty S.empty FromFParam)] -- Find all variables bound in a KernelSpace. boundInKernelSpace :: ExpMem.KernelSpace -> Names boundInKernelSpace space = -- This might do too much. S.fromList ([ ExpMem.spaceGlobalId space , ExpMem.spaceLocalId space , ExpMem.spaceGroupId space] ++ (case ExpMem.spaceStructure space of ExpMem.FlatThreadSpace ts -> map fst ts ++ mapMaybe (subExpVar . snd) ts ExpMem.NestedThreadSpace ts -> map (\(x, _, _, _) -> x) ts ++ mapMaybe (subExpVar . (\(_, x, _, _) -> x)) ts ++ map (\(_, _, x, _) -> x) ts ++ mapMaybe (subExpVar . (\(_, _, _, x) -> x)) ts )) -- FIXME: The results of this should maybe go in the core 'freeIn' function, or -- perhaps the ExplicitMemory module, instead of this arbitrary module. boundInExpExtra :: Exp ExplicitMemory -> Names boundInExpExtra = execWriter . inExp where inExp :: Exp ExplicitMemory -> Writer Names () inExp e = case e of Op (ExpMem.Inner (ExpMem.Kernel _ space _ _)) -> tell $ boundInKernelSpace space _ -> walkExpM walker e walker = identityWalker { walkOnBody = mapM_ (inExp . stmExp) . bodyStms } bodyBindingMap :: [Stm ExplicitMemory] -> BindingMap bodyBindingMap stms = concatMap createBindingStmt $ zip [0..] stms -- We do not need to run this recursively on any sub-bodies, since this will -- be run for every call to hoistInBody, which *does* run recursively on -- sub-bodies. where createBindingStmt :: (Line, Stm ExplicitMemory) -> BindingMap createBindingStmt (line, stmt@(Let (Pattern patctxelems patvalelems) _ e)) = let stmt_vars = S.fromList (map patElemName (patctxelems ++ patvalelems)) frees = freeInStm stmt consumed = case e of BasicOp (Update src _ _) -> S.singleton src _ -> mempty bound_extra = boundInExpExtra e frees' = frees `S.difference` bound_extra vars_binding = (stmt_vars, PrimBinding frees' consumed (FromLine line e)) -- Some variables exist only in a shape declaration. shape_sizes = S.fromList $ concatMap shapeSizes (patctxelems ++ patvalelems) sizes_binding = (shape_sizes, PrimBinding frees' consumed (FromLine line e)) -- Some expressions contain special identifiers that are used in a -- body. This should go somewhere else than here. param_vars = case e of Op (ExpMem.Inner (ExpMem.Kernel _ space _ _)) -> boundInKernelSpace space _ -> S.empty params_binding = (param_vars, PrimBinding S.empty S.empty FromFParam) bmap = [vars_binding, sizes_binding, params_binding] in bmap shapeSizes (PatElem _ (ExpMem.MemArray _ shape _ _)) = mapMaybe subExpVar $ shapeDims shape shapeSizes _ = [] hoistInBody :: Scope ExplicitMemory -> BindingMap -> Maybe [FParam ExplicitMemory] -> (Body ExplicitMemory -> Maybe [FParam ExplicitMemory] -> [VName]) -> Body ExplicitMemory -> Body ExplicitMemory hoistInBody scope_new bindingmap_old params findHoistees body = let hoistees = findHoistees body params -- We use the possibly non-empty scope to extend our BindingMap. bindingmap_fromscope = concatMap scopeBindingMap $ M.toList scope_new bindingmap_body = bodyBindingMap $ stmsToList $ bodyStms body bindingmap = bindingmap_old ++ bindingmap_fromscope ++ bindingmap_body -- Create a new body where all hoistees have been moved as much upwards in -- the statement list as possible. (Body () bnds res, bindingmap') = foldl (\(body0, lbindingmap) -> hoist lbindingmap body0) (body, bindingmap) hoistees -- Touch upon any subbodies. bnds' = fmap (hoistRecursivelyStm bindingmap' findHoistees) bnds body' = Body () bnds' res in body' hoistRecursivelyStm :: BindingMap -> (Body ExplicitMemory -> Maybe [FParam ExplicitMemory] -> [VName]) -> Stm ExplicitMemory -> Stm ExplicitMemory hoistRecursivelyStm bindingmap findHoistees (Let pat aux e) = runIdentity (Let pat aux <$> mapExpM transform e) where transform = identityMapper { mapOnBody = mapper } mapper scope_new = return . hoistInBody scope_new bindingmap' Nothing findHoistees -- The nested body cannot move to any of its locations of its parent's -- body, so we say that all its parent's bindings are parameters. bindingmap' = map (\(ns, PrimBinding frees consumed _) -> (ns, PrimBinding frees consumed FromFParam)) bindingmap -- Hoist the statement denoted by 'hoistee' as much upwards as possible in -- 'body', and return the new body. hoist :: BindingMap -> Body ExplicitMemory -> VName -> (Body ExplicitMemory, BindingMap) hoist bindingmap_cur body hoistee = let bindingmap = bindingmap_cur <> bodyBindingMap (stmsToList $ bodyStms body) body' = runState (moveLetUpwards hoistee body) bindingmap in body' -- Move a statement as much up as possible. moveLetUpwards :: VName -> Body ExplicitMemory -> State BindingMap (Body ExplicitMemory) moveLetUpwards letname body = do PrimBinding deps consumed letorig <- lookupPrimBinding letname -- Extend the dependencies with all those statements that use the consumed -- variables of this statement, except the current statement. deps' <- S.delete letname <$> (S.union deps <$> (S.unions <$> mapM namesDependingOn (S.toList consumed))) case letorig of FromFParam -> return body FromLine line_cur exp_cur -> case exp_cur of -- We do not want to change the structure of the program too much, so we -- restrict the aggressive hoister to *stop* and not hoist loops and -- kernels, as hoisting these expressions might actually make a -- hoisting-dependent optimisation *poorer* because of some assumptions -- about the structure. FIXME: Do this nicer in a way where it is easy -- to argue for it. DoLoop{} -> return body Op ExpMem.Inner{} -> return body _ -> do -- Sort by how close they are to the beginning of the body. The closest -- one should be the first one to hoist, so that the other ones can maybe -- exploit it. deps'' <- sortByKeyM (fmap pbOrigin . lookupPrimBinding) $ S.toList deps' body' <- foldM (flip moveLetUpwards) body deps'' origins <- mapM (fmap pbOrigin . lookupPrimBinding) deps'' let line_dest = case foldl max FromFParam origins of FromFParam -> 0 FromLine n _e -> n + 1 PrimBinding _ _ letorig' <- lookupPrimBinding letname when (letorig' /= letorig) $ error "Assertion: This should not happen." stms' <- moveLetToLine letname line_cur line_dest $ stmsToList $ bodyStms body' return body' { bodyStms = stmsFromList stms' } -- Both move the statement to the new line and update the BindingMap. moveLetToLine :: VName -> Line -> Line -> [Stm ExplicitMemory] -> State BindingMap [Stm ExplicitMemory] moveLetToLine stm_cur_name line_cur line_dest stms | line_cur == line_dest = return stms | otherwise = do let stm_cur = stms !! line_cur stms1 = take line_cur stms ++ drop (line_cur + 1) stms stms2 = take line_dest stms1 ++ [stm_cur] ++ drop line_dest stms1 modify $ map (\t@(ns, PrimBinding frees consumed orig) -> case orig of FromFParam -> t FromLine l e -> if l >= line_dest && l < line_cur then (ns, PrimBinding frees consumed (FromLine (l + 1) e)) else t) r <- lookupPrimBinding stm_cur_name case r of PrimBinding frees consumed (FromLine _ exp_cur) -> modify $ replaceWhere stm_cur_name (PrimBinding frees consumed (FromLine line_dest exp_cur)) _ -> error "moveLetToLine: unhandled case" -- fixme return stms2 replaceWhere :: VName -> PrimBinding -> BindingMap -> BindingMap replaceWhere n pb1 = map (\(ns, pb) -> (ns, if n `S.member` ns then pb1 else pb))