module Language.HERMIT.Primitive.GHC where
import GhcPlugins hiding (empty)
import qualified OccurAnal
import Control.Arrow
import Control.Monad
import qualified Data.Map as Map
import Data.List (nub, mapAccumL)
import Language.HERMIT.Primitive.Navigation
import Language.HERMIT.CoreExtra
import Language.HERMIT.Kure
import Language.HERMIT.Monad
import Language.HERMIT.External
import Language.HERMIT.Context
import qualified Language.HERMIT.GHC as GHC
import qualified Language.Haskell.TH as TH
import Prelude hiding (exp)
externals :: [External]
externals =
[ external "let-subst" (promoteExprR letSubstR :: RewriteH Core)
[ "Let substitution [via GHC]"
, "let x = E1 in E2 ==> E2[E1/x], fails otherwise"
, "only matches non-recursive lets" ] .+ Deep
, external "safe-let-subst" (promoteExprR safeLetSubstR :: RewriteH Core)
[ "Safe let substitution [via GHC]"
, "let x = E1 in E2, safe to inline without duplicating work ==> E2[E1/x],"
, "fails otherwise"
, "only matches non-recursive lets" ] .+ Deep .+ Eval .+ Bash
, external "safe-let-subst-plus" (promoteExprR safeLetSubstPlusR :: RewriteH Core)
[ "Safe let substitution [via GHC]"
, "let { x = E1, ... } in E2, "
, " where safe to inline without duplicating work ==> E2[E1/x,...],"
, "fails otherwise"
, "only matches non-recursive lets" ] .+ Deep .+ Eval
, external "free-ids" (promoteExprT freeIdsQuery :: TranslateH Core String)
[ "List the free identifiers in this expression [via GHC]" ] .+ Query .+ Deep
, external "deshadow-binds" (promoteProgramR deShadowBindsR :: RewriteH Core)
[ "Deshadow a program " ] .+ Deep
, external "apply-rule" (promoteExprR . rules :: String -> RewriteH Core)
[ "apply a named GHC rule" ] .+ Shallow
, external "apply-rule" (rules_help :: TranslateH Core String)
[ "list rules that can be used" ] .+ Query
, external "compare-values" compareValues
["compare's the rhs of two values"] .+ Query .+ Predicate
, external "add-rule" (\ rule_name id_name -> promoteModGutsR (addCoreBindAsRule rule_name id_name))
["add-rule \"rule-name\" <id> -- adds a new rule that freezes the right hand side of the <id>"]
.+ Introduce
, external "cast-elim" (promoteExprR castElimination)
["cast-elim removes casts"]
.+ Shallow .+ TODO
, external "add-rule" (\ rule_name id_name -> promoteModGutsR (addCoreBindAsRule rule_name id_name))
["add-rule \"rule-name\" <id> -- adds a new rule that freezes the right hand side of the <id>"]
, external "flatten-module" (promoteModGutsR flattenModule :: RewriteH Core)
["Flatten all the top-level binding groups into a single recursive binding group.",
"This can be useful if you intend to appply GHC RULES."]
, external "occur-analysis" (promoteExprR occurAnalyseExprR :: RewriteH Core)
["Performs dependency anlaysis on a CoreExpr.",
"This can be useful to simplify a recursive let to a non-recursive let."] .+ Deep
]
substR :: Id -> CoreExpr -> RewriteH Core
substR b e = setFailMsg "Can only perform substitution on Expr or CoreProgram forms." $
promoteExprR (substExprR b e) <+ promoteProgramR (substTopBindR b e)
substExprR :: Id -> CoreExpr -> RewriteH CoreExpr
substExprR b e = contextfreeT $ \ exp ->
let emptySub = mkEmptySubst (mkInScopeSet (exprFreeVars (Let (NonRec b e) exp)))
sub = if (isTyVar b)
then case e of
(Type bty) -> Just $ extendTvSubst emptySub b bty
(Var x) -> Just $ extendTvSubst emptySub b (mkTyVarTy x)
_ -> Nothing
else Just $ extendSubst emptySub b e
in
case sub of
Just sub' -> return $ substExpr (text "substR") sub' exp
Nothing -> fail "substExprR: Id argument is a TyVar, but the expression is not a Type."
substTopBindR :: Id -> CoreExpr -> RewriteH CoreProgram
substTopBindR b e = contextfreeT $ \ binds ->
let emptySub = emptySubst
sub = if (isTyVar b)
then case e of
(Type bty) -> Just $ extendTvSubst emptySub b bty
(Var x) -> Just $ extendTvSubst emptySub b (mkTyVarTy x)
_ -> Nothing
else Just $ extendSubst emptySub b e
in
case sub of
Just sub' -> return $ snd (mapAccumL substBind sub' binds)
Nothing -> fail "substTopBindR: Id argument is a TyVar, but the expression is not a Type."
letSubstR :: RewriteH CoreExpr
letSubstR = prefixFailMsg "Let substition failed: " $
rewrite $ \ ctx exp -> case occurAnalyseExpr exp of
Let (NonRec b be) e -> apply (substExprR b be) ctx e
_ -> fail "expression is not a non-recursive Let."
letSubstNR :: Int -> RewriteH Core
letSubstNR 0 = idR
letSubstNR n = childR 1 (letSubstNR (n 1)) >>> promoteExprR letSubstR
safeLetSubstR :: RewriteH CoreExpr
safeLetSubstR = prefixFailMsg "Safe let-substition failed: " $
translate $ \ env exp ->
let
safeBind (Var {}) = True
safeBind (Lam {}) = True
safeBind e@(App {}) =
case collectArgs e of
(Var f,args) -> arityOf env f > length (filter (not . isTypeArg) args)
(other,args) -> case collectBinders other of
(bds,_) -> length bds > length args
safeBind _ = False
safeSubst NoOccInfo = False
safeSubst IAmDead = True
safeSubst (OneOcc inLam oneBr _) = not inLam && oneBr
safeSubst _ = False
in case occurAnalyseExpr exp of
Let (NonRec b _) _
| isTyVar b -> apply letSubstR env exp
Let (NonRec b be) _
| isId b && (safeBind be || safeSubst (occInfo (idInfo b)))
-> apply letSubstR env exp
| otherwise -> fail "safety critera not met."
_ -> fail "expression is not a non-recursive Let."
safeLetSubstPlusR :: RewriteH CoreExpr
safeLetSubstPlusR = tryR (letT idR safeLetSubstPlusR Let) >>> safeLetSubstR
freeIdsQuery :: TranslateH CoreExpr String
freeIdsQuery = do
dynFlags <- constT getDynFlags
frees <- freeIdsT
return $ "Free identifiers are: " ++ showVars dynFlags frees
showVar :: DynFlags -> Var -> String
showVar dynFlags = show . showPpr dynFlags
showVars :: DynFlags -> [Var] -> String
showVars dynFlags = show . map (showPpr dynFlags)
freeIdsT :: TranslateH CoreExpr [Id]
freeIdsT = arr coreExprFreeIds
freeVarsT :: TranslateH CoreExpr [Var]
freeVarsT = arr coreExprFreeVars
coreExprFreeVars :: CoreExpr -> [Var]
coreExprFreeVars = uniqSetToList . exprFreeVars
coreExprFreeIds :: CoreExpr -> [Id]
coreExprFreeIds = uniqSetToList . exprFreeIds
deShadowBindsR :: RewriteH CoreProgram
deShadowBindsR = arr deShadowBinds
rulesToEnv :: [CoreRule] -> Map.Map String (RewriteH CoreExpr)
rulesToEnv rs = Map.fromList
[ ( unpackFS (ruleName r), rulesToRewriteH [r] )
| r <- rs
]
rulesToRewriteH :: [CoreRule] -> RewriteH CoreExpr
rulesToRewriteH rs = translate $ \ c e -> do
(Var fn,args) <- return $ collectArgs e
let in_scope = mkInScopeSet (mkVarEnv [ (v,v) | v <- coreExprFreeVars e ])
_rough_args = map (const Nothing) args
case lookupRule (const True) (const NoUnfolding) in_scope fn args rs of
Nothing -> fail "rule not matched"
Just (rule, exp) -> do
let e' = mkApps exp (drop (ruleArity rule) args)
ifM (liftM (and . map (inScope c)) $ apply freeVarsT c e')
(return e')
(fail $ unlines ["Resulting expression after rule application contains variables that are not in scope."
,"This can probably be solved by running the flatten-module command at the top level."])
inScope :: Context -> Id -> Bool
inScope c i = maybe (case unfoldingInfo (idInfo i) of
CoreUnfolding {} -> True
_ -> False)
(const True)
(lookupHermitBinding i c)
rules :: String -> RewriteH CoreExpr
rules r = do
theRules <- getHermitRules
case lookup r theRules of
Nothing -> fail $ "failed to find rule: " ++ show r
Just rr -> rulesToRewriteH rr
getHermitRules :: (Generic a ~ Core) => TranslateH a [(String, [CoreRule])]
getHermitRules = translate $ \ env _e -> do
rb <- liftCoreM getRuleBase
let other_rules = [ rule
| top_bnds <- mg_binds (hermitModGuts env)
, bnd <- case top_bnds of
Rec bnds -> map fst bnds
NonRec b _ -> [b]
, rule <- idCoreRules bnd
]
return [ ( unpackFS (ruleName r), [r] )
| r <- mg_rules (hermitModGuts env) ++ other_rules ++ concat (nameEnvElts rb)
]
rules_help :: TranslateH Core String
rules_help = do
rulesEnv <- getHermitRules
dynFlags <- constT getDynFlags
return $ (show (map fst rulesEnv) ++ "\n") ++
showSDoc dynFlags (pprRulesForUser $ concatMap snd rulesEnv)
makeRule :: String -> Id -> CoreExpr -> CoreRule
makeRule rule_name nm = mkRule True
False
(mkFastString rule_name)
NeverActive
(varName nm)
[]
[]
addCoreBindAsRule :: String -> TH.Name -> RewriteH ModGuts
addCoreBindAsRule rule_name nm = contextfreeT $ \ modGuts ->
case [ (v,e)
| top_bnds <- mg_binds modGuts
, (v,e) <- case top_bnds of
Rec bnds -> bnds
NonRec b e -> [(b,e)]
, nm `GHC.cmpTHName2Id` v
] of
[] -> fail $ "can not find binding " ++ show nm
[(v,e)] -> return $ modGuts { mg_rules = mg_rules modGuts
++ [makeRule rule_name v e]
}
_ -> fail $ "found multiple bindings for " ++ show nm
flattenModule :: RewriteH ModGuts
flattenModule = modGutsR mergeBinds
mergeBinds :: RewriteH CoreProgram
mergeBinds = contextfreeT $ \ binds ->
let allbinds = foldr listOfBinds [] binds
nodups = nub $ map fst allbinds
in
if (length allbinds == length nodups)
then return $ [Rec allbinds]
else fail "Module top level bindings contain multiple occurances of a name"
where listOfBinds cb others = case cb of
(NonRec b e) -> (b, e) : others
(Rec bds) -> bds ++ others
occurAnalyseExpr :: CoreExpr -> CoreExpr
occurAnalyseExpr = OccurAnal.occurAnalyseExpr
occurAnalyseExprR :: RewriteH CoreExpr
occurAnalyseExprR = contextfreeT $ \ exp -> return (occurAnalyseExpr exp)
exprEqual :: CoreExpr -> CoreExpr -> Bool
exprEqual e1 e2 = eqExpr (mkInScopeSet $ exprsFreeVars [e1, e2]) e1 e2
bindEqual :: CoreBind -> CoreBind -> Maybe Bool
bindEqual (Rec ps1) (Rec ps2) = Just $ all2 (eqExprX id_unf env') rs1 rs2
where
id_unf _ = noUnfolding
(bs1,rs1) = unzip ps1
(bs2,rs2) = unzip ps2
env = mkInScopeSet $ exprsFreeVars (rs1 ++ rs2)
env' = rnBndrs2 (mkRnEnv2 env) bs1 bs2
bindEqual (NonRec _ e1) (NonRec _ e2) = Just $ exprEqual e1 e2
bindEqual _ _ = Nothing
coreEqual :: Core -> Core -> Maybe Bool
coreEqual (ExprCore e1) (ExprCore e2) = Just $ e1 `exprEqual` e2
coreEqual (BindCore b1) (BindCore b2) = b1 `bindEqual` b2
coreEqual (DefCore dc1) (DefCore dc2) = defToRecBind [dc1] `bindEqual` defToRecBind [dc2]
coreEqual _ _ = Nothing
compareValues :: TH.Name -> TH.Name -> TranslateH Core ()
compareValues n1 n2 = do
p1 <- onePathToT (namedBinding n1)
p2 <- onePathToT (namedBinding n2)
e1 :: Core <- pathT p1 idR
e2 :: Core <- pathT p2 idR
case e1 `coreEqual` e2 of
Nothing -> fail $ show n1 ++ " and " ++ show n2 ++ " are incomparable"
Just False -> fail $ show n1 ++ " and " ++ show n2 ++ " are not equal"
Just True -> return ()
arityOf:: Context -> Id -> Int
arityOf env nm =
case lookupHermitBinding nm env of
Nothing -> idArity nm
Just (LAM {}) -> 0
Just (BIND _ _ e) -> GHC.exprArity e
Just (CASE _ e _) -> GHC.exprArity e
castElimination :: RewriteH CoreExpr
castElimination = do
Cast e _ <- idR
return e