module Language.HERMIT.Primitive.GHC
(
externals
, coreExprFreeIds
, coreExprFreeVars
, freeIdsT
, freeVarsT
, altFreeVarsT
, altFreeVarsExclWildT
, substR
, substExprR
, letSubstR
, safeLetSubstR
, safeLetSubstPlusR
, exprEqual
, exprsEqual
, coreEqual
, inScope
, showVars
, rule
, rules
, lintExprT
, lintProgramT
, lintModuleT
, equivalent
)
where
import GhcPlugins
import qualified Bag
import qualified CoreLint
import qualified OccurAnal
import IOEnv
import Control.Arrow
import Control.Monad
import Data.List (intercalate,mapAccumL,(\\))
import Data.Map (keys)
import Language.HERMIT.Core
import Language.HERMIT.Context
import Language.HERMIT.Monad
import Language.HERMIT.Kure
import Language.HERMIT.External
import Language.HERMIT.GHC
import Language.HERMIT.Primitive.Navigation hiding (externals)
import qualified Language.Haskell.TH as TH
externals :: [External]
externals =
[ external "info" (info :: TranslateH Core String)
[ "display information about the current node." ]
, external "let-subst" (promoteExprR letSubstR :: RewriteH Core)
[ "Let substitution"
, "(let x = e1 in e2) ==> (e2[e1/x])"
, "x must not be free in e1." ] .+ Deep
, external "safe-let-subst" (promoteExprR safeLetSubstR :: RewriteH Core)
[ "Safe let substitution"
, "let x = e1 in e2, safe to inline without duplicating work ==> e2[e1/x],"
, "x must not be free in e1." ] .+ Deep .+ Eval .+ Bash
, external "safe-let-subst-plus" (promoteExprR safeLetSubstPlusR :: RewriteH Core)
[ "Safe let substitution"
, "let { x = e1, ... } in e2, "
, " where safe to inline without duplicating work ==> e2[e1/x,...],"
, "only matches non-recursive lets" ] .+ Deep .+ Eval
, external "free-ids" (promoteExprT freeIdsQuery :: TranslateH Core String)
[ "List the free identifiers in this expression." ] .+ Query .+ Deep
, external "deshadow-prog" (promoteProgR deShadowProgR :: RewriteH Core)
[ "Deshadow a program." ] .+ Deep
, external "apply-rule" (promoteExprR . rule :: String -> RewriteH Core)
[ "apply a named GHC rule" ] .+ Shallow
, external "apply-rule" (rules_help :: TranslateH Core String)
[ "list rules that can be used" ] .+ Query
, external "apply-rules" (promoteExprR . rules :: [String] -> RewriteH Core)
[ "apply named GHC rules, succeeds if any of the rules succeed" ] .+ Shallow
, external "compare-values" compareValues
["compare the rhs of two values."] .+ Query .+ Predicate
, external "add-rule" (\ rule_name id_name -> promoteModGutsR (addCoreBindAsRule rule_name id_name))
["add-rule \"rule-name\" <id> -- adds a new rule that freezes the right hand side of the <id>"]
.+ Introduce
, external "occur-analysis" (promoteExprR occurAnalyseExprR :: RewriteH Core)
["Performs dependency anlaysis on a CoreExpr.",
"This can be useful to simplify a recursive let to a non-recursive let."] .+ Deep
, external "lintExpr" (promoteExprT lintExprT :: TranslateH Core String)
["Runs GHC's Core Lint, which typechecks the current expression."
,"Note: this can miss several things that a whole-module core lint will find."
,"For instance, running this on the RHS of a binding, the type of the RHS will"
,"not be checked against the type of the binding. Running on the whole let expression"
,"will catch that however."] .+ Deep .+ Debug .+ Query
, external "lintProg" (promoteProgT lintProgramT :: TranslateH Core String)
["Runs GHC's Core Lint, which typechecks the top level list of bindings."] .+ Deep .+ Debug .+ Query
, external "lintModule" (promoteModGutsT lintModuleT :: TranslateH Core String)
["Runs GHC's Core Lint, which typechecks the current module."] .+ Deep .+ Debug .+ Query
]
substR :: Var -> CoreExpr -> RewriteH Core
substR v e = setFailMsg "Can only perform substitution on expressions or programs." $
promoteExprR (substExprR v e) <+ promoteProgR (substTopBindR v e)
substExprR :: Var -> CoreExpr -> RewriteH CoreExpr
substExprR v e = contextfreeT $ \ expr -> do
let emptySub = mkEmptySubst (mkInScopeSet (exprFreeVars (Let (NonRec v e) expr)))
return $ substExpr (text "substR") (extendSubst emptySub v e) expr
substTopBindR :: Var -> CoreExpr -> RewriteH CoreProg
substTopBindR v e = contextfreeT $ \ p -> do
let emptySub = emptySubst
return $ bindsToProg $ snd (mapAccumL substBind (extendSubst emptySub v e) (progToBinds p))
letSubstR :: RewriteH CoreExpr
letSubstR = prefixFailMsg "Let substition failed: " $
rewrite $ \ c expr -> case occurAnalyseExpr expr of
Let (NonRec b be) e -> apply (substExprR b be) c e
_ -> fail "expression is not a non-recursive Let."
safeLetSubstR :: RewriteH CoreExpr
safeLetSubstR = prefixFailMsg "Safe let-substition failed: " $
translate $ \ env expr ->
let
safeBind (Var {}) = True
safeBind (Lam {}) = True
safeBind e@(App {}) =
case collectArgs e of
(Var f,args) -> arityOf env f > length (filter (not . isTyCoArg) args)
(other,args) -> case collectBinders other of
(bds,_) -> length bds > length args
safeBind _ = False
safeSubst NoOccInfo = False
safeSubst IAmDead = True
safeSubst (OneOcc inLam oneBr _) = not inLam && oneBr
safeSubst _ = False
in case occurAnalyseExpr expr of
Let (NonRec b _) _
| isTyVar b -> apply letSubstR env expr
Let (NonRec b be) _
| isId b && (safeBind be || safeSubst (occInfo (idInfo b)))
-> apply letSubstR env expr
| otherwise -> fail "safety critera not met."
_ -> fail "expression is not a non-recursive Let."
safeLetSubstPlusR :: RewriteH CoreExpr
safeLetSubstPlusR = tryR (letT idR safeLetSubstPlusR Let) >>> safeLetSubstR
info :: TranslateH Core String
info = translate $ \ c core -> do
dynFlags <- getDynFlags
let pa = "Path: " ++ show (absPath c)
node = "Node: " ++ coreNode core
con = "Constructor: " ++ coreConstructor core
bds = "Bindings in Scope: " ++ show (map unqualifiedVarName $ boundVars c)
expExtra = case core of
ExprCore e -> ["Type or Kind: " ++ showExprTypeOrKind dynFlags e] ++
["Free Variables: " ++ showVars (coreExprFreeVars e)]
_ -> []
return (intercalate "\n" $ [pa,node,con,bds] ++ expExtra)
showExprTypeOrKind :: DynFlags -> CoreExpr -> String
showExprTypeOrKind dynFlags = showPpr dynFlags . exprTypeOrKind
coreNode :: Core -> String
coreNode (GutsCore _) = "Module"
coreNode (ProgCore _) = "Program"
coreNode (BindCore _) = "Binding Group"
coreNode (DefCore _) = "Recursive Definition"
coreNode (ExprCore _) = "Expression"
coreNode (AltCore _) = "Case Alternative"
coreConstructor :: Core -> String
coreConstructor (GutsCore _) = "ModGuts"
coreConstructor (ProgCore prog) = case prog of
ProgNil -> "ProgNil"
ProgCons _ _ -> "ProgCons"
coreConstructor (BindCore bnd) = case bnd of
Rec _ -> "Rec"
NonRec _ _ -> "NonRec"
coreConstructor (DefCore _) = "Def"
coreConstructor (AltCore _) = "(,,)"
coreConstructor (ExprCore expr) = case expr of
Var _ -> "Var"
Type _ -> "Type"
Lit _ -> "Lit"
App _ _ -> "App"
Lam _ _ -> "Lam"
Let _ _ -> "Let"
Case _ _ _ _ -> "Case"
Cast _ _ -> "Cast"
Tick _ _ -> "Tick"
Coercion _ -> "Coercion"
freeIdsQuery :: TranslateH CoreExpr String
freeIdsQuery = do frees <- freeIdsT
return $ "Free identifiers are: " ++ showVars frees
showVars :: [Var] -> String
showVars = show . map var2String
freeIdsT :: TranslateH CoreExpr [Id]
freeIdsT = arr coreExprFreeIds
freeVarsT :: TranslateH CoreExpr [Var]
freeVarsT = arr coreExprFreeVars
coreExprFreeVars :: CoreExpr -> [Var]
coreExprFreeVars = uniqSetToList . exprFreeVars
coreExprFreeIds :: CoreExpr -> [Id]
coreExprFreeIds = uniqSetToList . exprFreeIds
altFreeVarsT :: TranslateH CoreAlt [Var]
altFreeVarsT = altT freeVarsT (\ _ vs fvs -> fvs \\ vs)
altFreeVarsExclWildT :: TranslateH CoreAlt (Id -> [Var])
altFreeVarsExclWildT = altT freeVarsT (\ _ vs fvs wild -> fvs \\ (wild : vs))
deShadowProgR :: RewriteH CoreProg
deShadowProgR = arr (bindsToProg . deShadowBinds . progToBinds)
rulesToRewriteH :: [CoreRule] -> RewriteH CoreExpr
rulesToRewriteH rs = translate $ \ c e -> do
(Var fn,args) <- return $ collectArgs e
let in_scope = mkInScopeSet (mkVarEnv [ (v,v) | v <- coreExprFreeVars e ])
_rough_args = map (const Nothing) args
#if __GLASGOW_HASKELL__ > 706
dflags <- getDynFlags
case lookupRule dflags (const True) (const NoUnfolding) in_scope fn args [r | r <- rs, ru_fn r == idName fn] of
#else
case lookupRule (const True) (const NoUnfolding) in_scope fn args [r | r <- rs, ru_fn r == idName fn] of
#endif
Nothing -> fail "rule not matched"
Just (r, expr) -> do
let e' = mkApps expr (drop (ruleArity r) args)
ifM (liftM (and . map (inScope c)) $ apply freeVarsT c e')
(return e')
(fail $ unlines ["Resulting expression after rule application contains variables that are not in scope."
,"This can probably be solved by running the flatten-module command at the top level."])
inScope :: HermitC -> Id -> Bool
inScope c v = (v `boundIn` c) ||
case unfoldingInfo (idInfo v) of
CoreUnfolding {} -> True
DFunUnfolding {} -> True
_ -> False
rule :: String -> RewriteH CoreExpr
rule r = do
theRules <- getHermitRules
case lookup r theRules of
Nothing -> fail $ "failed to find rule: " ++ show r
Just rr -> rulesToRewriteH rr
rules :: [String] -> RewriteH CoreExpr
rules = orR . map rule
getHermitRules :: TranslateH a [(String, [CoreRule])]
getHermitRules = translate $ \ env _ -> do
rb <- liftCoreM getRuleBase
hscEnv <- liftCoreM getHscEnv
rb' <- liftM eps_rule_base $ liftIO $ runIOEnv () $ readMutVar (hsc_EPS hscEnv)
let other_rules = [ r
| top_bnds <- mg_binds (hermitModGuts env)
, bnd <- case top_bnds of
Rec bnds -> map fst bnds
NonRec b _ -> [b]
, r <- idCoreRules bnd
]
return [ ( unpackFS (ruleName r), [r] )
| r <- mg_rules (hermitModGuts env) ++ other_rules ++ concat (nameEnvElts rb) ++ concat (nameEnvElts rb')
]
rules_help :: TranslateH Core String
rules_help = do
rulesEnv <- getHermitRules
dynFlags <- constT getDynFlags
return $ (show (map fst rulesEnv) ++ "\n") ++
showSDoc dynFlags (pprRulesForUser $ concatMap snd rulesEnv)
makeRule :: String -> Id -> CoreExpr -> CoreRule
makeRule rule_name nm = mkRule True
False
(mkFastString rule_name)
NeverActive
(varName nm)
[]
[]
addCoreBindAsRule :: String -> TH.Name -> RewriteH ModGuts
addCoreBindAsRule rule_name nm = contextfreeT $ \ modGuts ->
case [ (v,e)
| top_bnds <- mg_binds modGuts
, (v,e) <- case top_bnds of
Rec bnds -> bnds
NonRec b e -> [(b,e)]
, nm `cmpTHName2Var` v
] of
[] -> fail $ "cannot find binding " ++ show nm
[(v,e)] -> return $ modGuts { mg_rules = mg_rules modGuts
++ [makeRule rule_name v e]
}
_ -> fail $ "found multiple bindings for " ++ show nm
occurAnalyseExpr :: CoreExpr -> CoreExpr
occurAnalyseExpr = OccurAnal.occurAnalyseExpr
occurAnalyseExprR :: RewriteH CoreExpr
occurAnalyseExprR = arr occurAnalyseExpr
exprEqual :: CoreExpr -> CoreExpr -> Bool
exprEqual e1 e2 = eqExpr (mkInScopeSet $ exprsFreeVars [e1, e2]) e1 e2
exprsEqual :: [CoreExpr] -> Bool
exprsEqual = equivalent exprEqual
equivalent :: (a -> a -> Bool) -> [a] -> Bool
equivalent _ [] = True
equivalent eq (x:xs) = all (eq x) xs
bindEqual :: CoreBind -> CoreBind -> Maybe Bool
bindEqual (Rec ps1) (Rec ps2) = Just $ all2 (eqExprX id_unf env') rs1 rs2
where
id_unf _ = noUnfolding
(bs1,rs1) = unzip ps1
(bs2,rs2) = unzip ps2
env = mkInScopeSet $ exprsFreeVars (rs1 ++ rs2)
env' = rnBndrs2 (mkRnEnv2 env) bs1 bs2
bindEqual (NonRec _ e1) (NonRec _ e2) = Just $ exprEqual e1 e2
bindEqual _ _ = Nothing
coreEqual :: Core -> Core -> Maybe Bool
coreEqual (ExprCore e1) (ExprCore e2) = Just $ e1 `exprEqual` e2
coreEqual (BindCore b1) (BindCore b2) = b1 `bindEqual` b2
coreEqual (DefCore dc1) (DefCore dc2) = defsToRecBind [dc1] `bindEqual` defsToRecBind [dc2]
coreEqual _ _ = Nothing
compareValues :: TH.Name -> TH.Name -> TranslateH Core ()
compareValues n1 n2 = do
p1 <- onePathToT (namedBinding n1)
p2 <- onePathToT (namedBinding n2)
e1 <- pathT p1 idR
e2 <- pathT p2 idR
case e1 `coreEqual` e2 of
Nothing -> fail $ show n1 ++ " and " ++ show n2 ++ " are incomparable."
Just False -> fail $ show n1 ++ " and " ++ show n2 ++ " are not equal."
Just True -> return ()
arityOf :: HermitC -> Id -> Int
arityOf env nm =
case lookupHermitBinding nm env of
Nothing -> idArity nm
Just (LAM {}) -> 0
Just (BIND _ _ e) -> exprArity e
Just (CASE _ e _) -> exprArity e
lintModuleT :: TranslateH ModGuts String
lintModuleT = arr (bindsToProg . mg_binds) >>> lintProgramT
lintProgramT :: TranslateH CoreProg String
lintProgramT = do
bnds <- arr progToBinds
dflags <- constT getDynFlags
let (warns, errs) = CoreLint.lintCoreBindings bnds
dumpSDocs endMsg = Bag.foldBag (\d r -> d ++ ('\n':r)) (showSDoc dflags) endMsg
if Bag.isEmptyBag errs
then return $ dumpSDocs "Core Lint Passed" warns
else fail $ dumpSDocs "Core Lint Failed" errs
lintExprT :: TranslateH CoreExpr String
lintExprT = translate $ \ c e -> do
dflags <- getDynFlags
maybe (return "Core Lint Passed") (fail . showSDoc dflags)
$ CoreLint.lintUnfolding noSrcLoc (keys $ hermitBindings c) e