{-# LANGUAGE ParallelListComp #-} -- | -- Module : Language.SequentCore.Simpl -- Description : Simplifier reimplementation using Sequent Core -- Maintainer : maurerl@cs.uoregon.edu -- Stability : experimental -- -- A proof of concept to demonstrate that the Sequent Core syntax can be used -- for basic optimization in the style of GHC's simplifier. In some ways, it is -- easier to use Sequent Core for these, as the continuations are expressed -- directly in the program syntax rather than needing to be built up on the fly. module Language.SequentCore.Simpl (plugin) where import Language.SequentCore.Pretty (pprTopLevelBinds) import Language.SequentCore.Simpl.Env import Language.SequentCore.Simpl.Monad import Language.SequentCore.Syntax import Language.SequentCore.Translate import Language.SequentCore.Util import BasicTypes import Coercion ( isCoVar ) import CoreMonad ( Plugin(..), SimplifierMode(..), Tick(..), CoreToDo(..), CoreM, defaultPlugin, reinitializeGlobals, errorMsg, isZeroSimplCount, putMsg ) import CoreSyn ( isRuntimeVar, isCheapUnfolding ) import CoreUnfold ( smallEnoughToInline ) import DynFlags ( gopt, GeneralFlag(..) ) import FastString import Id import HscTypes ( ModGuts(..) ) import OccurAnal ( occurAnalysePgm ) import Outputable import Type ( mkFunTy ) import Var import VarEnv import VarSet import Control.Applicative ( (<$>), (<*>) ) import Control.Exception ( assert ) import Control.Monad ( when ) import Data.Maybe tracing :: Bool tracing = False -- | Plugin data. The initializer replaces all instances of the original -- simplifier with the new one. plugin :: Plugin plugin = defaultPlugin { installCoreToDos = \_ todos -> do reinitializeGlobals let todos' = replace todos return todos' } where replace (CoreDoSimplify max mode : todos) = newPass max mode : replace todos replace (CoreDoPasses todos1 : todos2) = CoreDoPasses (replace todos1) : replace todos2 replace (todo : todos) = todo : replace todos replace [] = [] newPass max mode = CoreDoPluginPass "SeqSimpl" (runSimplifier max mode) runSimplifier :: Int -> SimplifierMode -> ModGuts -> CoreM ModGuts runSimplifier iters mode guts = go 1 guts where go n guts | n > iters = do errorMsg $ text "Ran out of gas after" <+> int iters <+> text "iterations." return guts | otherwise = do let globalEnv = SimplGlobalEnv { sg_mode = mode } mod = mg_module guts coreBinds = mg_binds guts occBinds = runOccurAnal mod coreBinds binds = fromCoreBinds occBinds when tracing $ putMsg $ text "BEFORE" <+> int n $$ text "--------" $$ pprTopLevelBinds binds (binds', count) <- runSimplM globalEnv $ simplModule binds when tracing $ putMsg $ text "AFTER" <+> int n $$ text "-------" $$ pprTopLevelBinds binds' let coreBinds' = bindsToCore binds' guts' = guts { mg_binds = coreBinds' } when tracing $ putMsg $ text "CORE AFTER" <+> int n $$ text "------------" $$ ppr coreBinds' if isZeroSimplCount count then do when tracing $ putMsg $ text "Done after" <+> int n <+> text "iterations" return guts' else go (n+1) guts' runOccurAnal mod core = let isRuleActive = const False rules = [] vects = [] vectVars = emptyVarSet in occurAnalysePgm mod isRuleActive rules vects vectVars core simplModule :: [InBind] -> SimplM [OutBind] simplModule binds = snd <$> simplBinds initialEnv binds TopLevel simplCommand :: SimplEnv -> InCommand -> SimplM OutCommand simplCommand env (Command { cmdLet = binds, cmdValue = val, cmdCont = cont }) = do (env', binds') <- simplBinds env binds NotTopLevel cmd' <- simplCut env' val (staticPart env') cont return $ addLets binds' cmd' simplValue :: SimplEnv -> InValue -> SimplM OutValue simplValue env v = mkCompute <$> simplCut env' v (staticPart env') Return where env' = zapCont env simplBinds :: SimplEnv -> [InBind] -> TopLevelFlag -> SimplM (SimplEnv, [OutBind]) simplBinds env [] _ = return (env, []) simplBinds env (b : bs) level = do (env', b') <- simplBind env (staticPart env) b level (env'', bs') <- simplBinds env' bs level return (env'', b' `consMaybe` bs') simplBind :: SimplEnv -> StaticEnv -> InBind -> TopLevelFlag -> SimplM (SimplEnv, Maybe OutBind) --simplBind env level bind -- | pprTrace "simplBind" (text "Binding" <+> parens (ppr level) <> colon <+> -- ppr bind) False -- = undefined simplBind env_x env_c (NonRec x c) level = simplNonRec env_x x env_c c level simplBind env_x env_c (Rec xcs) level = do (env', xcs') <- simplRec env_x env_c xcs level return (env', if null xcs' then Nothing else Just $ Rec xcs') simplNonRec :: SimplEnv -> InVar -> StaticEnv -> InValue -> TopLevelFlag -> SimplM (SimplEnv, Maybe OutBind) simplNonRec env_x x env_v v level | isTyVar x , Type ty <- assert (isTypeValue v) $ v = let ty' = substTyStatic env_v ty tvs' = extendVarEnv (se_tvSubst env_x) x ty' in return (env_x { se_tvSubst = tvs' }, Nothing) | isCoVar x , Coercion co <- assert (isCoValue v) $ v = let co' = substCoStatic env_v co cvs' = extendVarEnv (se_cvSubst env_x) x co' in return (env_x { se_cvSubst = cvs' }, Nothing) | otherwise = do preInline <- preInlineUnconditionally env_x x env_v v level if preInline then do tick (PreInlineUnconditionally x) let rhs = mkSuspension env_v v env' = extendIdSubst env_x x rhs return (env', Nothing) else do let (env', x') = enterScope env_x x v' <- simplValue (env' `setStaticPart` env_v) v (env'', maybeNewPair) <- completeBind env' x x' v' level return (env'', uncurry NonRec <$> maybeNewPair) completeBind :: SimplEnv -> InVar -> OutVar -> OutValue -> TopLevelFlag -> SimplM (SimplEnv, Maybe (OutVar, OutValue)) completeBind env x x' v level = do postInline <- postInlineUnconditionally env x v level if postInline then do tick (PostInlineUnconditionally x) -- Nevermind about substituting x' for x; we'll substitute v instead let env' = extendIdSubst env x (DoneVal v) return (env', Nothing) else do -- TODO Eta-expansion goes here let ins = se_inScope env defs = se_defs env x'' = x' `setIdInfo` idInfo x ins' = extendInScopeSet ins x'' defs' = extendVarEnv defs x'' (BoundTo v level) return (env { se_inScope = ins', se_defs = defs' }, Just (x'', v)) simplRec :: SimplEnv -> StaticEnv -> [(InVar, InValue)] -> TopLevelFlag -> SimplM (SimplEnv, [(OutVar, OutValue)]) simplRec env_x env_v xvs level = go env0_x [ (x, x', v) | (x, v) <- xvs | x' <- xs' ] [] where go env_x [] acc = return (env_x, reverse acc) go env_x ((x, x', v) : triples) acc = do preInline <- preInlineUnconditionally env_x x env_v v level if preInline then do tick (PreInlineUnconditionally x) let rhs = mkSuspension env_v v env' = extendIdSubst env_x x rhs go env' triples acc else do v' <- simplValue (env_x `setStaticPart` env_v) v (env', bind') <- completeBind env_x x x' v' level go env' triples (bind' `consMaybe` acc) (env0_x, xs') = enterScopes env_x (map fst xvs) -- TODO Deal with casts, i.e. implement the congruence rules from the -- System FC paper simplCut :: SimplEnv -> InValue -> StaticEnv -> InCont -> SimplM OutCommand {- simplCut env_v v env_k cont | pprTrace "simplCut" ( ppr env_v $$ ppr v $$ ppr env_k $$ ppr cont ) False = undefined -} simplCut env_v (Type ty) _env_k cont = assert (isReturnCont cont) $ let ty' = substTy env_v ty in return $ valueCommand (Type ty') simplCut env (Coercion co) _env_k cont = assert (isReturnCont cont) $ let co' = substCo env co in return $ valueCommand (Coercion co') simplCut env (Cont k) _env_k cont = assert (isReturnCont cont) $ do k' <- simplCont env k return $ valueCommand (Cont k') simplCut env_v (Lam x c) env_k (App arg cont) = do tick (BetaReduction x) (env_v', newBind) <- simplNonRec env_v x env_k arg NotTopLevel -- Effectively, here we bind the covariable in the lambda to the current -- continuation before proceeding c' <- simplCommand (bindCont env_v' env_k cont) c return $ addLets (maybeToList newBind) c' simplCut env_v (Lam x c) env_k cont = do let (env_v', x') = enterScope env_v x c' <- simplCommand (zapCont env_v') c simplContWith (env_v `setStaticPart` env_k) (Lam x' c') cont simplCut env_v val env_k (Case x ty alts cont@Case {}) = do tick (CaseOfCase x) let contTy = ty `mkFunTy` contOuterType ty cont contVar <- asContId <$> mkSysLocalM (fsLit "k") contTy (env_k', newBind) <- simplNonRec (env_v `setStaticPart` env_k) contVar env_k (Cont cont) NotTopLevel let env_v' = env_k' `setStaticPart` staticPart env_v comm <- simplCut env_v' val (staticPart env_k') (Case x ty alts $ Jump contVar) return $ addLets (maybeToList newBind) comm simplCut env_v val env_k (Case x _ alts cont) | Just (pairs, body) <- matchCase env_v val alts = do tick (KnownBranch x) (env', binds) <- go (env_v `setStaticPart` env_k) ((x, val) : pairs) [] comm <- simplCommand (env' `pushCont` cont) body return $ addLets binds comm where go env [] acc = return (env, reverse acc) go env ((x, v) : pairs) acc = do (env', maybe_xv') <- simplNonRec env x (staticPart env_v) v NotTopLevel go env' pairs (maybe_xv' `consMaybe` acc) simplCut env_v (Var x) env_k cont = case substId env_v x of DoneId x' -> simplContWith (env_v `setStaticPart` env_k) (Var x') cont DoneVal v -> simplCut (zapSubstEnvs env_v) v env_k cont SuspVal stat v -> simplCut (env_v `setStaticPart` stat) v env_k cont simplCut env_v val@(Lit _) env_k cont = simplContWith (env_v `setStaticPart` env_k) val cont simplCut env_v (Cons ctor args) env_k cont = do args' <- mapM (simplValue env_v) args simplContWith (env_v `setStaticPart` env_k) (Cons ctor args') cont simplCut env_v (Compute c) env_k cont = simplCommand (bindCont env_v env_k cont) c -- TODO Somehow handle updating Definitions with NotAmong values? matchCase :: SimplEnv -> InValue -> [InAlt] -> Maybe ([(InVar, InValue)], InCommand) -- TODO First, handle variables with substitutions/unfoldings matchCase _env_v (Lit lit) (Alt (LitAlt lit') xs body : _alts) | assert (null xs) True , lit == lit' = Just ([], body) matchCase _env_v (Cons ctor args) (Alt (DataAlt ctor') xs body : _alts) | assert (length args == length xs) True , ctor == ctor' = Just (zip xs args, body) matchCase env_v val (Alt DEFAULT xs body : alts) | assert (null xs) True , Nothing <- matchCase env_v val alts = Just ([], body) matchCase env_v val (_ : alts) = matchCase env_v val alts matchCase _ _ [] = Nothing simplCont :: SimplEnv -> InCont -> SimplM OutCont {- simplCont env cont | pprTrace "simplCont" ( ppr env $$ ppr cont ) False = undefined -} simplCont env cont = go env cont (\k -> k) where {- go env cont _ | pprTrace "simplCont::go" ( ppr cont ) False = undefined -} go env (App arg cont) kc = do arg' <- simplValue env arg go env cont (kc . App arg') go env (Cast co cont) kc -- TODO Simplify coercions = go env cont (kc . Cast co) go env (Case x ty alts cont) kc -- TODO A whole lot - cases are important = doCase env'' x' ty' alts [] where (env', cont') | Jump {} <- cont = (bindCont env (staticPart env) cont, Return) | otherwise = (env, cont) (env'', x') = enterScope env' x ty' = substTy env'' ty env_orig = env doCase _env x ty [] alt_acc = go env_orig cont' (kc . Case x ty (reverse alt_acc)) doCase env x ty (Alt con xs c : alts) alt_acc = do let (env', xs') = enterScopes env xs c' <- simplCommand env' c doCase env x ty alts (Alt con xs' c' : alt_acc) go env (Tick ti cont) kc = go env cont (kc . Tick ti) go env (Jump x) kc -- TODO Consider call-site inline = case substId env x of DoneId x' -> return $ kc (Jump x') DoneVal (Cont k) -> go (zapSubstEnvs env) k kc SuspVal stat (Cont k) -> go (env `setStaticPart` stat) k kc _ -> error "jump to non-continuation" go env Return kc | Just (env', cont) <- restoreEnv env = go env' cont kc | otherwise = return $ kc Return simplContWith :: SimplEnv -> OutValue -> InCont -> SimplM OutCommand simplContWith env val cont = mkCommand [] val <$> simplCont env cont -- Based on preInlineUnconditionally in SimplUtils; see comments there preInlineUnconditionally :: SimplEnv -> InVar -> StaticEnv -> InValue -> TopLevelFlag -> SimplM Bool -- preInlineUnconditionally _ _ _ _ _ = return False preInlineUnconditionally _env_x x _env_rhs rhs level = do ans <- go <$> getMode <*> getDynFlags --liftCoreM $ putMsg $ "preInline" <+> ppr x <> colon <+> text (show ans)) return ans where go mode dflags | not active = False | not enabled = False | TopLevel <- level, isBottomingId x = False -- TODO Somehow GHC can pre-inline an exported thing? We can't, anyway | isExportedId x = False | isCoVar x = False | otherwise = case idOccInfo x of IAmDead -> True OneOcc inLam True intCxt -> try_once inLam intCxt _ -> False where active = isActive (sm_phase mode) act act = idInlineActivation x enabled = gopt Opt_SimplPreInlining dflags try_once inLam intCxt | not inLam = isNotTopLevel level || early_phase | otherwise = intCxt && canInlineValInLam rhs canInlineInLam c | Just v <- asValueCommand c = canInlineValInLam v | otherwise = False canInlineValInLam (Lit _) = True canInlineValInLam (Lam x c) = isRuntimeVar x || canInlineInLam c canInlineValInLam (Compute c) = canInlineInLam c canInlineValInLam _ = False early_phase = case sm_phase mode of Phase 0 -> False _ -> True -- Based on postInlineUnconditionally in SimplUtils; see comments there postInlineUnconditionally :: SimplEnv -> OutVar -> OutValue -> TopLevelFlag -> SimplM Bool postInlineUnconditionally _env x v level = do ans <- go <$> getMode <*> getDynFlags -- liftCoreM $ putMsg $ "postInline" <+> ppr x <> colon <+> text (show ans) return ans where go mode dflags | not active = False | isWeakLoopBreaker occ_info = False | isExportedId x = False | isTopLevel level = False | isTrivialValue v = True | otherwise = case occ_info of OneOcc in_lam _one_br int_cxt -- TODO Actually update unfoldings so that this makes sense -> smallEnoughToInline dflags unfolding && (not in_lam || (isCheapUnfolding unfolding && int_cxt)) IAmDead -> True _ -> False where occ_info = idOccInfo x active = isActive (sm_phase mode) (idInlineActivation x) unfolding = idUnfolding x