{-# LANGUAGE FlexibleContexts #-} -- | This module implements a compiler pass for inlining functions, -- then removing those that have become dead. module Futhark.Optimise.InliningDeadFun ( inlineAndRemoveDeadFunctions , removeDeadFunctions ) where import Control.Monad.Identity import Data.List import Data.Loc import Data.Maybe import qualified Data.Map.Strict as M import qualified Data.Set as S import Futhark.Representation.SOACS import Futhark.Representation.SOACS.Simplify (simplifyFun) import Futhark.Transform.Rename import Futhark.Analysis.CallGraph import Futhark.Binder import Futhark.Pass aggInlining :: MonadFreshNames m => CallGraph -> [FunDef SOACS] -> m [FunDef SOACS] aggInlining cg = fmap (filter keep) . recurse where noInterestingCalls :: S.Set Name -> FunDef SOACS -> Bool noInterestingCalls interesting fundec = case M.lookup (funDefName fundec) cg of Just calls | not $ any (`elem` interesting') calls -> True _ -> False where interesting' = funDefName fundec `S.insert` interesting -- We apply simplification after every round of inlining, -- because it is more efficient to shrink the program as soon -- as possible, rather than wait until it has balooned after -- full inlining. recurse funs = do let interesting = S.fromList $ map funDefName funs (to_be_inlined, to_inline_in) = partition (noInterestingCalls interesting) funs inlined_but_entry_points = filter (isJust . funDefEntryPoint) to_be_inlined if null to_be_inlined then return funs else do let onFun = simplifyFun <=< renameFun . (`doInlineInCaller` to_be_inlined) to_inline_in' <- recurse =<< mapM onFun to_inline_in return $ inlined_but_entry_points ++ to_inline_in' keep fundec = isJust (funDefEntryPoint fundec) || callsRecursive fundec callsRecursive fundec = maybe False (any recursive) $ M.lookup (funDefName fundec) cg recursive fname = case M.lookup fname cg of Just calls -> fname `elem` calls Nothing -> False -- | @doInlineInCaller caller inlcallees@ inlines in @calleer@ the functions -- in @inlcallees@. At this point the preconditions are that if @inlcallees@ -- is not empty, and, more importantly, the functions in @inlcallees@ do -- not call any other functions. Further extensions that transform a -- tail-recursive function to a do or while loop, should do the transformation -- first and then do the inlining. doInlineInCaller :: FunDef SOACS -> [FunDef SOACS] -> FunDef SOACS doInlineInCaller (FunDef entry name rtp args body) inlcallees = let body' = inlineInBody inlcallees body in FunDef entry name rtp args body' inlineInBody :: [FunDef SOACS] -> Body -> Body inlineInBody inlcallees (Body attr stms res) = Body attr stms' res where stms' = stmsFromList (concatMap inline $ stmsToList stms) inline (Let pat aux (Apply fname args _ (safety,loc,locs))) | fun:_ <- filter ((== fname) . funDefName) inlcallees = let param_stms = zipWith reshapeIfNecessary (map paramIdent $ funDefParams fun) (map fst args) body_stms = stmsToList $ addLocations safety (filter notNoLoc (loc:locs)) $ bodyStms $ funDefBody fun res_stms = map (certify $ stmAuxCerts aux) $ zipWith reshapeIfNecessary (patternIdents pat) $ bodyResult $ funDefBody fun in param_stms ++ body_stms ++ res_stms inline stm = [inlineInStm inlcallees stm] reshapeIfNecessary ident se | t@Array{} <- identType ident, Var v <- se = mkLet [] [ident] $ shapeCoerce (arrayDims t) v | otherwise = mkLet [] [ident] $ BasicOp $ SubExp se notNoLoc :: SrcLoc -> Bool notNoLoc = (/=NoLoc) . locOf inliner :: Monad m => [FunDef SOACS] -> Mapper SOACS SOACS m inliner funs = identityMapper { mapOnBody = const $ return . inlineInBody funs , mapOnOp = return . inlineInSOAC funs } inlineInSOAC :: [FunDef SOACS] -> SOAC SOACS -> SOAC SOACS inlineInSOAC inlcallees = runIdentity . mapSOACM identitySOACMapper { mapOnSOACLambda = return . inlineInLambda inlcallees } inlineInStm :: [FunDef SOACS] -> Stm -> Stm inlineInStm inlcallees (Let pat aux e) = Let pat aux $ mapExp (inliner inlcallees) e inlineInLambda :: [FunDef SOACS] -> Lambda -> Lambda inlineInLambda inlcallees (Lambda params body ret) = Lambda params (inlineInBody inlcallees body) ret addLocations :: Safety -> [SrcLoc] -> Stms SOACS -> Stms SOACS addLocations caller_safety more_locs = fmap onStm where onStm stm = stm { stmExp = onExp $ stmExp stm } onExp (Apply fname args t (safety, loc,locs)) = Apply fname args t (min caller_safety safety, loc,locs++more_locs) onExp (BasicOp (Assert cond desc (loc,locs))) = case caller_safety of Safe -> BasicOp $ Assert cond desc (loc,locs++more_locs) Unsafe -> BasicOp $ SubExp $ Constant Checked onExp (Op soac) = Op $ runIdentity $ mapSOACM identitySOACMapper { mapOnSOACLambda = return . onLambda } soac onExp e = mapExp identityMapper { mapOnBody = const $ return . onBody } e onBody body = body { bodyStms = addLocations caller_safety more_locs $ bodyStms body } onLambda :: Lambda -> Lambda onLambda lam = lam { lambdaBody = onBody $ lambdaBody lam } -- | A composition of 'inlineAggressively' and 'removeDeadFunctions', -- to avoid the cost of type-checking the intermediate stage. inlineAndRemoveDeadFunctions :: Pass SOACS SOACS inlineAndRemoveDeadFunctions = Pass { passName = "Inline and remove dead functions" , passDescription = "Inline and remove resulting dead functions." , passFunction = pass } where pass prog = do let cg = buildCallGraph prog Prog <$> aggInlining cg (progFunctions prog) -- | @removeDeadFunctions prog@ removes the functions that are unreachable from -- the main function from the program. removeDeadFunctions :: Pass SOACS SOACS removeDeadFunctions = Pass { passName = "Remove dead functions" , passDescription = "Remove the functions that are unreachable from the main function" , passFunction = return . pass } where pass prog = let cg = buildCallGraph prog live_funs = filter (isFunInCallGraph cg) (progFunctions prog) in Prog live_funs isFunInCallGraph cg fundec = isJust $ M.lookup (funDefName fundec) cg