module Language.SequentCore.Simpl (plugin) where
import Language.SequentCore.Pretty (pprTopLevelBinds)
import Language.SequentCore.Simpl.Env
import Language.SequentCore.Simpl.Monad
import Language.SequentCore.Syntax
import Language.SequentCore.Translate
import Language.SequentCore.Util
import BasicTypes
import Coercion ( isCoVar )
import CoreMonad ( Plugin(..), SimplifierMode(..), Tick(..), CoreToDo(..),
CoreM, defaultPlugin, reinitializeGlobals, errorMsg,
isZeroSimplCount, putMsg
)
import CoreSyn ( isRuntimeVar, isCheapUnfolding )
import CoreUnfold ( smallEnoughToInline )
import DynFlags ( gopt, GeneralFlag(..) )
import FastString
import Id
import HscTypes ( ModGuts(..) )
import OccurAnal ( occurAnalysePgm )
import Outputable
import Type ( mkFunTy )
import Var
import VarEnv
import VarSet
import Control.Applicative ( (<$>), (<*>) )
import Control.Exception ( assert )
import Control.Monad ( when )
import Data.Maybe
tracing :: Bool
tracing = False
plugin :: Plugin
plugin = defaultPlugin {
installCoreToDos = \_ todos -> do
reinitializeGlobals
let todos' = replace todos
return todos'
} where
replace (CoreDoSimplify max mode : todos)
= newPass max mode : replace todos
replace (CoreDoPasses todos1 : todos2)
= CoreDoPasses (replace todos1) : replace todos2
replace (todo : todos)
= todo : replace todos
replace []
= []
newPass max mode
= CoreDoPluginPass "SeqSimpl" (runSimplifier max mode)
runSimplifier :: Int -> SimplifierMode -> ModGuts -> CoreM ModGuts
runSimplifier iters mode guts
= go 1 guts
where
go n guts
| n > iters
= do
errorMsg $ text "Ran out of gas after"
<+> int iters
<+> text "iterations."
return guts
| otherwise
= do
let globalEnv = SimplGlobalEnv { sg_mode = mode }
mod = mg_module guts
coreBinds = mg_binds guts
occBinds = runOccurAnal mod coreBinds
binds = fromCoreBinds occBinds
when tracing $ putMsg $ text "BEFORE" <+> int n
$$ text "--------" $$ pprTopLevelBinds binds
(binds', count) <- runSimplM globalEnv $ simplModule binds
when tracing $ putMsg $ text "AFTER" <+> int n
$$ text "-------" $$ pprTopLevelBinds binds'
let coreBinds' = bindsToCore binds'
guts' = guts { mg_binds = coreBinds' }
when tracing $ putMsg $ text "CORE AFTER" <+> int n
$$ text "------------" $$ ppr coreBinds'
if isZeroSimplCount count
then do
when tracing $ putMsg $ text "Done after"
<+> int n <+> text "iterations"
return guts'
else go (n+1) guts'
runOccurAnal mod core
= let isRuleActive = const False
rules = []
vects = []
vectVars = emptyVarSet
in occurAnalysePgm mod isRuleActive rules vects vectVars core
simplModule :: [InBind] -> SimplM [OutBind]
simplModule binds
= snd <$> simplBinds initialEnv binds TopLevel
simplCommand :: SimplEnv -> InCommand -> SimplM OutCommand
simplCommand env (Command { cmdLet = binds, cmdValue = val, cmdCont = cont })
= do
(env', binds') <- simplBinds env binds NotTopLevel
cmd' <- simplCut env' val (staticPart env') cont
return $ addLets binds' cmd'
simplValue :: SimplEnv -> InValue -> SimplM OutValue
simplValue env v
= mkCompute <$> simplCut env' v (staticPart env') Return
where env' = zapCont env
simplBinds :: SimplEnv -> [InBind] -> TopLevelFlag
-> SimplM (SimplEnv, [OutBind])
simplBinds env [] _
= return (env, [])
simplBinds env (b : bs) level
= do
(env', b') <- simplBind env (staticPart env) b level
(env'', bs') <- simplBinds env' bs level
return (env'', b' `consMaybe` bs')
simplBind :: SimplEnv -> StaticEnv -> InBind -> TopLevelFlag
-> SimplM (SimplEnv, Maybe OutBind)
simplBind env_x env_c(NonRec x c) level
= simplNonRec env_x x env_c c level
simplBind env_x env_c (Rec xcs) level
= do
(env', xcs') <- simplRec env_x env_c xcs level
return (env', if null xcs' then Nothing else Just $ Rec xcs')
simplNonRec :: SimplEnv -> InVar -> StaticEnv -> InValue -> TopLevelFlag
-> SimplM (SimplEnv, Maybe OutBind)
simplNonRec env_x x env_v v level
| isTyVar x
, Type ty <- assert (isTypeValue v) $ v
= let ty' = substTyStatic env_v ty
tvs' = extendVarEnv (se_tvSubst env_x) x ty'
in return (env_x { se_tvSubst = tvs' }, Nothing)
| isCoVar x
, Coercion co <- assert (isCoValue v) $ v
= let co' = substCoStatic env_v co
cvs' = extendVarEnv (se_cvSubst env_x) x co'
in return (env_x { se_cvSubst = cvs' }, Nothing)
| otherwise
= do
preInline <- preInlineUnconditionally env_x x env_v v level
if preInline
then do
tick (PreInlineUnconditionally x)
let rhs = mkSuspension env_v v
env' = extendIdSubst env_x x rhs
return (env', Nothing)
else do
let (env', x') = enterScope env_x x
v' <- simplValue (env' `setStaticPart` env_v) v
(env'', maybeNewPair) <- completeBind env' x x' v' level
return (env'', uncurry NonRec <$> maybeNewPair)
completeBind :: SimplEnv -> InVar -> OutVar -> OutValue -> TopLevelFlag
-> SimplM (SimplEnv, Maybe (OutVar, OutValue))
completeBind env x x' v level
= do
postInline <- postInlineUnconditionally env x v level
if postInline
then do
tick (PostInlineUnconditionally x)
let env' = extendIdSubst env x (DoneVal v)
return (env', Nothing)
else do
let ins = se_inScope env
defs = se_defs env
x'' = x' `setIdInfo` idInfo x
ins' = extendInScopeSet ins x''
defs' = extendVarEnv defs x'' (BoundTo v level)
return (env { se_inScope = ins', se_defs = defs' }, Just (x'', v))
simplRec :: SimplEnv -> StaticEnv -> [(InVar, InValue)] -> TopLevelFlag
-> SimplM (SimplEnv, [(OutVar, OutValue)])
simplRec env_x env_v xvs level
= go env0_x [ (x, x', v) | (x, v) <- xvs | x' <- xs' ] []
where
go env_x [] acc = return (env_x, reverse acc)
go env_x ((x, x', v) : triples) acc
= do
preInline <- preInlineUnconditionally env_x x env_v v level
if preInline
then do
tick (PreInlineUnconditionally x)
let rhs = mkSuspension env_v v
env' = extendIdSubst env_x x rhs
go env' triples acc
else do
v' <- simplValue (env_x `setStaticPart` env_v) v
(env', bind') <- completeBind env_x x x' v' level
go env' triples (bind' `consMaybe` acc)
(env0_x, xs') = enterScopes env_x (map fst xvs)
simplCut :: SimplEnv -> InValue -> StaticEnv -> InCont -> SimplM OutCommand
simplCut env_v (Type ty) _env_k cont
= assert (isReturnCont cont) $
let ty' = substTy env_v ty
in return $ valueCommand (Type ty')
simplCut env (Coercion co) _env_k cont
= assert (isReturnCont cont) $
let co' = substCo env co
in return $ valueCommand (Coercion co')
simplCut env (Cont k) _env_k cont
= assert (isReturnCont cont) $ do
k' <- simplCont env k
return $ valueCommand (Cont k')
simplCut env_v (Lam x c) env_k (App arg cont)
= do
tick (BetaReduction x)
(env_v', newBind) <- simplNonRec env_v x env_k arg NotTopLevel
c' <- simplCommand (bindCont env_v' env_k cont) c
return $ addLets (maybeToList newBind) c'
simplCut env_v (Lam x c) env_k cont
= do
let (env_v', x') = enterScope env_v x
c' <- simplCommand (zapCont env_v') c
simplContWith (env_v `setStaticPart` env_k) (Lam x' c') cont
simplCut env_v val env_k (Case x ty alts cont@Case {})
= do
tick (CaseOfCase x)
let contTy = ty `mkFunTy` contOuterType ty cont
contVar <- asContId <$> mkSysLocalM (fsLit "k") contTy
(env_k', newBind) <- simplNonRec (env_v `setStaticPart` env_k)
contVar env_k (Cont cont) NotTopLevel
let env_v' = env_k' `setStaticPart` staticPart env_v
comm <- simplCut env_v' val (staticPart env_k') (Case x ty alts $ Jump contVar)
return $ addLets (maybeToList newBind) comm
simplCut env_v val env_k (Case x _ alts cont)
| Just (pairs, body) <- matchCase env_v val alts
= do
tick (KnownBranch x)
(env', binds) <- go (env_v `setStaticPart` env_k) ((x, val) : pairs) []
comm <- simplCommand (env' `pushCont` cont) body
return $ addLets binds comm
where
go env [] acc
= return (env, reverse acc)
go env ((x, v) : pairs) acc
= do
(env', maybe_xv') <- simplNonRec env x (staticPart env_v) v NotTopLevel
go env' pairs (maybe_xv' `consMaybe` acc)
simplCut env_v (Var x) env_k cont
= case substId env_v x of
DoneId x'
-> simplContWith (env_v `setStaticPart` env_k) (Var x') cont
DoneVal v
-> simplCut (zapSubstEnvs env_v) v env_k cont
SuspVal stat v
-> simplCut (env_v `setStaticPart` stat) v env_k cont
simplCut env_v val@(Lit _) env_k cont
= simplContWith (env_v `setStaticPart` env_k) val cont
simplCut env_v (Cons ctor args) env_k cont
= do
args' <- mapM (simplValue env_v) args
simplContWith (env_v `setStaticPart` env_k) (Cons ctor args') cont
simplCut env_v (Compute c) env_k cont
= simplCommand (bindCont env_v env_k cont) c
matchCase :: SimplEnv -> InValue -> [InAlt]
-> Maybe ([(InVar, InValue)], InCommand)
matchCase _env_v (Lit lit) (Alt (LitAlt lit') xs body : _alts)
| assert (null xs) True
, lit == lit'
= Just ([], body)
matchCase _env_v (Cons ctor args) (Alt (DataAlt ctor') xs body : _alts)
| assert (length args == length xs) True
, ctor == ctor'
= Just (zip xs args, body)
matchCase env_v val (Alt DEFAULT xs body : alts)
| assert (null xs) True
, Nothing <- matchCase env_v val alts
= Just ([], body)
matchCase env_v val (_ : alts)
= matchCase env_v val alts
matchCase _ _ []
= Nothing
simplCont :: SimplEnv -> InCont -> SimplM OutCont
simplCont env cont
= go env cont (\k -> k)
where
go env (App arg cont) kc
= do
arg' <- simplValue env arg
go env cont (kc . App arg')
go env (Cast co cont) kc
= go env cont (kc . Cast co)
go env (Case x ty alts cont) kc
= doCase env'' x' ty' alts []
where
(env', cont') | Jump {} <- cont
= (bindCont env (staticPart env) cont, Return)
| otherwise
= (env, cont)
(env'', x') = enterScope env' x
ty' = substTy env'' ty
env_orig = env
doCase _env x ty [] alt_acc
= go env_orig cont' (kc . Case x ty (reverse alt_acc))
doCase env x ty (Alt con xs c : alts) alt_acc
= do
let (env', xs') = enterScopes env xs
c' <- simplCommand env' c
doCase env x ty alts (Alt con xs' c' : alt_acc)
go env (Tick ti cont) kc
= go env cont (kc . Tick ti)
go env (Jump x) kc
= case substId env x of
DoneId x'
-> return $ kc (Jump x')
DoneVal (Cont k)
-> go (zapSubstEnvs env) k kc
SuspVal stat (Cont k)
-> go (env `setStaticPart` stat) k kc
_
-> error "jump to non-continuation"
go env Return kc
| Just (env', cont) <- restoreEnv env
= go env' cont kc
| otherwise
= return $ kc Return
simplContWith :: SimplEnv -> OutValue -> InCont -> SimplM OutCommand
simplContWith env val cont
= mkCommand [] val <$> simplCont env cont
preInlineUnconditionally :: SimplEnv -> InVar -> StaticEnv -> InValue
-> TopLevelFlag -> SimplM Bool
preInlineUnconditionally _env_x x _env_rhs rhs level
= do
ans <- go <$> getMode <*> getDynFlags
return ans
where
go mode dflags
| not active = False
| not enabled = False
| TopLevel <- level, isBottomingId x = False
| isExportedId x = False
| isCoVar x = False
| otherwise = case idOccInfo x of
IAmDead -> True
OneOcc inLam True intCxt -> try_once inLam intCxt
_ -> False
where
active = isActive (sm_phase mode) act
act = idInlineActivation x
enabled = gopt Opt_SimplPreInlining dflags
try_once inLam intCxt
| not inLam = isNotTopLevel level || early_phase
| otherwise = intCxt && canInlineValInLam rhs
canInlineInLam c
| Just v <- asValueCommand c = canInlineValInLam v
| otherwise = False
canInlineValInLam (Lit _) = True
canInlineValInLam (Lam x c) = isRuntimeVar x || canInlineInLam c
canInlineValInLam (Compute c) = canInlineInLam c
canInlineValInLam _ = False
early_phase = case sm_phase mode of
Phase 0 -> False
_ -> True
postInlineUnconditionally :: SimplEnv -> OutVar -> OutValue -> TopLevelFlag
-> SimplM Bool
postInlineUnconditionally _env x v level
= do
ans <- go <$> getMode <*> getDynFlags
return ans
where
go mode dflags
| not active = False
| isWeakLoopBreaker occ_info = False
| isExportedId x = False
| isTopLevel level = False
| isTrivialValue v = True
| otherwise
= case occ_info of
OneOcc in_lam _one_br int_cxt
-> smallEnoughToInline dflags unfolding
&& (not in_lam ||
(isCheapUnfolding unfolding && int_cxt))
IAmDead -> True
_ -> False
where
occ_info = idOccInfo x
active = isActive (sm_phase mode) (idInlineActivation x)
unfolding = idUnfolding x