{-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE NoImplicitPrelude #-} module HERMIT.Lemma ( -- * Clause Clause(..) , mkClause , mkForall , forallQs , instClause , instsClause , discardUniVars , freeVarsClause , clauseSyntaxEq , substClause , substClauses , dropBinders , redundantDicts -- * Lemmas , LemmaName(..) , Lemma(..) , Proven(..) , andP, orP , Used(..) , Lemmas , NamedLemma ) where import Prelude.Compat hiding (lookup) import Control.Monad import Data.Dynamic (Typeable) import Data.String (IsString(..)) import qualified Data.Map as M import HERMIT.Core import HERMIT.GHC hiding ((<>)) import Language.KURE.MonadCatch ---------------------------------------------------------------------------- -- | Build a Clause from a list of universally quantified binders and two expressions. -- If the head of either expression is a lambda expression, it's binder will become a universally quantified binder -- over both sides. It is assumed the two expressions have the same type. -- -- Ex. mkClause [] (\x. foo x) bar === forall x. foo x = bar x -- mkClause [] (baz y z) (\x. foo x x) === forall x. baz y z x = foo x x -- mkClause [] (\x. foo x) (\y. bar y) === forall x. foo x = bar x mkClause :: [CoreBndr] -> CoreExpr -> CoreExpr -> Clause mkClause vs lhs rhs = redundantDicts $ dropBinders $ Forall (tvs++vs++lbs++rbs) (Equiv lhs' rbody) where (lbs, lbody) = collectBinders lhs rhs' = uncurry mkCoreApps $ betaReduceAll rhs $ map varToCoreExpr lbs (rbs, rbody) = collectBinders rhs' lhs' = mkCoreApps lbody $ map varToCoreExpr rbs -- now quantify over the free type variables tvs = varSetElems $ filterVarSet isTyVar $ delVarSetList (unionVarSets $ map freeVarsExpr [lhs',rbody]) (vs++lbs++rbs) freeVarsClause :: Clause -> VarSet freeVarsClause (Forall bs cl) = delVarSetList (freeVarsClause cl) bs freeVarsClause (Conj q1 q2) = unionVarSets $ map freeVarsClause [q1,q2] freeVarsClause (Disj q1 q2) = unionVarSets $ map freeVarsClause [q1,q2] freeVarsClause (Impl _ q1 q2) = unionVarSets $ map freeVarsClause [q1,q2] freeVarsClause (Equiv e1 e2) = unionVarSets $ map freeVarsExpr [e1,e2] freeVarsClause CTrue = emptyVarSet dropBinders :: Clause -> Clause dropBinders (Forall bs cl) = case bs of [] -> dropBinders cl (b:bs') -> let c = dropBinders (mkForall bs' cl) in if b `elemVarSet` freeVarsClause c then addBinder b c else c dropBinders (Conj q1 q2) = Conj (dropBinders q1) (dropBinders q2) dropBinders (Disj q1 q2) = Disj (dropBinders q1) (dropBinders q2) dropBinders (Impl nm q1 q2) = Impl nm (dropBinders q1) (dropBinders q2) dropBinders other = other addBinder :: CoreBndr -> Clause -> Clause addBinder b = mkForall [b] mkForall :: [CoreBndr] -> Clause -> Clause mkForall bs (Forall bs' cl) = Forall (bs++bs') cl mkForall bs cl = Forall bs cl forallQs :: Clause -> [CoreBndr] forallQs (Forall bs _) = bs forallQs _ = [] -- | A name for lemmas. Use a newtype so we can tab-complete in shell. newtype LemmaName = LemmaName String deriving (Eq, Ord, Typeable) instance Monoid LemmaName where mempty = LemmaName mempty mappend (LemmaName n1) (LemmaName n2) = LemmaName (mappend n1 n2) instance IsString LemmaName where fromString = LemmaName instance Show LemmaName where show (LemmaName s) = s -- | An equality with a proven/used status. data Lemma = Lemma { lemmaC :: Clause , lemmaP :: Proven -- whether lemma has been proven , lemmaU :: Used -- whether lemma has been used } data Proven = Proven | Assumed -- ^ Assumed by user | BuiltIn -- ^ Assumed by library/HERMIT | NotProven deriving (Eq, Typeable) instance Show Proven where show Proven = "Proven" show Assumed = "Assumed" show BuiltIn = "Built In" show NotProven = "Not Proven" instance Enum Proven where toEnum 1 = Assumed toEnum 2 = BuiltIn toEnum 3 = Proven toEnum _ = NotProven fromEnum NotProven = 0 fromEnum Assumed = 1 fromEnum BuiltIn = 2 fromEnum Proven = 3 -- Ordering: NotProven < Assumed < BuiltIn < Proven instance Ord Proven where compare :: Proven -> Proven -> Ordering compare p1 p2 = compare (fromEnum p1) (fromEnum p2) -- When conjuncting, result is as proven as the least of the two andP :: Proven -> Proven -> Proven andP = min -- When disjuncting, result is as proven as the most of the two orP :: Proven -> Proven -> Proven orP = max data Used = Obligation -- ^ this MUST be proven immediately | UnsafeUsed -- ^ used, but can be proven later (only introduced in unsafe shell) | NotUsed deriving (Eq, Typeable) instance Show Used where show Obligation = "Obligation" show UnsafeUsed = "Used" show NotUsed = "Not Used" data Clause = Forall [CoreBndr] Clause | Conj Clause Clause | Disj Clause Clause | Impl LemmaName Clause Clause -- ^ name for the antecedent when it is in scope | Equiv CoreExpr CoreExpr | CTrue -- the always true clause -- | A collection of named lemmas. type Lemmas = M.Map LemmaName Lemma -- | A LemmaName, Lemma pair. type NamedLemma = (LemmaName, Lemma) ------------------------------------------------------------------------------ discardUniVars :: Clause -> Clause discardUniVars (Forall _ cl) = cl discardUniVars cl = cl ------------------------------------------------------------------------------ -- | Assumes Var is free in Clause. If not, no substitution will happen, though uniques might be freshened. substClause :: Var -> CoreArg -> Clause -> Clause substClause v e = substClauses [(v,e)] substClauses :: [(Var,CoreArg)] -> Clause -> Clause substClauses ps cl = substClauseSubst (extendSubstList sub ps) cl where (vs,es) = unzip ps sub = mkEmptySubst $ mkInScopeSet $ delVarSetList (unionVarSets $ freeVarsClause cl : map freeVarsExpr es) vs -- | Note: Subst must be properly set up with an InScopeSet that includes all vars -- in scope in the *range* of the substitution. substClauseSubst :: Subst -> Clause -> Clause substClauseSubst = go where go sub (Forall bs cl) = let (bs', cl') = go1 sub bs [] cl in mkForall bs' cl' go _ CTrue = CTrue go subst (Conj q1 q2) = Conj (go subst q1) (go subst q2) go subst (Disj q1 q2) = Disj (go subst q1) (go subst q2) go subst (Impl nm q1 q2) = Impl nm (go subst q1) (go subst q2) go subst (Equiv e1 e2) = let e1' = substExpr (text "substClauseSubst e1") subst e1 e2' = substExpr (text "substClauseSubst e2") subst e2 in Equiv e1' e2' go1 subst [] bs' cl = (reverse bs', go subst cl) go1 subst (b:bs) bs' cl = let (subst',b') = substBndr subst b in go1 subst' bs (b':bs') cl ------------------------------------------------------------------------------ redundantDicts :: Clause -> Clause redundantDicts (Forall bs cl) = go [] [] cl bs where go [] _ c [] = c go bnds _ c [] = mkForall (reverse bnds) c go bnds tys c (b:bs') | isDictTy bTy = -- is a dictionary binder let match = [ varToCoreExpr pb | (pb,ty) <- tys , eqType bTy ty ] in if null match then go (b:bnds) ((b,bTy):tys) c bs' -- not seen before else let Forall bs'' c' = substClause b (head match) $ mkForall bs' c in go bnds tys c' bs'' -- seen | otherwise = go (b:bnds) tys c bs' where bTy = varType b redundantDicts cl = cl ------------------------------------------------------------------------------ -- | Instantiate one of the universally quantified variables in a 'Clause'. -- Note: assumes implicit ordering of variables, such that substitution happens to the right -- as it does in case alternatives. Only first variable that matches predicate is -- instantiated. instClause :: MonadCatch m => VarSet -- vars in scope -> (Var -> Bool) -- predicate to select var -> CoreExpr -- expression to instantiate with -> Clause -> m Clause instClause inScope p e = prefixFailMsg "clause instantiation failed: " . liftM fst . go [] where go bbs (Forall bs cl) | not (any p bs) = -- not quantified at this level, so try further down let go2 con q1 q2 = do er <- attemptM $ go (bs++bbs) q1 (cl',s) <- case er of Right (q1',s) -> return (con q1' q2, s) Left _ -> do er' <- attemptM $ go (bs++bbs) q2 case er' of Right (q2',s) -> return (con q1 q2', s) Left msg -> fail msg return (replaceVars s bs cl', s) in case cl of Equiv{} -> fail "specified variable is not universally quantified." CTrue -> fail "specified variable is not universally quantified." Conj q1 q2 -> go2 Conj q1 q2 Disj q1 q2 -> go2 Disj q1 q2 Impl nm q1 q2 -> go2 (Impl nm) q1 q2 Forall _ _ -> fail "impossible case!" | otherwise = do -- quantified here, so do substitution and start bubbling up let (bs',i:vs) = break p bs -- this is safe because we know i is in bs (eTvs, eTy) = splitForAllTys $ exprKindOrType e bsInScope = bs'++bbs tyVars = eTvs ++ filter isTyVar bsInScope failMsg = fail "type of provided expression differs from selected binder." bindFn v = if v `elem` tyVars then BindMe else Skolem sub <- maybe failMsg return $ tcUnifyTys bindFn [varType i] [eTy] -- if i is a tyvar, we know e is a type, so free vars will be tyvars let e' = mkCoreApps e [ case lookupTyVar sub v of Nothing -> Type (mkTyVarTy v) Just ty -> Type ty | v <- eTvs ] let newBs = varSetElems $ filterVarSet (\v -> not (isId v) || isLocalId v) $ delVarSetList (minusVarSet (freeVarsExpr e') inScope) bsInScope cl' = substClause i e' $ mkForall vs cl return (replaceVars sub (bs' ++ newBs) cl', sub) go _ _ = fail "only applies to clauses with quantifiers." -- | The function which 'bubbles up' after the instantiation takes place, -- replacing any type variables that were instantiated as a result of specialization. replaceVars :: TvSubst -> [Var] -> Clause -> Clause replaceVars sub vs = go (reverse vs) where go [] cl = cl go (b:bs) cl | isTyVar b = case lookupTyVar sub b of Nothing -> go bs (addBinder b cl) Just ty -> let new = varSetElems (freeVarsType ty) in go (new++bs) (substClause b (Type ty) cl) | otherwise = go bs (addBinder b cl) -- tvSubstToSubst :: TvSubst -> Subst -- tvSubstToSubst (TvSubst inS tEnv) = mkSubst inS tEnv emptyVarEnv emptyVarEnv -- | Instantiate a set of universally quantified variables in a 'Clause'. -- It is important that all type variables appear before any value-level variables in the first argument. instsClause :: MonadCatch m => VarSet -> [(Var,CoreExpr)] -> Clause -> m Clause instsClause inScope = flip (foldM (\ q (v,e) -> instClause inScope (==v) e q)) . reverse -- foldM is a left-to-right fold, so the reverse is important to do substitutions in reverse order -- which is what we want (all value variables should be instantiated before type variables). ------------------------------------------------------------------------------ -- Syntactic Equality -- | Syntactic Equality of clauses. clauseSyntaxEq :: Clause -> Clause -> Bool clauseSyntaxEq (Forall bs1 c1) (Forall bs2 c2) = (bs1 == bs2) && clauseSyntaxEq c1 c2 clauseSyntaxEq (Conj q1 q2) (Conj p1 p2) = clauseSyntaxEq q1 p1 && clauseSyntaxEq q2 p2 clauseSyntaxEq (Disj q1 q2) (Disj p1 p2) = clauseSyntaxEq q1 p1 && clauseSyntaxEq q2 p2 clauseSyntaxEq (Impl n1 q1 q2) (Impl n2 p1 p2) = n1 == n2 && clauseSyntaxEq q1 p1 && clauseSyntaxEq q2 p2 clauseSyntaxEq (Equiv e1 e2) (Equiv e1' e2') = exprSyntaxEq e1 e1' && exprSyntaxEq e2 e2' clauseSyntaxEq _ _ = False ------------------------------------------------------------------------------