module HERMIT.Lemma
(
Clause(..)
, mkClause
, mkForall
, forallQs
, instClause
, instsClause
, discardUniVars
, freeVarsClause
, clauseSyntaxEq
, substClause
, substClauses
, dropBinders
, redundantDicts
, LemmaName(..)
, Lemma(..)
, Proven(..)
, andP, orP
, Used(..)
, Lemmas
, NamedLemma
) where
import Prelude.Compat hiding (lookup)
import Control.Monad
import Data.Dynamic (Typeable)
import Data.String (IsString(..))
import qualified Data.Map as M
import HERMIT.Core
import HERMIT.GHC hiding ((<>))
import Language.KURE.MonadCatch
mkClause :: [CoreBndr] -> CoreExpr -> CoreExpr -> Clause
mkClause vs lhs rhs = redundantDicts $ dropBinders $ Forall (tvs++vs++lbs++rbs) (Equiv lhs' rbody)
where (lbs, lbody) = collectBinders lhs
rhs' = uncurry mkCoreApps $ betaReduceAll rhs $ map varToCoreExpr lbs
(rbs, rbody) = collectBinders rhs'
lhs' = mkCoreApps lbody $ map varToCoreExpr rbs
tvs = varSetElems
$ filterVarSet isTyVar
$ delVarSetList (unionVarSets $ map freeVarsExpr [lhs',rbody]) (vs++lbs++rbs)
freeVarsClause :: Clause -> VarSet
freeVarsClause (Forall bs cl) = delVarSetList (freeVarsClause cl) bs
freeVarsClause (Conj q1 q2) = unionVarSets $ map freeVarsClause [q1,q2]
freeVarsClause (Disj q1 q2) = unionVarSets $ map freeVarsClause [q1,q2]
freeVarsClause (Impl _ q1 q2) = unionVarSets $ map freeVarsClause [q1,q2]
freeVarsClause (Equiv e1 e2) = unionVarSets $ map freeVarsExpr [e1,e2]
freeVarsClause CTrue = emptyVarSet
dropBinders :: Clause -> Clause
dropBinders (Forall bs cl) =
case bs of
[] -> dropBinders cl
(b:bs') -> let c = dropBinders (mkForall bs' cl)
in if b `elemVarSet` freeVarsClause c
then addBinder b c
else c
dropBinders (Conj q1 q2) = Conj (dropBinders q1) (dropBinders q2)
dropBinders (Disj q1 q2) = Disj (dropBinders q1) (dropBinders q2)
dropBinders (Impl nm q1 q2) = Impl nm (dropBinders q1) (dropBinders q2)
dropBinders other = other
addBinder :: CoreBndr -> Clause -> Clause
addBinder b = mkForall [b]
mkForall :: [CoreBndr] -> Clause -> Clause
mkForall bs (Forall bs' cl) = Forall (bs++bs') cl
mkForall bs cl = Forall bs cl
forallQs :: Clause -> [CoreBndr]
forallQs (Forall bs _) = bs
forallQs _ = []
newtype LemmaName = LemmaName String deriving (Eq, Ord, Typeable)
instance Monoid LemmaName where
mempty = LemmaName mempty
mappend (LemmaName n1) (LemmaName n2) = LemmaName (mappend n1 n2)
instance IsString LemmaName where fromString = LemmaName
instance Show LemmaName where show (LemmaName s) = s
data Lemma = Lemma { lemmaC :: Clause
, lemmaP :: Proven
, lemmaU :: Used
}
data Proven = Proven
| Assumed
| BuiltIn
| NotProven
deriving (Eq, Typeable)
instance Show Proven where
show Proven = "Proven"
show Assumed = "Assumed"
show BuiltIn = "Built In"
show NotProven = "Not Proven"
instance Enum Proven where
toEnum 1 = Assumed
toEnum 2 = BuiltIn
toEnum 3 = Proven
toEnum _ = NotProven
fromEnum NotProven = 0
fromEnum Assumed = 1
fromEnum BuiltIn = 2
fromEnum Proven = 3
instance Ord Proven where
compare :: Proven -> Proven -> Ordering
compare p1 p2 = compare (fromEnum p1) (fromEnum p2)
andP :: Proven -> Proven -> Proven
andP = min
orP :: Proven -> Proven -> Proven
orP = max
data Used = Obligation
| UnsafeUsed
| NotUsed
deriving (Eq, Typeable)
instance Show Used where
show Obligation = "Obligation"
show UnsafeUsed = "Used"
show NotUsed = "Not Used"
data Clause = Forall [CoreBndr] Clause
| Conj Clause Clause
| Disj Clause Clause
| Impl LemmaName Clause Clause
| Equiv CoreExpr CoreExpr
| CTrue
type Lemmas = M.Map LemmaName Lemma
type NamedLemma = (LemmaName, Lemma)
discardUniVars :: Clause -> Clause
discardUniVars (Forall _ cl) = cl
discardUniVars cl = cl
substClause :: Var -> CoreArg -> Clause -> Clause
substClause v e = substClauses [(v,e)]
substClauses :: [(Var,CoreArg)] -> Clause -> Clause
substClauses ps cl = substClauseSubst (extendSubstList sub ps) cl
where (vs,es) = unzip ps
sub = mkEmptySubst
$ mkInScopeSet
$ delVarSetList (unionVarSets $ freeVarsClause cl : map freeVarsExpr es) vs
substClauseSubst :: Subst -> Clause -> Clause
substClauseSubst = go
where go sub (Forall bs cl) =
let (bs', cl') = go1 sub bs [] cl
in mkForall bs' cl'
go _ CTrue = CTrue
go subst (Conj q1 q2) = Conj (go subst q1) (go subst q2)
go subst (Disj q1 q2) = Disj (go subst q1) (go subst q2)
go subst (Impl nm q1 q2) = Impl nm (go subst q1) (go subst q2)
go subst (Equiv e1 e2) =
let e1' = substExpr (text "substClauseSubst e1") subst e1
e2' = substExpr (text "substClauseSubst e2") subst e2
in Equiv e1' e2'
go1 subst [] bs' cl = (reverse bs', go subst cl)
go1 subst (b:bs) bs' cl =
let (subst',b') = substBndr subst b
in go1 subst' bs (b':bs') cl
redundantDicts :: Clause -> Clause
redundantDicts (Forall bs cl) = go [] [] cl bs
where go [] _ c [] = c
go bnds _ c [] = mkForall (reverse bnds) c
go bnds tys c (b:bs')
| isDictTy bTy =
let match = [ varToCoreExpr pb | (pb,ty) <- tys , eqType bTy ty ]
in if null match
then go (b:bnds) ((b,bTy):tys) c bs'
else let Forall bs'' c' = substClause b (head match) $ mkForall bs' c
in go bnds tys c' bs''
| otherwise = go (b:bnds) tys c bs'
where bTy = varType b
redundantDicts cl = cl
instClause :: MonadCatch m => VarSet
-> (Var -> Bool)
-> CoreExpr
-> Clause -> m Clause
instClause inScope p e = prefixFailMsg "clause instantiation failed: " . liftM fst . go []
where go bbs (Forall bs cl)
| not (any p bs) =
let go2 con q1 q2 = do
er <- attemptM $ go (bs++bbs) q1
(cl',s) <- case er of
Right (q1',s) -> return (con q1' q2, s)
Left _ -> do
er' <- attemptM $ go (bs++bbs) q2
case er' of
Right (q2',s) -> return (con q1 q2', s)
Left msg -> fail msg
return (replaceVars s bs cl', s)
in case cl of
Equiv{} -> fail "specified variable is not universally quantified."
CTrue -> fail "specified variable is not universally quantified."
Conj q1 q2 -> go2 Conj q1 q2
Disj q1 q2 -> go2 Disj q1 q2
Impl nm q1 q2 -> go2 (Impl nm) q1 q2
Forall _ _ -> fail "impossible case!"
| otherwise = do
let (bs',i:vs) = break p bs
(eTvs, eTy) = splitForAllTys $ exprKindOrType e
bsInScope = bs'++bbs
tyVars = eTvs ++ filter isTyVar bsInScope
failMsg = fail "type of provided expression differs from selected binder."
bindFn v = if v `elem` tyVars then BindMe else Skolem
sub <- maybe failMsg return $ tcUnifyTys bindFn [varType i] [eTy]
let e' = mkCoreApps e [ case lookupTyVar sub v of
Nothing -> Type (mkTyVarTy v)
Just ty -> Type ty | v <- eTvs ]
let newBs = varSetElems
$ filterVarSet (\v -> not (isId v) || isLocalId v)
$ delVarSetList (minusVarSet (freeVarsExpr e') inScope) bsInScope
cl' = substClause i e' $ mkForall vs cl
return (replaceVars sub (bs' ++ newBs) cl', sub)
go _ _ = fail "only applies to clauses with quantifiers."
replaceVars :: TvSubst -> [Var] -> Clause -> Clause
replaceVars sub vs = go (reverse vs)
where go [] cl = cl
go (b:bs) cl
| isTyVar b = case lookupTyVar sub b of
Nothing -> go bs (addBinder b cl)
Just ty -> let new = varSetElems (freeVarsType ty)
in go (new++bs) (substClause b (Type ty) cl)
| otherwise = go bs (addBinder b cl)
instsClause :: MonadCatch m => VarSet -> [(Var,CoreExpr)] -> Clause -> m Clause
instsClause inScope = flip (foldM (\ q (v,e) -> instClause inScope (==v) e q)) . reverse
clauseSyntaxEq :: Clause -> Clause -> Bool
clauseSyntaxEq (Forall bs1 c1) (Forall bs2 c2) = (bs1 == bs2) && clauseSyntaxEq c1 c2
clauseSyntaxEq (Conj q1 q2) (Conj p1 p2) = clauseSyntaxEq q1 p1 && clauseSyntaxEq q2 p2
clauseSyntaxEq (Disj q1 q2) (Disj p1 p2) = clauseSyntaxEq q1 p1 && clauseSyntaxEq q2 p2
clauseSyntaxEq (Impl n1 q1 q2) (Impl n2 p1 p2) = n1 == n2 && clauseSyntaxEq q1 p1 && clauseSyntaxEq q2 p2
clauseSyntaxEq (Equiv e1 e2) (Equiv e1' e2') = exprSyntaxEq e1 e1' && exprSyntaxEq e2 e2'
clauseSyntaxEq _ _ = False