module HERMIT.Dictionary.Rules
(
externals
, RuleName(..)
, RuleNameListBox(..)
, foldRuleR
, foldRulesR
, unfoldRuleR
, unfoldRulesR
, compileRulesT
, ruleToQuantifiedT
, ruleNameToQuantifiedT
, getHermitRuleT
, getHermitRulesT
, specConstrR
, specialiseR
) where
import qualified SpecConstr
import qualified Specialise
import Control.Arrow
import Control.Monad
import Data.Dynamic (Typeable)
import Data.Function (on)
import Data.List (deleteFirstsBy,intercalate)
import Data.String (IsString(..))
import HERMIT.Context
import HERMIT.Core
import HERMIT.External
import HERMIT.GHC
import HERMIT.Kure
import HERMIT.Lemma
import HERMIT.Monad
import HERMIT.Dictionary.Fold (compileFold, CompiledFold, toEqualities)
import HERMIT.Dictionary.Kure (anyCallR)
import HERMIT.Dictionary.Reasoning hiding (externals)
import HERMIT.PrettyPrinter.Common
import IOEnv hiding (liftIO)
externals :: [External]
externals =
[ external "show-rules" (rulesHelpListT :: TransformH LCoreTC String)
[ "List all the rules in scope." ] .+ Query
, external "show-rule" (ruleHelpT :: PrettyPrinter -> RuleName -> TransformH LCoreTC DocH)
[ "Display details on the named rule." ] .+ Query
, external "fold-rule" (promoteExprR . foldRuleR Obligation :: RuleName -> RewriteH LCore)
[ "Apply a named GHC rule right-to-left." ] .+ Shallow
, external "fold-rules" (promoteExprR . foldRulesR Obligation :: [RuleName] -> RewriteH LCore)
[ "Apply named GHC rules right-to-left, succeed if any of the rules succeed." ] .+ Shallow
, external "unfold-rule" (promoteExprR . unfoldRuleR Obligation :: RuleName -> RewriteH LCore)
[ "Apply a named GHC rule left-to-right." ] .+ Shallow
, external "unfold-rule-unsafe" (promoteExprR . unfoldRuleR UnsafeUsed :: RuleName -> RewriteH LCore)
[ "Apply a named GHC rule left-to-right." ] .+ Shallow .+ Unsafe
, external "unfold-rules" (promoteExprR . unfoldRulesR Obligation :: [RuleName] -> RewriteH LCore)
[ "Apply named GHC rules left-to-right, succeed if any of the rules succeed" ] .+ Shallow
, external "unfold-rules-unsafe" (promoteExprR . unfoldRulesR UnsafeUsed :: [RuleName] -> RewriteH LCore)
[ "Apply named GHC rules left-to-right, succeed if any of the rules succeed" ] .+ Shallow .+ Unsafe
, external "rule-to-lemma" ((\pp nm -> ruleToLemmaT nm >> liftPrettyH (pOptions pp) (showLemmaT (fromString (show nm)) pp)) :: PrettyPrinter -> RuleName -> TransformH LCore DocH)
[ "Create a lemma from a GHC RULE." ]
, external "spec-constr" (promoteModGutsR specConstrR :: RewriteH LCore)
[ "Run GHC's SpecConstr pass, which performs call pattern specialization."] .+ Deep
, external "specialise" (promoteModGutsR specialiseR :: RewriteH LCore)
[ "Run GHC's specialisation pass, which performs type and dictionary specialisation."] .+ Deep
]
newtype RuleName = RuleName String deriving (Eq, Typeable)
instance Extern RuleName where
type Box RuleName = RuleName
box = id
unbox = id
instance IsString RuleName where fromString = RuleName
instance Show RuleName where show (RuleName s) = s
newtype RuleNameListBox = RuleNameListBox [RuleName] deriving Typeable
instance Extern [RuleName] where
type Box [RuleName] = RuleNameListBox
box = RuleNameListBox
unbox (RuleNameListBox l) = l
foldRuleR :: ( AddBindings c, ExtendPath c Crumb, HasCoreRules c, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasDynFlags m, HasHermitMEnv m, HasLemmas m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Used -> RuleName -> Rewrite c m CoreExpr
foldRuleR u nm = do
q <- ruleNameToQuantifiedT nm
backwardT (birewrite q) >>> (verifyOrCreateT u (fromString (show nm)) (Lemma q NotProven u False) >> idR)
foldRulesR :: ( AddBindings c, ExtendPath c Crumb, HasCoreRules c, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasDynFlags m, HasHermitMEnv m, HasLemmas m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Used -> [RuleName] -> Rewrite c m CoreExpr
foldRulesR u = orR . map (foldRuleR u)
unfoldRuleR :: ( AddBindings c, ExtendPath c Crumb, HasCoreRules c, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasDynFlags m, HasHermitMEnv m, HasLemmas m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Used -> RuleName -> Rewrite c m CoreExpr
unfoldRuleR u nm = do
q <- ruleNameToQuantifiedT nm
forwardT (birewrite q) >>> (verifyOrCreateT u (fromString (show nm)) (Lemma q NotProven u False) >> idR)
unfoldRulesR :: ( AddBindings c, ExtendPath c Crumb, HasCoreRules c, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasDynFlags m, HasHermitMEnv m, HasLemmas m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m, MonadUnique m )
=> Used -> [RuleName] -> Rewrite c m CoreExpr
unfoldRulesR u = orR . map (unfoldRuleR u)
compileRulesT :: (BoundVars c, HasCoreRules c, HasHermitMEnv m, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m)
=> [RuleName] -> Transform c m a CompiledFold
compileRulesT nms = do
let suggestion = "If you think the rule exists, try running the flatten-module command at the top level."
let failMsg [] = "no rule names supplied."
failMsg [nm] = "failed to find rule: " ++ show nm ++ ". " ++ suggestion
failMsg _ = "failed to find any rules named " ++ intercalate ", " (map show nms) ++ ". " ++ suggestion
allRules <- getHermitRulesT
case filter ((`elem` nms) . fst) allRules of
[] -> fail (failMsg nms)
rs -> liftM (compileFold . concatMap toEqualities)
$ forM (map snd rs) $ \ r -> return r >>> ruleToQuantifiedT
getHermitRulesT :: (HasCoreRules c, HasHermitMEnv m, LiftCoreM m, MonadIO m) => Transform c m a [(RuleName, CoreRule)]
getHermitRulesT = contextonlyT $ \ c -> do
rb <- liftCoreM getRuleBase
mgRules <- liftM mg_rules getModGuts
hscEnv <- liftCoreM getHscEnv
rb' <- liftM eps_rule_base $ liftIO $ runIOEnv () $ readMutVar (hsc_EPS hscEnv)
let allRules = hermitCoreRules c ++ mgRules ++ concat (nameEnvElts rb) ++ concat (nameEnvElts rb')
return [ (fromString (unpackFS (ruleName r)), r) | r <- allRules ]
getHermitRuleT :: (HasCoreRules c, HasHermitMEnv m, LiftCoreM m, MonadIO m) => RuleName -> Transform c m a CoreRule
getHermitRuleT name =
do rulesEnv <- getHermitRulesT
case filter ((name ==) . fst) rulesEnv of
[] -> fail $ "failed to find rule: " ++ show name ++ ". If you think the rule exists, "
++ "try running the flatten-module command at the top level."
[(_,r)] -> return r
_ -> fail ("Rule name \"" ++ show name ++ "\" is ambiguous.")
rulesHelpListT :: (HasCoreRules c, HasHermitMEnv m, LiftCoreM m, MonadIO m) => Transform c m a String
rulesHelpListT = do
rulesEnv <- getHermitRulesT
return (intercalate "\n" $ reverse $ map (show.fst) rulesEnv)
ruleHelpT :: (HasCoreRules c, ReadBindings c, ReadPath c Crumb) => PrettyPrinter -> RuleName -> Transform c HermitM a DocH
ruleHelpT pp nm = ruleNameToQuantifiedT nm >>> liftPrettyH (pOptions pp) (ppQuantifiedT pp)
ruleNameToQuantifiedT :: ( BoundVars c, HasCoreRules c, HasDynFlags m, HasHermitMEnv m
, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m )
=> RuleName -> Transform c m a Quantified
ruleNameToQuantifiedT name = getHermitRuleT name >>> ruleToQuantifiedT
ruleToQuantifiedT :: (BoundVars c, HasHermitMEnv m, MonadThings m, MonadCatch m)
=> Transform c m CoreRule Quantified
ruleToQuantifiedT = withPatFailMsg "HERMIT cannot handle built-in rules yet." $ do
r@Rule{} <- idR
f <- lookupId $ ru_fn r
let lhs = mkCoreApps (varToCoreExpr f) (ru_args r)
return $ mkQuantified (ru_bndrs r) lhs (ru_rhs r)
ruleToLemmaT :: ( BoundVars c, HasCoreRules c, HasDynFlags m, HasHermitMEnv m, HasLemmas m
, LiftCoreM m, MonadCatch m, MonadIO m, MonadThings m)
=> RuleName -> Transform c m a ()
ruleToLemmaT nm = do
q <- ruleNameToQuantifiedT nm
insertLemmaT (fromString (show nm)) $ Lemma q NotProven NotUsed False
specConstrR :: RewriteH ModGuts
specConstrR = prefixFailMsg "spec-constr failed: " $ do
rs <- extractT specRules
e' <- contextfreeT $ liftCoreM . SpecConstr.specConstrProgram
rs' <- return e' >>> extractT specRules
let specRs = deleteFirstsBy ((==) `on` ru_name) rs' rs
guardMsg (notNull specRs) "no rules created."
let applyAllR = extractR
$ repeatR
$ anyCallR
$ promoteExprR
$ rulesToRewrite specRs
return e' >>> applyAllR
specialiseR :: RewriteH ModGuts
specialiseR = prefixFailMsg "specialisation failed: " $ do
gRules <- arr mg_rules
lRules <- extractT specRules
guts <- contextfreeT $ liftCoreM . Specialise.specProgram
lRules' <- return guts >>> extractT specRules
let gRules' = mg_rules guts
gSpecRs = deleteFirstsBy ((==) `on` ru_name) gRules' gRules
lSpecRs = deleteFirstsBy ((==) `on` ru_name) lRules' lRules
specRs = gSpecRs ++ lSpecRs
guardMsg (notNull specRs) "no rules created."
liftIO $ putStrLn $ unlines $ map (unpackFS . ru_name) specRs
return guts >>> extractR (repeatR (anyCallR (promoteExprR $ rulesToRewrite specRs)))
idSpecRules :: TransformH Id [CoreRule]
idSpecRules = do
guardMsgM (arr isId) "idSpecRules called on TyVar."
contextfreeT $ \ i -> let SpecInfo rs _ = specInfo (idInfo i) in return rs
bindSpecRules :: TransformH CoreBind [CoreRule]
bindSpecRules = recT (\_ -> defT idSpecRules successT const) concat
<+ nonRecT idSpecRules successT const
specRules :: TransformH Core [CoreRule]
specRules = crushtdT $ promoteBindT bindSpecRules
rulesToRewrite :: ( AddBindings c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb
, HasDynFlags m, HasHermitMEnv m, MonadCatch m, MonadThings m, MonadUnique m )
=> [CoreRule] -> Rewrite c m CoreExpr
rulesToRewrite rs = catchesM [ (return r >>> ruleToQuantifiedT) >>= forwardT . birewrite | r <- rs ]