{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NoImplicitPrelude #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} module HERMIT.Dictionary.Reasoning ( -- * Equational Reasoning externals , EqualityProof , eqLhsIntroR , eqRhsIntroR , birewrite , extensionalityR , getLemmasT , getLemmaByNameT , insertLemmaT , insertLemmasT , lemmaBiR , lemmaConsequentR , markLemmaUsedT , markLemmaProvenT , modifyLemmaT , showLemmaT , showLemmasT , ppLemmaT , ppClauseT , ppLCoreTCT -- ** Lifting transformations over 'Clause' , lhsT , rhsT , bothT , lhsR , rhsR , bothR , verifyClauseT , lemmaR , quantIdentitiesR , verifyOrCreateT , verifyEqualityLeftToRightT , verifyEqualityCommonTargetT , verifyIsomorphismT , verifyRetractionT , reflexivityR , simplifyClauseR , retractionBR , unshadowClauseR , instantiateDictsR , instantiateClauseVarR , abstractClauseR -- * Constructing Composite Lemmas , ($$) , ($$$) , (==>) , (-->) , (===) , (/\) , (\/) , ToCoreExpr(..) , newLemma ) where import Control.Arrow hiding ((<+>)) import Control.Monad ((>=>), forM, liftM) import Data.Either (partitionEithers) import Data.List (isInfixOf, nubBy) import qualified Data.Map as Map import Data.Maybe (fromMaybe) import Data.Monoid import HERMIT.Context import HERMIT.Core import HERMIT.External import HERMIT.GHC hiding ((<>), (<+>), nest, ($+$), ($$)) import HERMIT.Kure import HERMIT.Lemma import HERMIT.Monad import HERMIT.Name import HERMIT.ParserCore import HERMIT.ParserType import HERMIT.PrettyPrinter.Common import HERMIT.Utilities import HERMIT.Dictionary.Common import HERMIT.Dictionary.Fold hiding (externals) import HERMIT.Dictionary.Function hiding (externals) import HERMIT.Dictionary.GHC hiding (externals) import HERMIT.Dictionary.Local.Let (nonRecIntroR) import Prelude.Compat hiding ((<$>), (<*>)) import qualified Text.PrettyPrint.MarkedHughesPJ as PP ------------------------------------------------------------------------------ externals :: [External] externals = [ external "retraction" ((\ f g r -> promoteExprBiR $ retraction (Just r) f g) :: CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore) [ "Given f :: X -> Y and g :: Y -> X, and a proof that f (g y) ==> y, then" , "f (g y) <==> y." ] .+ Shallow , external "retraction-unsafe" ((\ f g -> promoteExprBiR $ retraction Nothing f g) :: CoreString -> CoreString -> BiRewriteH LCore) [ "Given f :: X -> Y and g :: Y -> X, then" , "f (g y) <==> y." , "Note that the precondition (f (g y) == y) is expected to hold." ] .+ Shallow .+ PreCondition , external "unshadow-quantified" (promoteClauseR unshadowClauseR :: RewriteH LCoreTC) [ "Unshadow a quantified clause." ] , external "merge-quantifiers" (\n1 n2 -> promoteR (mergeQuantifiersR (cmpHN2Var n1) (cmpHN2Var n2)) :: RewriteH LCore) [ "Merge quantifiers from two clauses if they have the same type." , "Example:" , "(forall (x::Int). foo x = x) ^ (forall (y::Int). bar y y = 5)" , "merge-quantifiers 'x 'y" , "forall (x::Int). (foo x = x) ^ (bar x x = 5)" , "Note: if only one quantifier matches, it will be floated if possible." ] , external "float-left" (\n1 -> promoteR (mergeQuantifiersR (cmpHN2Var n1) (const False)) :: RewriteH LCore) [ "Float quantifier out of left-hand side." ] , external "float-right" (\n1 -> promoteR (mergeQuantifiersR (const False) (cmpHN2Var n1)) :: RewriteH LCore) [ "Float quantifier out of right-hand side." ] , external "conjunct" (\n1 n2 n3 -> conjunctLemmasT n1 n2 n3 :: TransformH LCore ()) [ "conjunt new-name lhs-name rhs-name" ] , external "disjunct" (\n1 n2 n3 -> disjunctLemmasT n1 n2 n3 :: TransformH LCore ()) [ "disjunt new-name lhs-name rhs-name" ] , external "imply" (\n1 n2 n3 -> implyLemmasT n1 n2 n3 :: TransformH LCore ()) [ "imply new-name antecedent-name consequent-name" ] , external "lemma-birewrite" (promoteExprBiR . lemmaBiR Obligation :: LemmaName -> BiRewriteH LCore) [ "Generate a bi-directional rewrite from a lemma." ] , external "lemma-forward" (forwardT . promoteExprBiR . lemmaBiR Obligation :: LemmaName -> RewriteH LCore) [ "Generate a rewrite from a lemma, left-to-right." ] , external "lemma-backward" (backwardT . promoteExprBiR . lemmaBiR Obligation :: LemmaName -> RewriteH LCore) [ "Generate a rewrite from a lemma, right-to-left." ] , external "lemma-consequent" (promoteClauseR . lemmaConsequentR Obligation :: LemmaName -> RewriteH LCore) [ "Match the current lemma with the consequent of an implication lemma." , "Upon success, replaces with antecedent of the implication, properly instantiated." ] , external "lemma-consequent-birewrite" (promoteExprBiR . lemmaConsequentBiR Obligation :: LemmaName -> BiRewriteH LCore) [ "Generate a bi-directional rewrite from the consequent of an implication lemma." , "The antecedent is instantiated and introduced as an unproven obligation." ] , external "lemma-lhs-intro" (promoteCoreR . lemmaLhsIntroR :: LemmaName -> RewriteH LCore) [ "Introduce the LHS of a lemma as a non-recursive binding, in either an expression or a program." , "body ==> let v = lhs in body" ] .+ Introduce .+ Shallow , external "lemma-rhs-intro" (promoteCoreR . lemmaRhsIntroR :: LemmaName -> RewriteH LCore) [ "Introduce the RHS of a lemma as a non-recursive binding, in either an expression or a program." , "body ==> let v = rhs in body" ] .+ Introduce .+ Shallow , external "inst-lemma" (\ nm v cs -> modifyLemmaT nm id (instantiateClauseVarR (cmpHN2Var v) cs) id id :: TransformH LCore ()) [ "Instantiate one of the universally quantified variables of the given lemma," , "with the given Core expression, creating a new lemma. Instantiating an" , "already proven lemma will result in the new lemma being considered proven." ] , external "inst-dictionaries" (promoteClauseR instantiateDictsR :: RewriteH LCore) [ "Instantiate all of the universally quantified dictionaries of the given lemma." ] , external "abstract-forall" ((\nm -> promoteClauseR . abstractClauseR nm . csInQBodyT) :: String -> CoreString -> RewriteH LCore) [ "Weaken a lemma by abstracting an expression to a new quantifier." ] , external "abstract-forall" ((\nm rr -> promoteClauseR $ abstractClauseR nm $ extractT rr >>> setFailMsg "path must focus on an expression" projectT) :: String -> RewriteH LCore -> RewriteH LCore) [ "Weaken a lemma by abstracting an expression to a new quantifier." ] , external "copy-lemma" (\ nm newName -> modifyLemmaT nm (const newName) idR id id :: TransformH LCore ()) [ "Copy a given lemma, with a new name." ] , external "modify-lemma" ((\ nm rr -> modifyLemmaT nm id (extractR rr) (const NotProven) (const NotUsed)) :: LemmaName -> RewriteH LCore -> TransformH LCore ()) [ "Modify a given lemma. Resets proven status to Not Proven and used status to Not Used." ] , external "query-lemma" ((\ nm t -> getLemmaByNameT nm >>> arr lemmaC >>> extractT t) :: LemmaName -> TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to a lemma, returning the result." ] , external "show-lemma" ((\pp n -> showLemmaT n pp) :: PrettyPrinter -> LemmaName -> PrettyH LCore) [ "Display a lemma." ] , external "show-lemmas" ((\pp n -> showLemmasT (Just n) pp) :: PrettyPrinter -> LemmaName -> PrettyH LCore) [ "List lemmas whose names match search string." ] , external "show-lemmas" (showLemmasT Nothing :: PrettyPrinter -> PrettyH LCore) [ "List lemmas." ] , external "extensionality" (promoteR . extensionalityR . Just :: String -> RewriteH LCore) [ "Given a name 'x, then" , "f == g ==> forall x. f x == g x" ] , external "extensionality" (promoteR (extensionalityR Nothing) :: RewriteH LCore) [ "f == g ==> forall x. f x == g x" ] , external "lhs" (promoteClauseT . lhsT :: TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to the LHS of a quantified clause." ] , external "lhs" (promoteClauseR . lhsR :: RewriteH LCore -> RewriteH LCore) [ "Apply a rewrite to the LHS of a quantified clause." ] , external "rhs" (promoteClauseT . rhsT :: TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to the RHS of a quantified clause." ] , external "rhs" (promoteClauseR . rhsR :: RewriteH LCore -> RewriteH LCore) [ "Apply a rewrite to the RHS of a quantified clause." ] , external "both" (promoteClauseR . bothR :: RewriteH LCore -> RewriteH LCore) [ "Apply a rewrite to both sides of an equality, succeeding if either succeed." ] , external "both" ((\t -> do (r,s) <- promoteClauseT (bothT t); return (unlines [r,s])) :: TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to both sides of a quantified clause." ] , external "reflexivity" (promoteClauseR (forallR idR reflexivityR <+ reflexivityR) :: RewriteH LCore) [ "Rewrite alpha-equivalence to true." ] , external "simplify-lemma" (simplifyClauseR :: RewriteH LCore) [ "Reduce a proof by applying reflexivity and logical operator identities." ] , external "split-antecedent" (promoteClauseR splitAntecedentR :: RewriteH LCore) [ "Split an implication of the form (q1 ^ q2) => q3 into q1 => (q2 => q3)" ] , external "lemma" (promoteClauseR . lemmaR Obligation :: LemmaName -> RewriteH LCore) [ "Rewrite clause to true using given lemma." ] , external "lemma-unsafe" (promoteClauseR . lemmaR UnsafeUsed :: LemmaName -> RewriteH LCore) [ "Rewrite clause to true using given lemma." ] .+ Unsafe ] ------------------------------------------------------------------------------ type EqualityProof c m = (Rewrite c m CoreExpr, Rewrite c m CoreExpr) -- | f == g ==> forall x. f x == g x extensionalityR :: (AddBindings c, ExtendPath c Crumb, ReadPath c Crumb) => Maybe String -> Rewrite c HermitM Clause extensionalityR mn = prefixFailMsg "extensionality failed: " $ do (vs,(lhs,rhs)) <- forallT idR (equivT idR idR (,)) (,) <+ equivT idR idR (\l r -> ([],(l,r))) let tyL = exprKindOrType lhs tyR = exprKindOrType rhs guardMsg (tyL `typeAlphaEq` tyR) "type mismatch between sides of equality. This shouldn't happen, so is probably a bug." -- TODO: use the fresh-name-generator in AlphaConversion to avoid shadowing. (_,argTy,_) <- splitFunTypeM tyL v <- constT $ newVarH (fromMaybe "x" mn) argTy let x = varToCoreExpr v return $ Forall (vs ++ [v]) $ Equiv (mkCoreApp lhs x) (mkCoreApp rhs x) ------------------------------------------------------------------------------ -- | @e@ ==> @let v = lhs in e@ eqLhsIntroR :: Clause -> Rewrite c HermitM Core eqLhsIntroR (Forall bs (Equiv lhs _)) = nonRecIntroR "lhs" (mkCoreLams bs lhs) eqLhsIntroR _ = fail "compound lemmas not supported." -- | @e@ ==> @let v = rhs in e@ eqRhsIntroR :: Clause -> Rewrite c HermitM Core eqRhsIntroR (Forall bs (Equiv _ rhs)) = nonRecIntroR "rhs" (mkCoreLams bs rhs) eqRhsIntroR _ = fail "compound lemmas not supported." ------------------------------------------------------------------------------ -- | Create a 'BiRewrite' from a 'Clause'. birewrite :: ( AddBindings c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c , ReadPath c Crumb, MonadCatch m, MonadUnique m ) => Clause -> BiRewrite c m CoreExpr birewrite cl = bidirectional (foldUnfold "left" id) (foldUnfold "right" flipEquality) where foldUnfold side f = transform $ \ c -> maybeM ("expression did not match "++side++"-hand side") . fold (map f (toEqualities cl)) c ------------------------------------------------------------------------------ -- TODO: deprecate these? -- Yes, but later. They're in the paper now. -- We should be using "childR crumb", really. -- | Lift a transformation over 'LCoreTC' into a transformation over the left-hand side of a 'Clause'. lhsT :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Transform c m LCore a -> Transform c m Clause a lhsT t = extractT $ catchesT [ f (childT cr t) | cr <- [Conj_Lhs, Disj_Lhs, Impl_Lhs, Eq_Lhs] , f <- [childT Forall_Body, id] ] -- | Lift a transformation over 'LCoreTC' into a transformation over the right-hand side of a 'Clause'. rhsT :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Transform c m LCore a -> Transform c m Clause a rhsT t = extractT $ catchesT [ f (childT cr t) | cr <- [Conj_Rhs, Disj_Rhs, Impl_Rhs, Eq_Rhs] , f <- [childT Forall_Body, id] ] -- | Lift a transformation over 'LCoreTC' into a transformation over both sides of a 'Clause'. bothT :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Transform c m LCore a -> Transform c m Clause (a, a) bothT t = (,) <$> lhsT t <*> rhsT t -- | Lift a rewrite over 'LCoreTC' into a rewrite over the left-hand side of a 'Clause'. lhsR :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Rewrite c m LCore -> Rewrite c m Clause lhsR r = extractR $ catchesT [ f (childR cr r) | cr <- [Conj_Lhs, Disj_Lhs, Impl_Lhs, Eq_Lhs] , f <- [childR Forall_Body, id] ] -- | Lift a rewrite over 'LCoreTC' into a rewrite over the right-hand side of a 'Clause'. rhsR :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Rewrite c m LCore -> Rewrite c m Clause rhsR r = extractR $ catchesT [ f (childR cr r) | cr <- [Conj_Rhs, Disj_Rhs, Impl_Rhs, Eq_Rhs] , f <- [childR Forall_Body, id] ] -- | Lift a rewrite over 'LCoreTC' into a rewrite over both sides of a 'Clause'. bothR :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Rewrite c m LCore -> Rewrite c m Clause bothR r = lhsR r >+> rhsR r ------------------------------------------------------------------------------ showLemmasT :: Maybe LemmaName -> PrettyPrinter -> PrettyH a showLemmasT mnm pp = do ls <- getLemmasT let ls' = Map.toList $ Map.filterWithKey (maybe (\ _ _ -> True) (\ nm n _ -> show nm `isInfixOf` show n) mnm) ls ds <- forM ls' $ \(nm,l) -> return l >>> ppLemmaT pp nm return $ PP.vcat ds showLemmaT :: LemmaName -> PrettyPrinter -> PrettyH a showLemmaT nm pp = getLemmaByNameT nm >>> ppLemmaT pp nm ppLemmaT :: PrettyPrinter -> LemmaName -> PrettyH Lemma ppLemmaT pp nm = do Lemma q p _u <- idR qDoc <- return q >>> ppClauseT pp let hDoc = PP.text (show nm) PP.<+> PP.text ("(" ++ show p ++ ")") return $ hDoc PP.$+$ PP.nest 2 qDoc ppLCoreTCT :: PrettyPrinter -> PrettyH LCoreTC ppLCoreTCT pp = promoteT (ppClauseT pp) <+ promoteT (pCoreTC pp) ppClauseT :: PrettyPrinter -> PrettyH Clause ppClauseT pp = do p <- absPathT let parenify = ppClauseT pp >>^ \ d -> syntaxColor (PP.text "(") PP.<> d PP.<> syntaxColor (PP.text ")") (forallT (pForall pp) (ppClauseT pp) (\ d1 d2 -> PP.sep [d1,d2]) <+ conjT parenify parenify (\ d1 d2 -> PP.sep [d1,syntaxColor (specialSymbol p ConjSymbol),d2]) <+ disjT parenify parenify (\ d1 d2 -> PP.sep [d1,syntaxColor (specialSymbol p DisjSymbol),d2]) <+ implT parenify parenify (\ _nm d1 d2 -> PP.sep [d1,syntaxColor (specialSymbol p ImplSymbol),d2]) <+ equivT (extractT $ pCoreTC pp) (extractT $ pCoreTC pp) (\ d1 d2 -> PP.sep [d1,specialSymbol p EquivSymbol,d2]) <+ return (syntaxColor $ PP.text "true")) ------------------------------------------------------------------------------ verifyClauseT :: (AddBindings c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Transform c m Clause () verifyClauseT = setFailMsg "verification failed: clause must be true (perhaps try reflexivity first)" $ do CTrue <- idR return () lemmaR :: (LemmaContext c, HasLemmas m, MonadCatch m) => Used -> LemmaName -> Rewrite c m Clause lemmaR used nm = prefixFailMsg "verification failed: " $ do Lemma cl _ _ <- getLemmaByNameT nm eq <- arr (cl `proves`) guardMsg eq "lemmas are not equivalent." markLemmaUsedT nm used return CTrue verifyOrCreateT :: ( AddBindings c, ExtendPath c Crumb, HasCoreRules c, LemmaContext c, ReadBindings c, ReadPath c Crumb , HasHermitMEnv m, HasLemmas m, LiftCoreM m, MonadCatch m ) => Used -> LemmaName -> Clause -> Transform c m a () verifyOrCreateT u nm cl = do exists <- testM $ getLemmaByNameT nm if exists then return cl >>> lemmaR u nm >>> verifyClauseT else contextonlyT $ \ c -> sendKEnvMessage $ AddObligation (toHermitC c) nm $ Lemma cl NotProven u reflexivityR :: Monad m => Rewrite c m Clause reflexivityR = do Equiv lhs rhs <- idR guardMsg (exprAlphaEq lhs rhs) "the two sides are not alpha-equivalent." return CTrue simplifyClauseR :: (AddBindings c, ExtendPath c Crumb, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, MonadCatch m) => Rewrite c m LCore simplifyClauseR = anybuR (promoteR quantIdentitiesR <+ promoteR reflexivityR) quantIdentitiesR :: MonadCatch m => Rewrite c m Clause quantIdentitiesR = trueConjLR <+ trueConjRR <+ trueDisjLR <+ trueDisjRR <+ trueImpliesR <+ impliesTrueR <+ aImpliesAR <+ forallTrueR trueConjLR :: Monad m => Rewrite c m Clause trueConjLR = do Conj CTrue cl <- idR return cl trueConjRR :: Monad m => Rewrite c m Clause trueConjRR = do Conj cl CTrue <- idR return cl trueDisjLR :: Monad m => Rewrite c m Clause trueDisjLR = do Disj CTrue _ <- idR return CTrue trueDisjRR :: Monad m => Rewrite c m Clause trueDisjRR = do Disj _ CTrue <- idR return CTrue trueImpliesR :: Monad m => Rewrite c m Clause trueImpliesR = do Impl _ CTrue cl <- idR return cl impliesTrueR :: Monad m => Rewrite c m Clause impliesTrueR = do Impl _ _ CTrue <- idR return CTrue forallTrueR :: Monad m => Rewrite c m Clause forallTrueR = do Forall _ CTrue <- idR return CTrue aImpliesAR :: Monad m => Rewrite c m Clause aImpliesAR = do Impl _ a c <- idR guardMsg (a `proves` c) "antecedent does not prove consequent." return CTrue splitAntecedentR :: MonadCatch m => Rewrite c m Clause splitAntecedentR = prefixFailMsg "antecedent split failed: " $ withPatFailMsg (wrongExprForm "(ante1 ^ ante2) => con") $ do Impl nm (Conj c1 c2) con <- idR return $ Impl (nm <> "0") c1 $ Impl (nm <> "1") c2 con ------------------------------------------------------------------------------ -- TODO: everything between here and instantiateDictsR needs to be rethought/removed -- TODO: this is used in century plugin, but otherwise should be removed -- | Given two expressions, and a rewrite from the former to the latter, verify that rewrite. verifyEqualityLeftToRightT :: MonadCatch m => CoreExpr -> CoreExpr -> Rewrite c m CoreExpr -> Transform c m a () verifyEqualityLeftToRightT sourceExpr targetExpr r = prefixFailMsg "equality verification failed: " $ do resultExpr <- r <<< return sourceExpr guardMsg (exprAlphaEq targetExpr resultExpr) "result of running proof on lhs of equality does not match rhs of equality." -- | Given two expressions, and a rewrite to apply to each, verify that the resulting expressions are equal. verifyEqualityCommonTargetT :: MonadCatch m => CoreExpr -> CoreExpr -> EqualityProof c m -> Transform c m a () verifyEqualityCommonTargetT lhs rhs (l,r) = prefixFailMsg "equality verification failed: " $ do lhsResult <- l <<< return lhs rhsResult <- r <<< return rhs guardMsg (exprAlphaEq lhsResult rhsResult) "results of running proofs on both sides of equality do not match." ------------------------------------------------------------------------------ -- Note: We use global Ids for verification to avoid out-of-scope errors. -- | Given f :: X -> Y and g :: Y -> X, verify that f (g y) ==> y and g (f x) ==> x. verifyIsomorphismT :: CoreExpr -> CoreExpr -> Rewrite c HermitM CoreExpr -> Rewrite c HermitM CoreExpr -> Transform c HermitM a () verifyIsomorphismT f g fgR gfR = prefixFailMsg "Isomorphism verification failed: " $ do (tyX, tyY) <- funExprsWithInverseTypes f g x <- constT (newGlobalIdH "x" tyX) y <- constT (newGlobalIdH "y" tyY) verifyEqualityLeftToRightT (App f (App g (Var y))) (Var y) fgR verifyEqualityLeftToRightT (App g (App f (Var x))) (Var x) gfR -- | Given f :: X -> Y and g :: Y -> X, verify that f (g y) ==> y. verifyRetractionT :: CoreExpr -> CoreExpr -> Rewrite c HermitM CoreExpr -> Transform c HermitM a () verifyRetractionT f g r = prefixFailMsg "Retraction verification failed: " $ do (_tyX, tyY) <- funExprsWithInverseTypes f g y <- constT (newGlobalIdH "y" tyY) let lhs = App f (App g (Var y)) rhs = Var y verifyEqualityLeftToRightT lhs rhs r ------------------------------------------------------------------------------ -- | Given f :: X -> Y and g :: Y -> X, and a proof that f (g y) ==> y, then f (g y) <==> y. retractionBR :: forall c. Maybe (Rewrite c HermitM CoreExpr) -> CoreExpr -> CoreExpr -> BiRewrite c HermitM CoreExpr retractionBR mr f g = beforeBiR (prefixFailMsg "Retraction failed: " $ do whenJust (verifyRetractionT f g) mr y <- idR (_, tyY) <- funExprsWithInverseTypes f g guardMsg (exprKindOrType y `typeAlphaEq` tyY) "type of expression does not match given retraction components." return y ) (\ y -> bidirectional retractionL (return $ App f (App g y)) ) where retractionL :: Rewrite c HermitM CoreExpr retractionL = prefixFailMsg "Retraction failed: " $ withPatFailMsg (wrongExprForm "App f (App g y)") $ do App f' (App g' y) <- idR guardMsg (exprAlphaEq f f' && exprAlphaEq g g') "given retraction components do not match current expression." return y -- | Given @f :: X -> Y@ and @g :: Y -> X@, and a proof that @f (g y)@ ==> @y@, then @f (g y)@ <==> @y@. retraction :: Maybe (RewriteH LCore) -> CoreString -> CoreString -> BiRewriteH CoreExpr retraction mr = parse2beforeBiR (retractionBR (extractR <$> mr)) ------------------------------------------------------------------------------ -- TODO: revisit this for binder re-ordering issue instantiateDictsR :: RewriteH Clause instantiateDictsR = prefixFailMsg "Dictionary instantiation failed: " $ do bs <- forallT idR successT const let dArgs = filter (\b -> isId b && isDictTy (varType b)) bs uniqDs = nubBy (\ b1 b2 -> eqType (varType b1) (varType b2)) dArgs guardMsg (not (null uniqDs)) "no universally quantified dictionaries can be instantiated." ds <- forM uniqDs $ \ b -> constT $ do (i,bnds) <- buildDictionary b let dExpr = case bnds of -- the common case that we would have gotten a single non-recursive let [NonRec v e] | i == v -> e _ -> mkCoreLets bnds (varToCoreExpr i) return (b,dExpr) let buildSubst :: Monad m => Var -> m (Var, CoreExpr) buildSubst b = case [ (b,e) | (b',e) <- ds, eqType (varType b) (varType b') ] of [] -> fail "cannot find equivalent dictionary expression (impossible!)" [t] -> return t _ -> fail "multiple dictionary expressions found (impossible!)" lookup2 :: Var -> [(Var,CoreExpr)] -> (Var,CoreExpr) lookup2 v l = head [ t | t@(v',_) <- l, v == v' ] allDs <- forM dArgs $ \ b -> constT $ do if b `elem` uniqDs then return $ lookup2 b ds else buildSubst b transform (\ c -> instsClause (boundVars c) allDs) >>> arr redundantDicts ------------------------------------------------------------------------------ conjunctLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> LemmaName -> LemmaName -> Transform c m a () conjunctLemmasT new lhs rhs = do Lemma ql pl _ <- getLemmaByNameT lhs Lemma qr pr _ <- getLemmaByNameT rhs insertLemmaT new $ Lemma (Conj ql qr) (pl `andP` pr) NotUsed disjunctLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> LemmaName -> LemmaName -> Transform c m a () disjunctLemmasT new lhs rhs = do Lemma ql pl _ <- getLemmaByNameT lhs Lemma qr pr _ <- getLemmaByNameT rhs insertLemmaT new $ Lemma (Disj ql qr) (pl `orP` pr) NotUsed implyLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> LemmaName -> LemmaName -> Transform c m a () implyLemmasT new lhs rhs = do Lemma ql _ _ <- getLemmaByNameT lhs Lemma qr pr _ <- getLemmaByNameT rhs insertLemmaT new $ Lemma (Impl lhs ql qr) pr NotUsed ------------------------------------------------------------------------------ mergeQuantifiersR :: MonadCatch m => (Var -> Bool) -> (Var -> Bool) -> Rewrite c m Clause mergeQuantifiersR pl pr = contextfreeT $ mergeQuantifiers pl pr mergeQuantifiers :: MonadCatch m => (Var -> Bool) -> (Var -> Bool) -> Clause -> m Clause mergeQuantifiers pl pr cl = prefixFailMsg "merge-quantifiers failed: " $ do (con,lq@(Forall bsl cll),rq@(Forall bsr clr)) <- case cl of Conj q1 q2 -> return (Conj,q1,q2) Disj q1 q2 -> return (Disj,q1,q2) Impl nm q1 q2 -> return (Impl nm,q1,q2) _ -> fail "no quantifiers on either side." let (lBefore,lbs) = break pl bsl (rBefore,rbs) = break pr bsr check b q l r = guardMsg (not (b `elemVarSet` freeVarsClause q)) $ "specified "++l++" binder would capture in "++r++"-hand clause." checkUB v vs = let fvs = freeVarsVar v in guardMsg (not (any (`elemVarSet` fvs) vs)) $ "binder " ++ getOccString v ++ " cannot be floated because it depends on binders not being floated." case (lbs,rbs) of ([],[]) -> fail "no quantifiers match." ([],rb:rAfter) -> do check rb lq "right" "left" checkUB rb rBefore return $ mkForall [rb] $ con lq (mkForall (rBefore++rAfter) clr) (lb:lAfter,[]) -> do check lb rq "left" "right" checkUB lb lBefore return $ mkForall [lb] $ con (mkForall (lBefore++lAfter) cll) rq (lb:lAfter,rb:rAfter) -> do guardMsg (eqType (varType lb) (varType rb)) "specified quantifiers have differing types." check lb rq "left" "right" check rb lq "right" "left" checkUB lb lBefore checkUB rb rBefore let clr' = substClause rb (varToCoreExpr lb) $ mkForall rAfter clr rq' = mkForall rBefore clr' lq' = mkForall (lBefore ++ lAfter) cll return $ mkForall [lb] (con lq' rq') ------------------------------------------------------------------------------ unshadowClauseR :: MonadUnique m => Rewrite c m Clause unshadowClauseR = contextfreeT unshadowClause unshadowClause :: MonadUnique m => Clause -> m Clause unshadowClause c = go emptySubst (mapUniqSet fs (freeVarsClause c)) c where fs = occNameFS . getOccName go subst seen (Forall bs cl) = go1 subst seen bs [] cl go subst seen (Conj q1 q2) = do q1' <- go subst seen q1 q2' <- go subst seen q2 return $ Conj q1' q2' go subst seen (Disj q1 q2) = do q1' <- go subst seen q1 q2' <- go subst seen q2 return $ Disj q1' q2' go subst seen (Impl nm q1 q2) = do q1' <- go subst seen q1 q2' <- go subst seen q2 return $ Impl nm q1' q2' go subst _ (Equiv e1 e2) = let e1' = substExpr (text "unshadowClause e1") subst e1 e2' = substExpr (text "unshadowClause e2") subst e2 in return $ Equiv e1' e2' go _ _ CTrue = return CTrue go1 subst seen [] bs' cl = do cl' <- go subst seen cl return $ mkForall (reverse bs') cl' go1 subst seen (b:bs) bs' cl | fsb `elementOfUniqSet` seen = do b'' <- cloneVarFSH (inventNames seen) b' go1 (extendSubst subst' b' (varToCoreExpr b'')) (addOneToUniqSet seen (fs b'')) bs (b'':bs') cl | otherwise = go1 subst' (addOneToUniqSet seen fsb) bs (b':bs') cl where fsb = fs b' (subst', b') = substBndr subst b inventNames :: UniqSet FastString -> FastString -> FastString inventNames s nm = head [ nm' | i :: Int <- [0..] , let nm' = nm `appendFS` (mkFastString (show i)) , not (nm' `elementOfUniqSet` s) ] ------------------------------------------------------------------------------ instantiateClauseVarR :: (Var -> Bool) -> CoreString -> RewriteH Clause instantiateClauseVarR p cs = prefixFailMsg "instantiation failed: " $ do bs <- forallT idR successT const e <- case filter p bs of [] -> fail "no universally quantified variables match predicate." (b:_) | isId b -> let (before,_) = break (==b) bs in withVarsInScope before $ parseCoreExprT cs | otherwise -> let (before,_) = break (==b) bs in liftM (Type . fst) $ withVarsInScope before $ parseTypeWithHolesT cs transform (\ c -> instClause (boundVars c) p e) >>> (lintClauseT >> idR) -- lint for sanity ------------------------------------------------------------------------------ -- | Replace all occurrences of the given expression with a new quantified variable. abstractClauseR :: forall c m. ( AddBindings c, BoundVars c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb , LemmaContext c, HasHermitMEnv m, HasLemmas m, LiftCoreM m, MonadCatch m, MonadUnique m ) => String -> Transform c m Clause CoreExpr -> Rewrite c m Clause abstractClauseR nm tr = prefixFailMsg "abstraction failed: " $ do e <- tr cl <- idR b <- constT $ newVarH nm (exprKindOrType e) let f = compileFold [Equality [] e (varToCoreExpr b)] -- we don't use mkEquality on purpose, so we can abstract lambdas liftM dropBinders $ return (mkForall [b] cl) >>> extractR (anytdR $ promoteExprR $ runFoldR f :: Rewrite c m LCoreTC) csInQBodyT :: ( AddBindings c, ExtendPath c Crumb, ReadBindings c, ReadPath c Crumb, HasHermitMEnv m, HasLemmas m, LiftCoreM m ) => CoreString -> Transform c m Clause CoreExpr csInQBodyT cs = forallT successT (parseCoreExprT cs) (flip const) ------------------------------------------------------------------------------ getLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => Transform c m x Lemmas getLemmasT = contextonlyT $ \ c -> liftM (Map.union (getAntecedents c)) getLemmas getLemmaByNameT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> Transform c m x Lemma getLemmaByNameT nm = getLemmasT >>= maybe (fail $ "No lemma named: " ++ show nm) return . Map.lookup nm ------------------------------------------------------------------------------ lemmaBiR :: ( AddBindings c, ExtendPath c Crumb, HasEmptyContext c, LemmaContext c, ReadBindings c, ReadPath c Crumb , HasLemmas m, MonadCatch m, MonadUnique m) => Used -> LemmaName -> BiRewrite c m CoreExpr lemmaBiR u nm = afterBiR (beforeBiR (getLemmaByNameT nm) (birewrite . lemmaC)) (markLemmaUsedT nm u >> idR) lemmaConsequentR :: forall c m. ( AddBindings c, ExtendPath c Crumb, HasEmptyContext c, LemmaContext c, ReadBindings c , ReadPath c Crumb, HasLemmas m, MonadCatch m, MonadUnique m) => Used -> LemmaName -> Rewrite c m Clause lemmaConsequentR u nm = prefixFailMsg "lemma-consequent failed:" $ withPatFailMsg "lemma is not an implication." $ do (hs,ante,pat) <- (getLemmaByNameT nm >>^ lemmaC) >>= \case Forall bs (Impl _ ante con) -> return (bs,ante,con) Impl _ ante con -> return ([],ante,con) cl' <- transform $ \ c cl -> do m <- maybeM ("consequent did not match.") $ lemmaMatch hs pat cl subs <- maybeM ("some quantifiers not instantiated.") $ mapM (\h -> (h,) <$> lookupVarEnv m h) hs let cl' = substClauses subs ante guardMsg (all (inScope c) $ varSetElems (freeVarsClause cl')) "some variables in result would be out of scope." return cl' markLemmaUsedT nm u return cl' lemmaConsequentBiR :: forall c m. ( AddBindings c, ExtendPath c Crumb, HasCoreRules c, HasEmptyContext c, LemmaContext c , ReadBindings c, ReadPath c Crumb, HasHermitMEnv m, HasLemmas m, LiftCoreM m , MonadCatch m, MonadUnique m) => Used -> LemmaName -> BiRewrite c m CoreExpr lemmaConsequentBiR u nm = afterBiR (beforeBiR (getLemmaByNameT nm) (go [] . lemmaC)) (markLemmaUsedT nm u >> idR) where go :: [CoreBndr] -> Clause -> BiRewrite c m CoreExpr go bbs (Forall bs cl) = go (bbs++bs) cl go bbs (Impl anteNm ante con) = do let con' = mkForall bbs con bs = forallQs con' eqs = toEqualities con' foldUnfold side f = do (cl,e) <- transform $ \ c e -> do let cf = compileFold $ map f eqs (e',hs) <- maybeM ("expression did not match "++side++"-hand side") $ runFoldMatches cf c e let matches = [ case lookupVarEnv hs b of Nothing -> Left b Just arg -> Right (b,arg) | b <- bs ] (unmatched, subs) = partitionEithers matches acl = substClauses subs ante cl = mkForall unmatched acl return (cl,e') verifyOrCreateT u anteNm cl return e bidirectional (foldUnfold "left" id) (foldUnfold "right" flipEquality) go _ _ = let t = fail $ show nm ++ " is not an implication." in bidirectional t t ------------------------------------------------------------------------------ insertLemmaT :: (HasLemmas m, Monad m) => LemmaName -> Lemma -> Transform c m a () insertLemmaT nm l = constT $ insertLemma nm l insertLemmasT :: (HasLemmas m, Monad m) => [NamedLemma] -> Transform c m a () insertLemmasT = constT . mapM_ (uncurry insertLemma) modifyLemmaT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> (LemmaName -> LemmaName) -- ^ modify lemma name -> Rewrite c m Clause -- ^ rewrite the quantified clause -> (Proven -> Proven) -- ^ modify proven status -> (Used -> Used) -- ^ modify used status -> Transform c m a () modifyLemmaT nm nFn rr pFn uFn = do Lemma cl p u <- getLemmaByNameT nm cl' <- rr <<< return cl constT $ insertLemma (nFn nm) $ Lemma cl' (pFn p) (uFn u) markLemmaUsedT :: (LemmaContext c, HasLemmas m, MonadCatch m) => LemmaName -> Used -> Transform c m a () markLemmaUsedT nm u = ifM (lemmaExistsT nm) (modifyLemmaT nm id idR id (const u)) (return ()) markLemmaProvenT :: (LemmaContext c, HasLemmas m, MonadCatch m) => LemmaName -> Proven -> Transform c m a () markLemmaProvenT nm p = ifM (lemmaExistsT nm) (modifyLemmaT nm id idR (const p) id) (return ()) lemmaExistsT :: (LemmaContext c, HasLemmas m, MonadCatch m) => LemmaName -> Transform c m a Bool lemmaExistsT nm = constT $ Map.member nm <$> getLemmas ------------------------------------------------------------------------------ lemmaNameToClauseT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> Transform c m x Clause lemmaNameToClauseT nm = liftM lemmaC $ getLemmaByNameT nm -- | @e@ ==> @let v = lhs in e@ (also works in a similar manner at Program nodes) lemmaLhsIntroR :: LemmaName -> RewriteH Core lemmaLhsIntroR = lemmaNameToClauseT >=> eqLhsIntroR -- | @e@ ==> @let v = rhs in e@ (also works in a similar manner at Program nodes) lemmaRhsIntroR :: LemmaName -> RewriteH Core lemmaRhsIntroR = lemmaNameToClauseT >=> eqRhsIntroR ------------------------------------------------------------------------------ -- Little DSL for building composite lemmas infixr 5 --> (-->) :: Type -> Type -> Type (-->) = mkFunTy infixr 3 ==> (==>) :: (LemmaName, Clause) -> Clause -> Clause (==>) = uncurry Impl infixr 5 /\ (/\) :: Clause -> Clause -> Clause (/\) = Conj infixr 4 \/ (\/) :: Clause -> Clause -> Clause (\/) = Disj infix 8 === (===) :: (ToCoreExpr a, ToCoreExpr b) => a -> b -> Clause lhs === rhs = Equiv (toCE lhs) (toCE rhs) infixl 9 $$ ($$) :: (ToCoreExpr a, ToCoreExpr b, MonadCatch m) => a -> b -> m CoreExpr f $$ e = buildAppM (toCE f) (toCE e) ($$$) :: (ToCoreExpr a, ToCoreExpr b, MonadCatch m) => a -> [b] -> m CoreExpr f $$$ es = buildAppsM (toCE f) (map toCE es) class ToCoreExpr a where toCE :: a -> CoreExpr instance ToCoreExpr CoreExpr where toCE = id instance ToCoreExpr Var where toCE = varToCoreExpr instance ToCoreExpr Type where toCE = Type -- Create new lemma library with single unproven lemma. newLemma :: LemmaName -> Clause -> Map.Map LemmaName Lemma newLemma nm cl = Map.singleton nm (Lemma cl NotProven NotUsed)