{-# LANGUAGE CPP, FlexibleContexts #-} module HERMIT.Dictionary.Rules ( -- * GHC Rewrite Rules and Specialisation externals -- ** Rules , RuleNameString , ruleR , rulesR , ruleToEqualityT , ruleNameToEqualityT , getSingletonHermitRuleT -- , verifyCoreRuleT -- , verifyRuleT -- , ruleLhsIntroR -- , ruleRhsIntroR -- ** Specialisation , specConstrR ) where import IOEnv hiding (liftIO) import qualified SpecConstr import qualified Specialise import Control.Arrow import Control.Monad import Data.Function (on) import Data.List (deleteFirstsBy,intercalate) import HERMIT.Core import HERMIT.Context import HERMIT.Monad import HERMIT.Kure import HERMIT.External import HERMIT.GHC import HERMIT.Dictionary.Common (findIdT,inScope,callT) import HERMIT.Dictionary.GHC (dynFlagsT) -- import HERMIT.Dictionary.Induction import HERMIT.Dictionary.Kure (anyCallR) import HERMIT.Dictionary.Reasoning hiding (externals) import HERMIT.Dictionary.Unfold (cleanupUnfoldR) ------------------------------------------------------------------------ -- | Externals that reflect GHC functions, or are derived from GHC functions. externals :: [External] externals = [ external "rules-help-list" (rulesHelpListT :: TranslateH CoreTC String) [ "List all the rules in scope." ] .+ Query , external "rule-help" (ruleHelpT :: RuleNameString -> TranslateH CoreTC String) [ "Display details on the named rule." ] .+ Query , external "apply-rule" (promoteExprR . ruleR :: RuleNameString -> RewriteH Core) [ "Apply a named GHC rule" ] .+ Shallow , external "apply-rules" (promoteExprR . rulesR :: [RuleNameString] -> RewriteH Core) [ "Apply named GHC rules, succeed if any of the rules succeed" ] .+ Shallow , external "add-rule" ((\ rule_name id_name -> promoteModGutsR (addCoreBindAsRule rule_name id_name)) :: String -> String -> RewriteH Core) [ "add-rule \"rule-name\" -- adds a new rule that freezes the right hand side of the "] .+ Introduce , external "unfold-rule" ((\ nm -> promoteExprR (ruleR nm >>> cleanupUnfoldR)) :: String -> RewriteH Core) [ "Unfold a named GHC rule" ] .+ Deep .+ Context .+ TODO -- TODO: does not work with rules with no arguments , external "spec-constr" (promoteModGutsR specConstrR :: RewriteH Core) [ "Run GHC's SpecConstr pass, which performs call pattern specialization."] .+ Deep , external "specialise" (promoteModGutsR specialise :: RewriteH Core) [ "Run GHC's specialisation pass, which performs type and dictionary specialisation."] .+ Deep ] ------------------------------------------------------------------------ {- lookupRule :: (Activation -> Bool) -- When rule is active -> IdUnfoldingFun -- When Id can be unfolded -> InScopeSet -> Id -> [CoreExpr] -> [CoreRule] -> Maybe (CoreRule, CoreExpr) GHC HEAD: type InScopeEnv = (InScopeSet, IdUnfoldingFun) lookupRule :: DynFlags -> InScopeEnv -> (Activation -> Bool) -- When rule is active -> Id -> [CoreExpr] -> [CoreRule] -> Maybe (CoreRule, CoreExpr) -} -- Neil: Commented this out as it's not (currently) used. -- rulesToEnv :: [CoreRule] -> Map.Map String (Rewrite c m CoreExpr) -- rulesToEnv rs = Map.fromList -- [ ( unpackFS (ruleName r), rulesToRewrite c m [r] ) -- | r <- rs -- ] type RuleNameString = String -- TODO: deprecate this (and related functions) in favor of 'biRuleUnsafeR' #if __GLASGOW_HASKELL__ > 706 rulesToRewriteH :: (ReadBindings c, HasDynFlags m, MonadCatch m) => [CoreRule] -> Rewrite c m CoreExpr #else rulesToRewriteH :: (ReadBindings c, MonadCatch m) => [CoreRule] -> Rewrite c m CoreExpr #endif rulesToRewriteH rs = prefixFailMsg "RulesToRewrite failed: " $ withPatFailMsg "rule not matched." $ do (Var fn, args) <- callT translate $ \ c e -> do let in_scope = mkInScopeSet (mkVarEnv [ (v,v) | v <- varSetElems (localFreeVarsExpr e) ]) #if __GLASGOW_HASKELL__ > 706 dflags <- getDynFlags case lookupRule dflags (in_scope, const NoUnfolding) (const True) fn args [r | r <- rs, ru_fn r == idName fn] of #else case lookupRule (const True) (const NoUnfolding) in_scope fn args [r | r <- rs, ru_fn r == idName fn] of #endif Nothing -> fail "rule not matched" Just (r, expr) -> do let e' = mkApps expr (drop (ruleArity r) args) if all (inScope c) $ varSetElems $ localFreeVarsExpr e' -- TODO: The problem with this check, is that it precludes the case where this is an intermediate transformation. I can imagine situations where some variables would be out-of-scope at this point, but in scope again after a subsequent transformation. then return e' else fail $ unlines ["Resulting expression after rule application contains variables that are not in scope." ,"This can probably be solved by running the flatten-module command at the top level."] -- | Lookup a rule and attempt to construct a corresponding rewrite. ruleR :: (ReadBindings c, HasCoreRules c) => RuleNameString -> Rewrite c HermitM CoreExpr ruleR r = do theRules <- getHermitRulesT case lookup r theRules of Nothing -> fail $ "failed to find rule: " ++ show r Just rr -> rulesToRewriteH rr rulesR :: (ReadBindings c, HasCoreRules c) => [RuleNameString] -> Rewrite c HermitM CoreExpr rulesR = orR . map ruleR getHermitRulesT :: HasCoreRules c => Translate c HermitM a [(RuleNameString, [CoreRule])] getHermitRulesT = contextonlyT $ \ c -> do rb <- liftCoreM getRuleBase hscEnv <- liftCoreM getHscEnv rb' <- liftM eps_rule_base $ liftIO $ runIOEnv () $ readMutVar (hsc_EPS hscEnv) return [ ( unpackFS (ruleName r), [r] ) | r <- hermitCoreRules c ++ concat (nameEnvElts rb) ++ concat (nameEnvElts rb') ] getHermitRuleT :: HasCoreRules c => RuleNameString -> Translate c HermitM a [CoreRule] getHermitRuleT name = do rulesEnv <- getHermitRulesT case filter ((name ==) . fst) rulesEnv of [] -> fail ("Rule \"" ++ name ++ "\" not found.") [(_,rus)] -> return rus _ -> fail ("Rule name \"" ++ name ++ "\" is ambiguous.") getSingletonHermitRuleT :: HasCoreRules c => RuleNameString -> Translate c HermitM a CoreRule getSingletonHermitRuleT name = do rus <- getHermitRuleT name case rus of [] -> fail "No rules with that name." [ru] -> return ru _ -> fail "Multiple rules with that name." rulesHelpListT :: HasCoreRules c => Translate c HermitM a String rulesHelpListT = do rulesEnv <- getHermitRulesT return (intercalate "\n" $ map fst rulesEnv) ruleHelpT :: HasCoreRules c => RuleNameString -> Translate c HermitM a String ruleHelpT name = showSDoc <$> dynFlagsT <*> (pprRulesForUser <$> getHermitRuleT name) -- Too much information. -- rulesHelpT :: HasCoreRules c => Translate c HermitM a String -- rulesHelpT = do -- rulesEnv <- getHermitRulesT -- dynFlags <- dynFlagsT -- return $ (show (map fst rulesEnv) ++ "\n") ++ -- showSDoc dynFlags (pprRulesForUser $ concatMap snd rulesEnv) makeRule :: RuleNameString -> Id -> CoreExpr -> CoreRule makeRule rule_name nm = mkRule True -- auto-generated False -- local (mkFastString rule_name) NeverActive -- because we need to call for these (varName nm) [] [] -- TODO: check if a top-level binding addCoreBindAsRule :: Monad m => RuleNameString -> String -> Rewrite c m ModGuts addCoreBindAsRule rule_name nm = contextfreeT $ \ modGuts -> case [ (v,e) | bnd <- mg_binds modGuts , (v,e) <- bindToVarExprs bnd , nm `cmpString2Var` v ] of [] -> fail $ "cannot find binding " ++ nm [(v,e)] -> return $ modGuts { mg_rules = mg_rules modGuts ++ [makeRule rule_name v e] } _ -> fail $ "found multiple bindings for " ++ nm -- | Returns the universally quantified binders, the LHS, and the RHS. ruleToEqualityT :: (BoundVars c, HasGlobalRdrEnv c, HasDynFlags m, MonadThings m, MonadCatch m) => Translate c m CoreRule CoreExprEquality ruleToEqualityT = withPatFailMsg "HERMIT cannot handle built-in rules yet." $ do r@Rule{} <- idR -- other possibility is "BuiltinRule" f <- findIdT (getOccString $ ru_fn r) -- TODO: refactor name lookup functions (like findIdT) to avoid intermediate String here return $ CoreExprEquality (ru_bndrs r) (mkCoreApps (Var f) (ru_args r)) (ru_rhs r) ruleNameToEqualityT :: (BoundVars c, HasGlobalRdrEnv c, HasCoreRules c) => RuleNameString -> Translate c HermitM a CoreExprEquality ruleNameToEqualityT name = getSingletonHermitRuleT name >>> ruleToEqualityT ------------------------------------------------------------------------ -- | Run GHC's specConstr pass, and apply any rules generated. specConstrR :: RewriteH ModGuts specConstrR = prefixFailMsg "spec-constr failed: " $ do rs <- extractT specRules e' <- contextfreeT $ liftCoreM . SpecConstr.specConstrProgram rs' <- return e' >>> extractT specRules let specRs = deleteFirstsBy ((==) `on` ru_name) rs' rs guardMsg (notNull specRs) "no rules created." return e' >>> extractR (repeatR (anyCallR (promoteExprR $ rulesToRewriteH specRs))) -- | Run GHC's specialisation pass, and apply any rules generated. specialise :: RewriteH ModGuts specialise = prefixFailMsg "specialisation failed: " $ do gRules <- arr mg_rules lRules <- extractT specRules #if __GLASGOW_HASKELL__ <= 706 dflags <- dynFlagsT guts <- contextfreeT $ liftCoreM . Specialise.specProgram dflags #else guts <- contextfreeT $ liftCoreM . Specialise.specProgram #endif lRules' <- return guts >>> extractT specRules -- spec rules on bindings in this module let gRules' = mg_rules guts -- plus spec rules on imported bindings gSpecRs = deleteFirstsBy ((==) `on` ru_name) gRules' gRules lSpecRs = deleteFirstsBy ((==) `on` ru_name) lRules' lRules specRs = gSpecRs ++ lSpecRs guardMsg (notNull specRs) "no rules created." liftIO $ putStrLn $ unlines $ map (unpackFS . ru_name) specRs return guts >>> extractR (repeatR (anyCallR (promoteExprR $ rulesToRewriteH specRs))) -- | Get all the specialization rules on a binding. -- These are created by SpecConstr and other GHC passes. idSpecRules :: TranslateH Id [CoreRule] idSpecRules = do guardMsgM (arr isId) "idSpecRules called on TyVar." -- idInfo panics on TyVars contextfreeT $ \ i -> let SpecInfo rs _ = specInfo (idInfo i) in return rs -- | Promote 'idSpecRules' to CoreBind. bindSpecRules :: TranslateH CoreBind [CoreRule] bindSpecRules = recT (\_ -> defT idSpecRules successT const) concat <+ nonRecT idSpecRules successT const -- | Find all specialization rules in a Core fragment. specRules :: TranslateH Core [CoreRule] specRules = crushtdT $ promoteBindT bindSpecRules ------------------------------------------------------------------------