module HERMIT.Dictionary.Reasoning
(
externals
, CoreExprEquality(..)
, RewriteCoreExprEqualityBox(..)
, TransformCoreExprEqualityStringBox(..)
, CoreExprEqualityProof
, flipCoreExprEquality
, eqLhsIntroR
, eqRhsIntroR
, birewrite
, extensionalityR
, lhsT
, rhsT
, bothT
, forallVarsT
, lhsR
, rhsR
, bothR
, ppCoreExprEqualityT
, proveCoreExprEqualityT
, verifyCoreExprEqualityT
, verifyEqualityLeftToRightT
, verifyEqualityCommonTargetT
, verifyIsomorphismT
, verifyRetractionT
, retractionBR
, alphaEqualityR
, unshadowEqualityR
, instantiateDictsR
, instantiateEquality
, instantiateEqualityVar
, instantiateEqualityVarR
, discardUniVars
) where
import Control.Applicative
import Control.Arrow
import Control.Monad
import Control.Monad.IO.Class
import Data.List (nubBy)
import Data.Maybe (fromMaybe)
import Data.Monoid
import Data.Typeable
import HERMIT.Context
import HERMIT.Core
import HERMIT.External
import HERMIT.GHC
import HERMIT.Kure
import HERMIT.Monad
import HERMIT.ParserCore
#if __GLASGOW_HASKELL__ >= 708
import HERMIT.ParserType
#endif
import HERMIT.Utilities
import HERMIT.Dictionary.AlphaConversion hiding (externals)
import HERMIT.Dictionary.Common
import HERMIT.Dictionary.Fold hiding (externals)
import HERMIT.Dictionary.GHC hiding (externals)
import HERMIT.Dictionary.Local.Let (nonRecIntroR)
import HERMIT.Dictionary.Unfold hiding (externals)
import HERMIT.PrettyPrinter.Common
import qualified Text.PrettyPrint.MarkedHughesPJ as PP
externals :: [External]
externals =
[ external "retraction" ((\ f g r -> promoteExprBiR $ retraction (Just r) f g) :: CoreString -> CoreString -> RewriteH Core -> BiRewriteH Core)
[ "Given f :: X -> Y and g :: Y -> X, and a proof that f (g y) ==> y, then"
, "f (g y) <==> y."
] .+ Shallow
, external "retraction-unsafe" ((\ f g -> promoteExprBiR $ retraction Nothing f g) :: CoreString -> CoreString -> BiRewriteH Core)
[ "Given f :: X -> Y and g :: Y -> X, then"
, "f (g y) <==> y."
, "Note that the precondition (f (g y) == y) is expected to hold."
] .+ Shallow .+ PreCondition
, external "alpha-equality" ((\ nm newName -> alphaEqualityR (cmpString2Var nm) (const newName)))
[ "Alpha-rename a universally quantified variable." ]
, external "unshadow-equality" unshadowEqualityR
[ "Unshadow an equality." ]
]
data CoreExprEquality = CoreExprEquality [CoreBndr] CoreExpr CoreExpr
data RewriteCoreExprEqualityBox =
RewriteCoreExprEqualityBox (RewriteH CoreExprEquality) deriving Typeable
instance Extern (RewriteH CoreExprEquality) where
type Box (RewriteH CoreExprEquality) = RewriteCoreExprEqualityBox
box = RewriteCoreExprEqualityBox
unbox (RewriteCoreExprEqualityBox r) = r
data TransformCoreExprEqualityStringBox =
TransformCoreExprEqualityStringBox (TransformH CoreExprEquality String) deriving Typeable
instance Extern (TransformH CoreExprEquality String) where
type Box (TransformH CoreExprEquality String) = TransformCoreExprEqualityStringBox
box = TransformCoreExprEqualityStringBox
unbox (TransformCoreExprEqualityStringBox t) = t
type CoreExprEqualityProof c m = (Rewrite c m CoreExpr, Rewrite c m CoreExpr)
flipCoreExprEquality :: CoreExprEquality -> CoreExprEquality
flipCoreExprEquality (CoreExprEquality xs lhs rhs) = CoreExprEquality xs rhs lhs
extensionalityR :: Maybe String -> Rewrite c HermitM CoreExprEquality
extensionalityR mn = prefixFailMsg "extensionality failed: " $
do CoreExprEquality vs lhs rhs <- idR
let tyL = exprKindOrType lhs
tyR = exprKindOrType rhs
guardMsg (tyL `typeAlphaEq` tyR) "type mismatch between sides of equality. This shouldn't happen, so is probably a bug."
(argTy,_) <- splitFunTypeM tyL
v <- constT $ newVarH (fromMaybe "x" mn) argTy
let x = varToCoreExpr v
return $ CoreExprEquality (vs ++ [v])
(mkCoreApp lhs x)
(mkCoreApp rhs x)
eqLhsIntroR :: CoreExprEquality -> Rewrite c HermitM Core
eqLhsIntroR (CoreExprEquality bs lhs _) = nonRecIntroR "lhs" (mkCoreLams bs lhs)
eqRhsIntroR :: CoreExprEquality -> Rewrite c HermitM Core
eqRhsIntroR (CoreExprEquality bs _ rhs) = nonRecIntroR "rhs" (mkCoreLams bs rhs)
birewrite :: (AddBindings c, ReadBindings c, ExtendPath c Crumb, ReadPath c Crumb, HasEmptyContext c) => CoreExprEquality -> BiRewrite c HermitM CoreExpr
birewrite (CoreExprEquality bnds l r) = bidirectional (foldUnfold l r) (foldUnfold r l)
where foldUnfold lhs rhs = transform $ \ c e -> do
let lhsLam = mkCoreLams bnds lhs
v <- newIdH "biTemp" (exprType lhsLam)
e' <- maybe (fail "folding LHS failed") return (fold v lhsLam e)
let rhsLam = mkCoreLams bnds rhs
c' = addHermitBindings [(v, NONREC rhsLam, mempty)] c
apply unfoldR c' e'
lhsT :: (AddBindings c, Monad m, ReadPath c Crumb) => Transform c m CoreExpr b -> Transform c m CoreExprEquality b
lhsT t = idR >>= \ (CoreExprEquality vs lhs _) -> return lhs >>> withVarsInScope vs t
rhsT :: (AddBindings c, Monad m, ReadPath c Crumb) => Transform c m CoreExpr b -> Transform c m CoreExprEquality b
rhsT t = idR >>= \ (CoreExprEquality vs _ rhs) -> return rhs >>> withVarsInScope vs t
bothT :: (AddBindings c, Monad m, ReadPath c Crumb) => Transform c m CoreExpr b -> Transform c m CoreExprEquality (b,b)
bothT t = liftM2 (,) (lhsT t) (rhsT t)
forallVarsT :: Monad m => Transform c m [Var] b -> Transform c m CoreExprEquality b
forallVarsT t = idR >>= \ (CoreExprEquality vs _ _) -> return vs >>> t
lhsR :: (AddBindings c, Monad m, ReadPath c Crumb) => Rewrite c m CoreExpr -> Rewrite c m CoreExprEquality
lhsR r = do
CoreExprEquality vs lhs rhs <- idR
lhs' <- withVarsInScope vs r <<< return lhs
return $ CoreExprEquality vs lhs' rhs
rhsR :: (AddBindings c, Monad m, ReadPath c Crumb) => Rewrite c m CoreExpr -> Rewrite c m CoreExprEquality
rhsR r = do
CoreExprEquality vs lhs rhs <- idR
rhs' <- withVarsInScope vs r <<< return rhs
return $ CoreExprEquality vs lhs rhs'
bothR :: (AddBindings c, MonadCatch m, ReadPath c Crumb) => Rewrite c m CoreExpr -> Rewrite c m CoreExprEquality
bothR r = lhsR r >+> rhsR r
ppCoreExprEqualityT :: PrettyPrinter -> TransformH CoreExprEquality DocH
ppCoreExprEqualityT pp = do
let pos = pOptions pp
d1 <- forallVarsT (liftPrettyH pos $ pForall pp)
(d2,d3) <- bothT (liftPrettyH pos $ extractT $ pCoreTC pp)
return $ PP.sep [d1,d2,syntaxColor (PP.text "="),d3]
class BuildEquality a where
mkEquality :: a -> HermitM CoreExprEquality
instance BuildEquality (CoreExpr,CoreExpr) where
mkEquality :: (CoreExpr,CoreExpr) -> HermitM CoreExprEquality
mkEquality (lhs,rhs) = return $ CoreExprEquality [] lhs rhs
instance BuildEquality a => BuildEquality (CoreExpr -> a) where
mkEquality :: (CoreExpr -> a) -> HermitM CoreExprEquality
mkEquality f = do
x <- newIdH "x" (error "need to create a type")
CoreExprEquality bnds lhs rhs <- mkEquality (f (varToCoreExpr x))
return $ CoreExprEquality (x:bnds) lhs rhs
proveCoreExprEqualityT :: forall c m. (AddBindings c, Monad m, ReadPath c Crumb)
=> CoreExprEqualityProof c m -> Transform c m CoreExprEquality ()
proveCoreExprEqualityT (l,r) = lhsR l >>> rhsR r >>> verifyCoreExprEqualityT
verifyCoreExprEqualityT :: Monad m => Transform c m CoreExprEquality ()
verifyCoreExprEqualityT = do
CoreExprEquality _ lhs rhs <- idR
guardMsg (exprAlphaEq lhs rhs) "the two sides of the equality do not match."
verifyEqualityLeftToRightT :: MonadCatch m => CoreExpr -> CoreExpr -> Rewrite c m CoreExpr -> Transform c m a ()
verifyEqualityLeftToRightT sourceExpr targetExpr r =
prefixFailMsg "equality verification failed: " $
do resultExpr <- r <<< return sourceExpr
guardMsg (exprAlphaEq targetExpr resultExpr) "result of running proof on lhs of equality does not match rhs of equality."
verifyEqualityCommonTargetT :: MonadCatch m => CoreExpr -> CoreExpr -> CoreExprEqualityProof c m -> Transform c m a ()
verifyEqualityCommonTargetT lhs rhs (l,r) =
prefixFailMsg "equality verification failed: " $
do lhsResult <- l <<< return lhs
rhsResult <- r <<< return rhs
guardMsg (exprAlphaEq lhsResult rhsResult) "results of running proofs on both sides of equality do not match."
verifyIsomorphismT :: CoreExpr -> CoreExpr -> Rewrite c HermitM CoreExpr -> Rewrite c HermitM CoreExpr -> Transform c HermitM a ()
verifyIsomorphismT f g fgR gfR = prefixFailMsg "Isomorphism verification failed: " $
do (tyX, tyY) <- funExprsWithInverseTypes f g
x <- constT (newGlobalIdH "x" tyX)
y <- constT (newGlobalIdH "y" tyY)
verifyEqualityLeftToRightT (App f (App g (Var y))) (Var y) fgR
verifyEqualityLeftToRightT (App g (App f (Var x))) (Var x) gfR
verifyRetractionT :: CoreExpr -> CoreExpr -> Rewrite c HermitM CoreExpr -> Transform c HermitM a ()
verifyRetractionT f g r = prefixFailMsg "Retraction verification failed: " $
do (_tyX, tyY) <- funExprsWithInverseTypes f g
y <- constT (newGlobalIdH "y" tyY)
let lhs = App f (App g (Var y))
rhs = Var y
verifyEqualityLeftToRightT lhs rhs r
retractionBR :: forall c. Maybe (Rewrite c HermitM CoreExpr) -> CoreExpr -> CoreExpr -> BiRewrite c HermitM CoreExpr
retractionBR mr f g = beforeBiR
(prefixFailMsg "Retraction failed: " $
do whenJust (verifyRetractionT f g) mr
y <- idR
(_, tyY) <- funExprsWithInverseTypes f g
guardMsg (exprKindOrType y `typeAlphaEq` tyY) "type of expression does not match given retraction components."
return y
)
(\ y -> bidirectional
retractionL
(return $ App f (App g y))
)
where
retractionL :: Rewrite c HermitM CoreExpr
retractionL = prefixFailMsg "Retraction failed: " $
withPatFailMsg (wrongExprForm "App f (App g y)") $
do App f' (App g' y) <- idR
guardMsg (exprAlphaEq f f' && exprAlphaEq g g') "given retraction components do not match current expression."
return y
retraction :: Maybe (RewriteH Core) -> CoreString -> CoreString -> BiRewriteH CoreExpr
retraction mr = parse2beforeBiR (retractionBR (extractR <$> mr))
instantiateDictsR :: RewriteH CoreExprEquality
#if __GLASGOW_HASKELL__ >= 708
instantiateDictsR = prefixFailMsg "Dictionary instantiation failed: " $ do
bs <- forallVarsT idR
let dArgs = filter (\b -> isId b && isDictTy (varType b)) bs
uniqDs = nubBy (\ b1 b2 -> eqType (varType b1) (varType b2)) dArgs
guardMsg (not (null uniqDs)) "no universally quantified dictionaries can be instantiated."
ds <- forM uniqDs $ \ b -> constT $ do
guts <- getModGuts
(i,bnds) <- liftCoreM $ buildDictionary guts b
let dExpr = case bnds of
[NonRec v e] | i == v -> e
_ -> mkCoreLets bnds (varToCoreExpr i)
new = varSetElems $ delVarSetList (localFreeVarsExpr dExpr) bs
return (b,dExpr,new)
let buildSubst :: Monad m => Var -> m (Var, CoreExpr, [Var])
buildSubst b = case [ (b,e,[]) | (b',e,_) <- ds, eqType (varType b) (varType b') ] of
[] -> fail "cannot find equivalent dictionary expression (impossible!)"
[t] -> return t
_ -> fail "multiple dictionary expressions found (impossible!)"
lookup3 :: Var -> [(Var,CoreExpr,[Var])] -> (Var,CoreExpr,[Var])
lookup3 v l = head [ t | t@(v',_,_) <- l, v == v' ]
allDs <- forM dArgs $ \ b -> constT $ do
if b `elem` uniqDs
then return $ lookup3 b ds
else buildSubst b
contextfreeT $ instantiateEquality allDs
#else
instantiateDictsR = fail "Dictionaries cannot be instantiated in GHC 7.6"
#endif
alphaEqualityR :: (Var -> Bool) -> (String -> String) -> RewriteH CoreExprEquality
alphaEqualityR p f = prefixFailMsg "Alpha-renaming binder in equality failed: " $ do
CoreExprEquality bs lhs rhs <- idR
guardMsg (any p bs) "specified variable is not universally quantified."
let (bs',i:vs) = break p bs
i' <- constT $ cloneVarH f i
let inS = delVarSetList (unionVarSets (map localFreeVarsExpr [lhs, rhs] ++ map freeVarsVar vs)) (i:i':vs)
subst = extendSubst (mkEmptySubst (mkInScopeSet inS)) i (varToCoreExpr i')
(subst', vs') = substBndrs subst vs
lhs' = substExpr (text "coreExprEquality-lhs") subst' lhs
rhs' = substExpr (text "coreExprEquality-rhs") subst' rhs
return $ CoreExprEquality (bs'++(i':vs')) lhs' rhs'
unshadowEqualityR :: RewriteH CoreExprEquality
unshadowEqualityR = prefixFailMsg "Unshadowing equality failed: " $ do
c@(CoreExprEquality bs _ _) <- idR
bvs <- boundVarsT
let visible = unionVarSets [bvs , freeVarsEquality c]
ss <- varSetElems <$> detectShadowsM bs visible
guardMsg (not (null ss)) "no shadows to eliminate."
let f = freshNameGenAvoiding Nothing . extendVarSet visible
andR [ alphaEqualityR (==s) (f s) | s <- reverse ss ] >>> bothR (tryR unshadowExprR)
freeVarsEquality :: CoreExprEquality -> VarSet
freeVarsEquality (CoreExprEquality bs lhs rhs) =
delVarSetList (unionVarSets (map freeVarsExpr [lhs,rhs])) bs
instantiateEqualityVarR :: (Var -> Bool) -> CoreString -> RewriteH CoreExprEquality
instantiateEqualityVarR p cs = prefixFailMsg "instantiation failed: " $ do
bs <- forallVarsT idR
(e,new) <- case filter p bs of
[] -> fail "no universally quantified variables match predicate."
(b:_) | isId b -> let (before,_) = break (==b) bs
in liftM (,[]) $ withVarsInScope before $ parseCoreExprT cs
#if __GLASGOW_HASKELL__ >= 708
| otherwise -> do let (before,_) = break (==b) bs
(ty, tvs) <- withVarsInScope before $ parseTypeWithHolesT cs
return (Type ty, tvs)
#else
| otherwise -> fail "cannot instantiate type binders in GHC 7.6"
#endif
eq <- contextfreeT $ instantiateEqualityVar p e new
(_,_) <- return eq >>> bothT lintExprT
return eq
instantiateEqualityVar :: MonadIO m => (Var -> Bool)
-> CoreExpr
-> [Var]
-> CoreExprEquality -> m CoreExprEquality
instantiateEqualityVar p e new (CoreExprEquality bs lhs rhs)
| not (any p bs) = fail "specified variable is not universally quantified."
| otherwise = do
let (bs',i:vs) = break p bs
tyVars = filter isTyVar bs'
failMsg = fail "type of provided expression differs from selected binder."
dropSelfSubst :: [(TyVar, Type)] -> [(TyVar,Type)]
dropSelfSubst ps = [ (v,t) | (v,t) <- ps, case t of
TyVarTy v' | v' == v -> False
_ -> True ]
tvs <- maybe failMsg (return . tyMatchesToCoreExpr . dropSelfSubst)
$ unifyTypes tyVars (varType i) (exprKindOrType e)
let inS = delVarSetList (unionVarSets (map localFreeVarsExpr [lhs, rhs, e] ++ map freeVarsVar vs)) (i:vs)
subst = extendSubst (mkEmptySubst (mkInScopeSet inS)) i e
(subst', vs') = substBndrs subst vs
lhs' = substExpr (text "coreExprEquality-lhs") subst' lhs
rhs' = substExpr (text "coreExprEquality-rhs") subst' rhs
instantiateEquality (noAdds tvs) $ CoreExprEquality (bs'++new++vs') lhs' rhs'
noAdds :: [(Var,CoreExpr)] -> [(Var,CoreExpr,[Var])]
noAdds ps = [ (v,e,[]) | (v,e) <- ps ]
instantiateEquality :: MonadIO m => [(Var,CoreExpr,[Var])] -> CoreExprEquality -> m CoreExprEquality
instantiateEquality = flip (foldM (\ eq (v,e,vs) -> instantiateEqualityVar (==v) e vs eq)) . reverse
discardUniVars :: CoreExprEquality -> CoreExprEquality
discardUniVars (CoreExprEquality _ lhs rhs) = CoreExprEquality [] lhs rhs