{-# LANGUAGE CPP #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} module HERMIT.Dictionary.Reasoning ( -- * Equational Reasoning externals , EqualityProof , eqLhsIntroR , eqRhsIntroR , birewrite , extensionalityR , getLemmasT , getLemmaByNameT , insertLemmaT , insertLemmasT , lemmaBiR , lemmaConsequentR , markLemmaUsedT , markLemmaProvenT , modifyLemmaT , showLemmaT , showLemmasT , ppLemmaT , retraction , mergeQuantifiersR , conjunctLemmasT , disjunctLemmasT , implyLemmasT , lemmaConsequentBiR , lemmaLhsIntroR , lemmaRhsIntroR , splitAntecedentR -- ** Lifting transformations over 'Clause' , lhsT , rhsT , bothT , lhsR , rhsR , bothR , verifyClauseT , lemmaR , quantIdentitiesR , verifyOrCreateT , verifyEqualityLeftToRightT , verifyEqualityCommonTargetT , verifyIsomorphismT , verifyRetractionT , reflexivityR , simplifyClauseR , retractionBR , unshadowClauseR , instantiateDictsR , abstractClauseR , csInQBodyT , instantiateClauseVarR -- * Constructing Composite Lemmas , ($$) , ($$$) , (==>) , (-->) , (===) , (/\) , (\/) , ToCoreExpr(..) , newLemma ) where import Control.Arrow hiding ((<+>)) import Control.Monad import Data.Either (partitionEithers) import Data.List (isInfixOf, nubBy) import qualified Data.Map as Map import Data.Maybe (fromMaybe) import Data.Monoid import HERMIT.Context import HERMIT.Core import HERMIT.External import HERMIT.GHC hiding ((<>), (<+>), nest, ($+$), ($$)) import HERMIT.Kure import HERMIT.Lemma import HERMIT.Monad import HERMIT.Name import HERMIT.ParserCore import HERMIT.ParserType import HERMIT.Utilities import HERMIT.Dictionary.Common import HERMIT.Dictionary.Fold hiding (externals) import HERMIT.Dictionary.Function hiding (externals) import HERMIT.Dictionary.GHC hiding (externals) import HERMIT.Dictionary.Local.Let (nonRecIntroR) import HERMIT.PrettyPrinter.Common import qualified Text.PrettyPrint.MarkedHughesPJ as PP ------------------------------------------------------------------------------ externals :: [External] externals = [ external "retraction" ((\ f g r -> promoteExprBiR $ retraction (Just r) f g) :: CoreString -> CoreString -> RewriteH LCore -> BiRewriteH LCore) [ "Given f :: X -> Y and g :: Y -> X, and a proof that f (g y) ==> y, then" , "f (g y) <==> y." ] .+ Shallow , external "retraction-unsafe" ((\ f g -> promoteExprBiR $ retraction Nothing f g) :: CoreString -> CoreString -> BiRewriteH LCore) [ "Given f :: X -> Y and g :: Y -> X, then" , "f (g y) <==> y." , "Note that the precondition (f (g y) == y) is expected to hold." ] .+ Shallow .+ PreCondition , external "unshadow-quantified" (promoteClauseR unshadowClauseR :: RewriteH LCoreTC) [ "Unshadow a quantified clause." ] , external "merge-quantifiers" (\n1 n2 -> promoteR (mergeQuantifiersR (cmpHN2Var n1) (cmpHN2Var n2)) :: RewriteH LCore) [ "Merge quantifiers from two clauses if they have the same type." , "Example:" , "(forall (x::Int). foo x = x) ^ (forall (y::Int). bar y y = 5)" , "merge-quantifiers 'x 'y" , "forall (x::Int). (foo x = x) ^ (bar x x = 5)" , "Note: if only one quantifier matches, it will be floated if possible." ] , external "float-left" (\n1 -> promoteR (mergeQuantifiersR (cmpHN2Var n1) (const False)) :: RewriteH LCore) [ "Float quantifier out of left-hand side." ] , external "float-right" (\n1 -> promoteR (mergeQuantifiersR (const False) (cmpHN2Var n1)) :: RewriteH LCore) [ "Float quantifier out of right-hand side." ] , external "conjunct" (\n1 n2 n3 -> conjunctLemmasT n1 n2 n3 :: TransformH LCore ()) [ "conjunct new-name lhs-name rhs-name" ] , external "disjunct" (\n1 n2 n3 -> disjunctLemmasT n1 n2 n3 :: TransformH LCore ()) [ "disjunct new-name lhs-name rhs-name" ] , external "imply" (\n1 n2 n3 -> implyLemmasT n1 n2 n3 :: TransformH LCore ()) [ "imply new-name antecedent-name consequent-name" ] , external "lemma-birewrite" (promoteExprBiR . lemmaBiR Obligation :: LemmaName -> BiRewriteH LCore) [ "Generate a bi-directional rewrite from a lemma." ] , external "lemma-forward" (forwardT . promoteExprBiR . lemmaBiR Obligation :: LemmaName -> RewriteH LCore) [ "Generate a rewrite from a lemma, left-to-right." ] , external "lemma-backward" (backwardT . promoteExprBiR . lemmaBiR Obligation :: LemmaName -> RewriteH LCore) [ "Generate a rewrite from a lemma, right-to-left." ] , external "lemma-consequent" (promoteClauseR . lemmaConsequentR Obligation :: LemmaName -> RewriteH LCore) [ "Match the current lemma with the consequent of an implication lemma." , "Upon success, replaces with antecedent of the implication, properly instantiated." ] , external "lemma-consequent-birewrite" (promoteExprBiR . lemmaConsequentBiR Obligation :: LemmaName -> BiRewriteH LCore) [ "Generate a bi-directional rewrite from the consequent of an implication lemma." , "The antecedent is instantiated and introduced as an unproven obligation." ] , external "lemma-lhs-intro" (promoteCoreR . lemmaLhsIntroR :: LemmaName -> RewriteH LCore) [ "Introduce the LHS of a lemma as a non-recursive binding, in either an expression or a program." , "body ==> let v = lhs in body" ] .+ Introduce .+ Shallow , external "lemma-rhs-intro" (promoteCoreR . lemmaRhsIntroR :: LemmaName -> RewriteH LCore) [ "Introduce the RHS of a lemma as a non-recursive binding, in either an expression or a program." , "body ==> let v = rhs in body" ] .+ Introduce .+ Shallow , external "inst-lemma" (\ nm v cs -> modifyLemmaT nm id (instantiateClauseVarR (cmpHN2Var v) cs) id id :: TransformH LCore ()) [ "Instantiate one of the universally quantified variables of the given lemma," , "with the given Core expression, creating a new lemma. Instantiating an" , "already proven lemma will result in the new lemma being considered proven." ] , external "inst-dictionaries" (promoteClauseR instantiateDictsR :: RewriteH LCore) [ "Instantiate all of the universally quantified dictionaries of the given lemma." ] , external "abstract-forall" ((\nm -> promoteClauseR . abstractClauseR nm . csInQBodyT) :: String -> CoreString -> RewriteH LCore) [ "Weaken a lemma by abstracting an expression to a new quantifier." ] , external "abstract-forall" ((\nm rr -> promoteClauseR $ abstractClauseR nm $ extractT rr >>> setFailMsg "path must focus on an expression" projectT) :: String -> RewriteH LCore -> RewriteH LCore) [ "Weaken a lemma by abstracting an expression to a new quantifier." ] , external "copy-lemma" (\ nm newName -> modifyLemmaT nm (const newName) idR id id :: TransformH LCore ()) [ "Copy a given lemma, with a new name." ] , external "modify-lemma" ((\ nm rr -> modifyLemmaT nm id (extractR rr) (const NotProven) (const NotUsed)) :: LemmaName -> RewriteH LCore -> TransformH LCore ()) [ "Modify a given lemma. Resets proven status to Not Proven and used status to Not Used." ] , external "query-lemma" ((\ nm t -> getLemmaByNameT nm >>> arr lemmaC >>> extractT t) :: LemmaName -> TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to a lemma, returning the result." ] , external "show-lemma" ((\pp n -> showLemmaT n pp) :: PrettyPrinter -> LemmaName -> PrettyH LCore) [ "Display a lemma." ] , external "show-lemmas" ((\pp n -> showLemmasT (Just n) pp) :: PrettyPrinter -> LemmaName -> PrettyH LCore) [ "List lemmas whose names match search string." ] , external "show-lemmas" (showLemmasT Nothing :: PrettyPrinter -> PrettyH LCore) [ "List lemmas." ] , external "extensionality" (promoteR . extensionalityR . Just :: String -> RewriteH LCore) [ "Given a name 'x, then" , "f == g ==> forall x. f x == g x" ] , external "extensionality" (promoteR (extensionalityR Nothing) :: RewriteH LCore) [ "f == g ==> forall x. f x == g x" ] , external "lhs" (promoteClauseT . lhsT :: TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to the LHS of a quantified clause." ] , external "lhs" (promoteClauseR . lhsR :: RewriteH LCore -> RewriteH LCore) [ "Apply a rewrite to the LHS of a quantified clause." ] , external "rhs" (promoteClauseT . rhsT :: TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to the RHS of a quantified clause." ] , external "rhs" (promoteClauseR . rhsR :: RewriteH LCore -> RewriteH LCore) [ "Apply a rewrite to the RHS of a quantified clause." ] , external "both" (promoteClauseR . bothR :: RewriteH LCore -> RewriteH LCore) [ "Apply a rewrite to both sides of an equality, succeeding if either succeed." ] , external "both" ((\t -> do (r,s) <- promoteClauseT (bothT t); return (unlines [r,s])) :: TransformH LCore String -> TransformH LCore String) [ "Apply a transformation to both sides of a quantified clause." ] , external "reflexivity" (promoteClauseR (reflexivityR <+ forallR idR reflexivityR) :: RewriteH LCore) [ "Rewrite alpha-equivalence to true." ] , external "simplify-lemma" (simplifyClauseR :: RewriteH LCore) [ "Reduce a proof by applying reflexivity and logical operator identities." ] , external "split-antecedent" (promoteClauseR splitAntecedentR :: RewriteH LCore) [ "Split an implication of the form (q1 ^ q2) => q3 into q1 => (q2 => q3)" ] , external "lemma" (promoteClauseR . lemmaR Obligation :: LemmaName -> RewriteH LCore) [ "Rewrite clause to true using given lemma." ] , external "lemma-unsafe" (promoteClauseR . lemmaR UnsafeUsed :: LemmaName -> RewriteH LCore) [ "Rewrite clause to true using given lemma." ] .+ Unsafe ] ------------------------------------------------------------------------------ type EqualityProof c m = (Rewrite c m CoreExpr, Rewrite c m CoreExpr) -- | f == g ==> forall x. f x == g x extensionalityR :: Maybe String -> Rewrite c HermitM Clause extensionalityR mn = prefixFailMsg "extensionality failed: " $ do (vs, Equiv lhs rhs) <- arr collectQs let tyL = exprKindOrType lhs tyR = exprKindOrType rhs guardMsg (tyL `typeAlphaEq` tyR) "type mismatch between sides of equality. This shouldn't happen, so is probably a bug." -- TODO: use the fresh-name-generator in AlphaConversion to avoid shadowing. (_,argTy,_) <- splitFunTypeM tyL v <- constT $ newIdH (fromMaybe "x" mn) argTy let x = varToCoreExpr v #if __GLASGOW_HASKELL__ > 710 return $ mkForall vs $ Forall v $ Equiv (mkCoreApp (text "extensionalityR-lhs") lhs x) (mkCoreApp (text "extensionalityR-rhs") rhs x) #else return $ mkForall vs $ Forall v $ Equiv (mkCoreApp lhs x) (mkCoreApp rhs x) #endif ------------------------------------------------------------------------------ -- | @e@ ==> @let v = lhs in e@ eqLhsIntroR :: Clause -> Rewrite c HermitM Core eqLhsIntroR cl | (bs, Equiv lhs _) <- collectQs cl = nonRecIntroR "lhs" (mkCoreLams bs lhs) eqLhsIntroR _ = fail "compound lemmas not supported." -- | @e@ ==> @let v = rhs in e@ eqRhsIntroR :: Clause -> Rewrite c HermitM Core eqRhsIntroR cl | (bs, Equiv _ rhs) <- collectQs cl = nonRecIntroR "rhs" (mkCoreLams bs rhs) eqRhsIntroR _ = fail "compound lemmas not supported." ------------------------------------------------------------------------------ -- | Create a 'BiRewrite' from a 'Clause'. birewrite :: (ReadBindings c, MonadCatch m) => Clause -> BiRewrite c m CoreExpr birewrite cl = bidirectional (foldUnfold "left" id) (foldUnfold "right" flipEquality) where foldUnfold side f = transform $ \ c -> maybeM ("expression did not match "++side++"-hand side") . fold (map f (toEqualities cl)) c ------------------------------------------------------------------------------ -- TODO: deprecate these? -- Yes, but later. They're in the paper now. -- We should be using "childR crumb", really. -- | Lift a transformation over 'LCoreTC' into a transformation over the left-hand side of a 'Clause'. lhsT :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Transform c m LCore a -> Transform c m Clause a lhsT t = extractT $ catchesM $ childT Forall_Body (promoteT $ lhsT t) : [ childT cr t | cr <- [Conj_Lhs, Disj_Lhs, Impl_Lhs, Eq_Lhs] ] -- | Lift a transformation over 'LCoreTC' into a transformation over the right-hand side of a 'Clause'. rhsT :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Transform c m LCore a -> Transform c m Clause a rhsT t = extractT $ catchesM $ childT Forall_Body (promoteT $ rhsT t) : [ childT cr t | cr <- [Conj_Rhs, Disj_Rhs, Impl_Rhs, Eq_Rhs] ] -- | Lift a transformation over 'LCoreTC' into a transformation over both sides of a 'Clause'. bothT :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Transform c m LCore a -> Transform c m Clause (a, a) bothT t = (,) <$> lhsT t <*> rhsT t -- | Lift a rewrite over 'LCoreTC' into a rewrite over the left-hand side of a 'Clause'. lhsR :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Rewrite c m LCore -> Rewrite c m Clause lhsR r = extractR $ catchesM $ childR Forall_Body (promoteR $ lhsR r) : [ childR cr r | cr <- [Conj_Lhs, Disj_Lhs, Impl_Lhs, Eq_Lhs] ] -- | Lift a rewrite over 'LCoreTC' into a rewrite over the right-hand side of a 'Clause'. rhsR :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Rewrite c m LCore -> Rewrite c m Clause rhsR r = extractR $ catchesM $ childR Forall_Body (promoteR $ rhsR r) : [ childR cr r | cr <- [Conj_Rhs, Disj_Rhs, Impl_Rhs, Eq_Rhs] ] -- | Lift a rewrite over 'LCoreTC' into a rewrite over both sides of a 'Clause'. bothR :: (AddBindings c, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, ExtendPath c Crumb, MonadCatch m) => Rewrite c m LCore -> Rewrite c m Clause bothR r = lhsR r >+> rhsR r ------------------------------------------------------------------------------ showLemmasT :: Maybe LemmaName -> PrettyPrinter -> PrettyH a showLemmasT mnm pp = do ls <- getLemmasT let ls' = Map.toList $ Map.filterWithKey (maybe (\ _ _ -> True) (\ nm n _ -> show nm `isInfixOf` show n) mnm) ls ds <- forM ls' $ \(nm,l) -> return l >>> ppLemmaT pp nm return $ PP.vcat ds showLemmaT :: LemmaName -> PrettyPrinter -> PrettyH a showLemmaT nm pp = getLemmaByNameT nm >>> ppLemmaT pp nm ppLemmaT :: PrettyPrinter -> LemmaName -> PrettyH Lemma ppLemmaT pp nm = do Lemma q p _u <- idR qDoc <- return q >>> extractT (pLCoreTC pp) let hDoc = PP.text (show nm) PP.<+> PP.text ("(" ++ show p ++ ")") return $ hDoc PP.$+$ PP.nest 2 qDoc ------------------------------------------------------------------------------ verifyClauseT :: MonadCatch m => Transform c m Clause () verifyClauseT = setFailMsg "verification failed: clause must be true (perhaps try reflexivity first)" $ do CTrue <- idR return () lemmaR :: (LemmaContext c, HasLemmas m, MonadCatch m) => Used -> LemmaName -> Rewrite c m Clause lemmaR used nm = prefixFailMsg "verification failed: " $ do Lemma cl _ _ <- getLemmaByNameT nm eq <- arr (cl `proves`) guardMsg eq "lemmas are not equivalent." markLemmaUsedT nm used return CTrue verifyOrCreateT :: ( HasCoreRules c, LemmaContext c, ReadBindings c, ReadPath c Crumb , HasHermitMEnv m, HasLemmas m, LiftCoreM m, MonadCatch m ) => Used -> LemmaName -> Clause -> Transform c m a () verifyOrCreateT u nm cl = do exists <- testM $ getLemmaByNameT nm if exists then return cl >>> lemmaR u nm >>> verifyClauseT else contextonlyT $ \ c -> sendKEnvMessage $ AddObligation (toHermitC c) nm $ Lemma cl NotProven u reflexivityR :: MonadCatch m => Rewrite c m Clause reflexivityR = withPatFailMsg "reflexivity may only be applied to equivalence lemmas" $ do Equiv lhs rhs <- idR guardMsg (exprAlphaEq lhs rhs) "the two sides are not alpha-equivalent." return CTrue simplifyClauseR :: (AddBindings c, ExtendPath c Crumb, HasEmptyContext c, LemmaContext c, ReadPath c Crumb, MonadCatch m) => Rewrite c m LCore simplifyClauseR = anybuR (promoteR quantIdentitiesR <+ promoteR reflexivityR) quantIdentitiesR :: MonadCatch m => Rewrite c m Clause quantIdentitiesR = trueConjLR <+ trueConjRR <+ trueDisjLR <+ trueDisjRR <+ trueImpliesR <+ impliesTrueR <+ aImpliesAR <+ forallTrueR trueConjLR :: Monad m => Rewrite c m Clause trueConjLR = do Conj CTrue cl <- idR return cl trueConjRR :: Monad m => Rewrite c m Clause trueConjRR = do Conj cl CTrue <- idR return cl trueDisjLR :: Monad m => Rewrite c m Clause trueDisjLR = do Disj CTrue _ <- idR return CTrue trueDisjRR :: Monad m => Rewrite c m Clause trueDisjRR = do Disj _ CTrue <- idR return CTrue trueImpliesR :: Monad m => Rewrite c m Clause trueImpliesR = do Impl _ CTrue cl <- idR return cl impliesTrueR :: Monad m => Rewrite c m Clause impliesTrueR = do Impl _ _ CTrue <- idR return CTrue forallTrueR :: Monad m => Rewrite c m Clause forallTrueR = do Forall _ CTrue <- idR return CTrue aImpliesAR :: Monad m => Rewrite c m Clause aImpliesAR = do Impl _ a c <- idR guardMsg (a `proves` c) "antecedent does not prove consequent." return CTrue splitAntecedentR :: MonadCatch m => Rewrite c m Clause splitAntecedentR = prefixFailMsg "antecedent split failed: " $ withPatFailMsg (wrongExprForm "(ante1 ^ ante2) => con") $ do Impl nm (Conj c1 c2) con <- idR return $ Impl (nm <> "0") c1 $ Impl (nm <> "1") c2 con ------------------------------------------------------------------------------ -- TODO: everything between here and instantiateDictsR needs to be rethought/removed -- TODO: this is used in century plugin, but otherwise should be removed -- | Given two expressions, and a rewrite from the former to the latter, verify that rewrite. verifyEqualityLeftToRightT :: MonadCatch m => CoreExpr -> CoreExpr -> Rewrite c m CoreExpr -> Transform c m a () verifyEqualityLeftToRightT sourceExpr targetExpr r = prefixFailMsg "equality verification failed: " $ do resultExpr <- r <<< return sourceExpr guardMsg (exprAlphaEq targetExpr resultExpr) "result of running proof on lhs of equality does not match rhs of equality." -- | Given two expressions, and a rewrite to apply to each, verify that the resulting expressions are equal. verifyEqualityCommonTargetT :: MonadCatch m => CoreExpr -> CoreExpr -> EqualityProof c m -> Transform c m a () verifyEqualityCommonTargetT lhs rhs (l,r) = prefixFailMsg "equality verification failed: " $ do lhsResult <- l <<< return lhs rhsResult <- r <<< return rhs guardMsg (exprAlphaEq lhsResult rhsResult) "results of running proofs on both sides of equality do not match." ------------------------------------------------------------------------------ -- Note: We use global Ids for verification to avoid out-of-scope errors. -- | Given f :: X -> Y and g :: Y -> X, verify that f (g y) ==> y and g (f x) ==> x. verifyIsomorphismT :: CoreExpr -> CoreExpr -> Rewrite c HermitM CoreExpr -> Rewrite c HermitM CoreExpr -> Transform c HermitM a () verifyIsomorphismT f g fgR gfR = prefixFailMsg "Isomorphism verification failed: " $ do (tyX, tyY) <- funExprsWithInverseTypes f g x <- constT (newGlobalIdH "x" tyX) y <- constT (newGlobalIdH "y" tyY) verifyEqualityLeftToRightT (App f (App g (Var y))) (Var y) fgR verifyEqualityLeftToRightT (App g (App f (Var x))) (Var x) gfR -- | Given f :: X -> Y and g :: Y -> X, verify that f (g y) ==> y. verifyRetractionT :: CoreExpr -> CoreExpr -> Rewrite c HermitM CoreExpr -> Transform c HermitM a () verifyRetractionT f g r = prefixFailMsg "Retraction verification failed: " $ do (_tyX, tyY) <- funExprsWithInverseTypes f g y <- constT (newGlobalIdH "y" tyY) let lhs = App f (App g (Var y)) rhs = Var y verifyEqualityLeftToRightT lhs rhs r ------------------------------------------------------------------------------ -- | Given f :: X -> Y and g :: Y -> X, and a proof that f (g y) ==> y, then f (g y) <==> y. retractionBR :: forall c. Maybe (Rewrite c HermitM CoreExpr) -> CoreExpr -> CoreExpr -> BiRewrite c HermitM CoreExpr retractionBR mr f g = beforeBiR (prefixFailMsg "Retraction failed: " $ do whenJust (verifyRetractionT f g) mr y <- idR (_, tyY) <- funExprsWithInverseTypes f g guardMsg (exprKindOrType y `typeAlphaEq` tyY) "type of expression does not match given retraction components." return y ) (\ y -> bidirectional retractionL (return $ App f (App g y)) ) where retractionL :: Rewrite c HermitM CoreExpr retractionL = prefixFailMsg "Retraction failed: " $ withPatFailMsg (wrongExprForm "App f (App g y)") $ do App f' (App g' y) <- idR guardMsg (exprAlphaEq f f' && exprAlphaEq g g') "given retraction components do not match current expression." return y -- | Given @f :: X -> Y@ and @g :: Y -> X@, and a proof that @f (g y)@ ==> @y@, then @f (g y)@ <==> @y@. retraction :: Maybe (RewriteH LCore) -> CoreString -> CoreString -> BiRewriteH CoreExpr retraction mr = parse2beforeBiR (retractionBR (extractR <$> mr)) ------------------------------------------------------------------------------ -- TODO: revisit this and rewrite to act only on current quantifer? (more KURE-like) instantiateDictsR :: RewriteH Clause instantiateDictsR = prefixFailMsg "Dictionary instantiation failed: " $ do (bs,_) <- arr collectQs let dArgs = filter (\b -> isId b && isDictTy (varType b)) bs uniqDs = nubBy (\ b1 b2 -> eqType (varType b1) (varType b2)) dArgs guardMsg (not (null uniqDs)) "no universally quantified dictionaries can be instantiated." ds <- forM uniqDs $ \ b -> constT $ do (i,bnds) <- buildDictionary b let dExpr = case bnds of -- the common case that we would have gotten a single non-recursive let [NonRec v e] | i == v -> e _ -> mkCoreLets bnds (varToCoreExpr i) return (b,dExpr) let buildSubst :: Monad m => Var -> m (Var, CoreExpr) buildSubst b = case [ (b,e) | (b',e) <- ds, eqType (varType b) (varType b') ] of [] -> fail "cannot find equivalent dictionary expression (impossible!)" [t] -> return t _ -> fail "multiple dictionary expressions found (impossible!)" lookup2 :: Var -> [(Var,CoreExpr)] -> (Var,CoreExpr) lookup2 v l = head [ t | t@(v',_) <- l, v == v' ] allDs <- forM dArgs $ \ b -> constT $ do if b `elem` uniqDs then return $ lookup2 b ds else buildSubst b transform (\ c -> instsClause (boundVars c) allDs) >>> arr redundantDicts ------------------------------------------------------------------------------ conjunctLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> LemmaName -> LemmaName -> Transform c m a () conjunctLemmasT new lhs rhs = do Lemma ql pl _ <- getLemmaByNameT lhs Lemma qr pr _ <- getLemmaByNameT rhs insertLemmaT new $ Lemma (Conj ql qr) (pl `andP` pr) NotUsed disjunctLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> LemmaName -> LemmaName -> Transform c m a () disjunctLemmasT new lhs rhs = do Lemma ql pl _ <- getLemmaByNameT lhs Lemma qr pr _ <- getLemmaByNameT rhs insertLemmaT new $ Lemma (Disj ql qr) (pl `orP` pr) NotUsed implyLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> LemmaName -> LemmaName -> Transform c m a () implyLemmasT new lhs rhs = do Lemma ql _ _ <- getLemmaByNameT lhs Lemma qr pr _ <- getLemmaByNameT rhs insertLemmaT new $ Lemma (Impl lhs ql qr) pr NotUsed ------------------------------------------------------------------------------ mergeQuantifiersR :: MonadCatch m => (Var -> Bool) -> (Var -> Bool) -> Rewrite c m Clause mergeQuantifiersR pl pr = contextfreeT $ mergeQuantifiers pl pr mergeQuantifiers :: MonadCatch m => (Var -> Bool) -> (Var -> Bool) -> Clause -> m Clause mergeQuantifiers pl pr cl = prefixFailMsg "merge-quantifiers failed: " $ do (con,lq,rq) <- case cl of Conj q1 q2 -> return (Conj,q1,q2) Disj q1 q2 -> return (Disj,q1,q2) Impl nm q1 q2 -> return (Impl nm,q1,q2) _ -> fail "no quantifiers on either side." let (bsl, cll) = collectQs lq (bsr, clr) = collectQs rq (lBefore,lbs) = break pl bsl (rBefore,rbs) = break pr bsr check b q l r = guardMsg (not (b `elemVarSet` freeVarsClause q)) $ "specified "++l++" binder would capture in "++r++"-hand clause." checkUB v vs = let fvs = freeVarsVar v in guardMsg (not (any (`elemVarSet` fvs) vs)) $ "binder " ++ getOccString v ++ " cannot be floated because it depends on binders not being floated." case (lbs,rbs) of ([],[]) -> fail "no quantifiers match." ([],rb:rAfter) -> do check rb lq "right" "left" checkUB rb rBefore return $ mkForall [rb] $ con lq (mkForall (rBefore++rAfter) clr) (lb:lAfter,[]) -> do check lb rq "left" "right" checkUB lb lBefore return $ mkForall [lb] $ con (mkForall (lBefore++lAfter) cll) rq (lb:lAfter,rb:rAfter) -> do guardMsg (eqType (varType lb) (varType rb)) "specified quantifiers have differing types." check lb rq "left" "right" check rb lq "right" "left" checkUB lb lBefore checkUB rb rBefore let clr' = substClause rb (varToCoreExpr lb) $ mkForall rAfter clr rq' = mkForall rBefore clr' lq' = mkForall (lBefore ++ lAfter) cll return $ mkForall [lb] (con lq' rq') ------------------------------------------------------------------------------ unshadowClauseR :: MonadUnique m => Rewrite c m Clause unshadowClauseR = contextfreeT unshadowClause unshadowClause :: MonadUnique m => Clause -> m Clause unshadowClause c = go emptySubst (mapUniqSet fs (freeVarsClause c)) c where fs = occNameFS . getOccName go subst seen (Forall b cl) | fsb `elementOfUniqSet` seen = do b'' <- cloneVarFSH (inventNames seen) b' cl' <- go (extendSubst subst' b' (varToCoreExpr b'')) (addOneToUniqSet seen (fs b'')) cl return $ addBinder b'' cl' | otherwise = do cl' <- go subst' (addOneToUniqSet seen fsb) cl return $ addBinder b' cl' where fsb = fs b' (subst', b') = substBndr subst b go subst seen (Conj q1 q2) = do q1' <- go subst seen q1 q2' <- go subst seen q2 return $ Conj q1' q2' go subst seen (Disj q1 q2) = do q1' <- go subst seen q1 q2' <- go subst seen q2 return $ Disj q1' q2' go subst seen (Impl nm q1 q2) = do q1' <- go subst seen q1 q2' <- go subst seen q2 return $ Impl nm q1' q2' go subst _ (Equiv e1 e2) = let e1' = substExpr (text "unshadowClause e1") subst e1 e2' = substExpr (text "unshadowClause e2") subst e2 in return $ Equiv e1' e2' go _ _ CTrue = return CTrue inventNames :: UniqSet FastString -> FastString -> FastString inventNames s nm = head [ nm' | i :: Int <- [0..] , let nm' = nm `appendFS` (mkFastString (show i)) , not (nm' `elementOfUniqSet` s) ] ------------------------------------------------------------------------------ -- TODO: revisit design of this, it's ugly instantiateClauseVarR :: (Var -> Bool) -> CoreString -> RewriteH Clause instantiateClauseVarR p cs = setFailMsg "instantiation failed: no quantifier matches" $ extractR (onetdR (promoteClauseR $ instantiateForallVarR p cs) :: RewriteH LCore) instantiateForallVarR :: (Var -> Bool) -> CoreString -> RewriteH Clause instantiateForallVarR p cs = prefixFailMsg "instantiation failed: " $ do Forall b _ <- idR guardMsg (p b) "universally quantified variable does not match predicate." e <- if isId b then parseCoreExprT cs else liftM (Type . fst) $ parseTypeWithHolesT cs transform (\ c -> instClause (boundVars c) p e) >>> (lintClauseT >> idR) -- lint for sanity ------------------------------------------------------------------------------ -- | Replace all occurrences of the given expression with a new quantified variable. abstractClauseR :: forall c m. ( AddBindings c, ExtendPath c Crumb, HasEmptyContext c, ReadBindings c, ReadPath c Crumb , LemmaContext c, LiftCoreM m, MonadCatch m, MonadUnique m ) => String -> Transform c m Clause CoreExpr -> Rewrite c m Clause abstractClauseR nm tr = prefixFailMsg "abstraction failed: " $ do e <- tr cl <- idR b <- constT $ case e of Type _ -> newTyVarH nm (exprKindOrType e) _ -> newIdH nm (exprKindOrType e) let f = compileFold [Equality [] e (varToCoreExpr b)] -- we don't use mkEquality on purpose, so we can abstract lambdas liftM dropBinders $ return (mkForall [b] cl) >>> extractR (anytdR $ promoteExprR $ runFoldR f :: Rewrite c m LCoreTC) csInQBodyT :: ( AddBindings c, ExtendPath c Crumb, ReadBindings c, ReadPath c Crumb, HasHermitMEnv m, HasLemmas m, LiftCoreM m ) => CoreString -> Transform c m Clause CoreExpr csInQBodyT cs = forallT successT (parseCoreExprT cs) (flip const) ------------------------------------------------------------------------------ getLemmasT :: (LemmaContext c, HasLemmas m, Monad m) => Transform c m x Lemmas getLemmasT = contextonlyT $ \ c -> liftM (Map.union (getAntecedents c)) getLemmas getLemmaByNameT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> Transform c m x Lemma getLemmaByNameT nm = getLemmasT >>= maybe (fail $ "No lemma named: " ++ show nm) return . Map.lookup nm ------------------------------------------------------------------------------ lemmaBiR :: (LemmaContext c, ReadBindings c, HasLemmas m, MonadCatch m) => Used -> LemmaName -> BiRewrite c m CoreExpr lemmaBiR u nm = afterBiR (beforeBiR (getLemmaByNameT nm) (birewrite . lemmaC)) (markLemmaUsedT nm u >> idR) lemmaConsequentR :: forall c m. (LemmaContext c, ReadBindings c, HasLemmas m, MonadCatch m) => Used -> LemmaName -> Rewrite c m Clause lemmaConsequentR u nm = prefixFailMsg "lemma-consequent failed:" $ withPatFailMsg "lemma is not an implication." $ do (hs, Impl _ ante pat) <- getLemmaByNameT nm >>^ (collectQs . lemmaC) cl' <- transform $ \ c cl -> do m <- maybeM ("consequent did not match.") $ lemmaMatch hs pat cl subs <- maybeM ("some quantifiers not instantiated.") $ mapM (\h -> (h,) <$> lookupVarEnv m h) hs let cl' = substClauses subs ante guardMsg (all (inScope c) $ varSetElems (freeVarsClause cl')) "some variables in result would be out of scope." return cl' markLemmaUsedT nm u return cl' lemmaConsequentBiR :: forall c m. ( HasCoreRules c, LemmaContext c , ReadBindings c, ReadPath c Crumb, HasHermitMEnv m, HasLemmas m, LiftCoreM m , MonadCatch m) => Used -> LemmaName -> BiRewrite c m CoreExpr lemmaConsequentBiR u nm = afterBiR (beforeBiR (getLemmaByNameT nm) (go [] . lemmaC)) (markLemmaUsedT nm u >> idR) where go :: [CoreBndr] -> Clause -> BiRewrite c m CoreExpr go bs (Forall b cl) = go (b:bs) cl go bs (Impl anteNm ante con) = do let con' = mkForall (reverse bs) con (bs',_) = collectQs con' eqs = toEqualities con' foldUnfold side f = do (cl,e) <- transform $ \ c e -> do let cf = compileFold $ map f eqs (e',hs) <- maybeM ("expression did not match "++side++"-hand side") $ runFoldMatches cf c e let matches = [ case lookupVarEnv hs b of Nothing -> Left b Just arg -> Right (b,arg) | b <- bs' ] (unmatched, subs) = partitionEithers matches acl = substClauses subs ante return (mkForall unmatched acl, e') verifyOrCreateT u anteNm cl return e bidirectional (foldUnfold "left" id) (foldUnfold "right" flipEquality) go _ _ = let t = fail $ show nm ++ " is not an implication." in bidirectional t t ------------------------------------------------------------------------------ insertLemmaT :: (HasLemmas m, Monad m) => LemmaName -> Lemma -> Transform c m a () insertLemmaT nm l = constT $ insertLemma nm l insertLemmasT :: (HasLemmas m, Monad m) => [NamedLemma] -> Transform c m a () insertLemmasT = constT . mapM_ (uncurry insertLemma) modifyLemmaT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> (LemmaName -> LemmaName) -- ^ modify lemma name -> Rewrite c m Clause -- ^ rewrite the quantified clause -> (Proven -> Proven) -- ^ modify proven status -> (Used -> Used) -- ^ modify used status -> Transform c m a () modifyLemmaT nm nFn rr pFn uFn = do Lemma cl p u <- getLemmaByNameT nm cl' <- rr <<< return cl constT $ insertLemma (nFn nm) $ Lemma cl' (pFn p) (uFn u) markLemmaUsedT :: (LemmaContext c, HasLemmas m, MonadCatch m) => LemmaName -> Used -> Transform c m a () markLemmaUsedT nm u = ifM (lemmaExistsT nm) (modifyLemmaT nm id idR id (const u)) (return ()) markLemmaProvenT :: (LemmaContext c, HasLemmas m, MonadCatch m) => LemmaName -> Proven -> Transform c m a () markLemmaProvenT nm p = ifM (lemmaExistsT nm) (modifyLemmaT nm id idR (const p) id) (return ()) lemmaExistsT :: (HasLemmas m, MonadCatch m) => LemmaName -> Transform c m a Bool lemmaExistsT nm = constT $ Map.member nm <$> getLemmas ------------------------------------------------------------------------------ lemmaNameToClauseT :: (LemmaContext c, HasLemmas m, Monad m) => LemmaName -> Transform c m x Clause lemmaNameToClauseT nm = liftM lemmaC $ getLemmaByNameT nm -- | @e@ ==> @let v = lhs in e@ (also works in a similar manner at Program nodes) lemmaLhsIntroR :: LemmaName -> RewriteH Core lemmaLhsIntroR = lemmaNameToClauseT >=> eqLhsIntroR -- | @e@ ==> @let v = rhs in e@ (also works in a similar manner at Program nodes) lemmaRhsIntroR :: LemmaName -> RewriteH Core lemmaRhsIntroR = lemmaNameToClauseT >=> eqRhsIntroR ------------------------------------------------------------------------------ -- Little DSL for building composite lemmas infixr 5 --> (-->) :: Type -> Type -> Type (-->) = mkFunTy infixr 3 ==> (==>) :: (LemmaName, Clause) -> Clause -> Clause (==>) = uncurry Impl infixr 5 /\ -- this comment is required to avoid a CPP issue with backslash (/\) :: Clause -> Clause -> Clause (/\) = Conj infixr 4 \/ (\/) :: Clause -> Clause -> Clause (\/) = Disj infix 8 === (===) :: (ToCoreExpr a, ToCoreExpr b) => a -> b -> Clause lhs === rhs = Equiv (toCE lhs) (toCE rhs) infixl 9 $$ ($$) :: (ToCoreExpr a, ToCoreExpr b, MonadCatch m) => a -> b -> m CoreExpr f $$ e = buildAppM (toCE f) (toCE e) ($$$) :: (ToCoreExpr a, ToCoreExpr b, MonadCatch m) => a -> [b] -> m CoreExpr f $$$ es = buildAppsM (toCE f) (map toCE es) class ToCoreExpr a where toCE :: a -> CoreExpr instance ToCoreExpr CoreExpr where toCE = id instance ToCoreExpr Var where toCE = varToCoreExpr instance ToCoreExpr Type where toCE = Type -- Create new lemma library with single unproven lemma. newLemma :: LemmaName -> Clause -> Map.Map LemmaName Lemma newLemma nm cl = Map.singleton nm (Lemma cl NotProven NotUsed)