module HERMIT.Core
(
CoreProg(..)
, CoreDef(..)
, CoreTickish
, progSyntaxEq
, bindSyntaxEq
, defSyntaxEq
, exprSyntaxEq
, altSyntaxEq
, typeSyntaxEq
, coercionSyntaxEq
, progAlphaEq
, bindAlphaEq
, defAlphaEq
, exprAlphaEq
, altAlphaEq
, typeAlphaEq
, coercionAlphaEq
, defsToRecBind
, defToIdExpr
, progToBinds
, bindsToProg
, bindToVarExprs
, progIds
, bindVars
, defId
, altVars
, freeVarsProg
, freeVarsBind
, freeVarsDef
, freeVarsExpr
, freeVarsAlt
, freeVarsVar
, localFreeVarsAlt
, freeVarsType
, freeVarsCoercion
, localFreeVarsExpr
, freeIdsExpr
, localFreeIdsExpr
, isCoArg
, exprKindOrType
, exprTypeM
, endoFunTypeM
, splitTyConAppM
, splitFunTypeM
, endoFunExprTypeM
, funExprArgResTypesM
, funExprsWithInverseTypes
, appCount
, mapAlts
, substCoreAlt
, substCoreExpr
, betaReduceAll
, mkDataConApp
, Crumb(..)
, showCrumbs
, leftSibling
, rightSibling
) where
import Control.Monad ((>=>))
import Language.KURE.Combinators.Monad
import Language.KURE.MonadCatch
import HERMIT.GHC
import HERMIT.Utilities
import Data.List (intercalate)
type CoreTickish = Tickish Id
data CoreProg = ProgNil
| ProgCons CoreBind CoreProg
infixr 5 `ProgCons`
progToBinds :: CoreProg -> [CoreBind]
progToBinds ProgNil = []
progToBinds (ProgCons bd p) = bd : progToBinds p
bindsToProg :: [CoreBind] -> CoreProg
bindsToProg = foldr ProgCons ProgNil
bindToVarExprs :: CoreBind -> [(Var,CoreExpr)]
bindToVarExprs (NonRec v e) = [(v,e)]
bindToVarExprs (Rec bds) = bds
data CoreDef = Def Id CoreExpr
defToIdExpr :: CoreDef -> (Id,CoreExpr)
defToIdExpr (Def v e) = (v,e)
defsToRecBind :: [CoreDef] -> CoreBind
defsToRecBind = Rec . map defToIdExpr
progSyntaxEq :: CoreProg -> CoreProg -> Bool
progSyntaxEq ProgNil ProgNil = True
progSyntaxEq (ProgCons bnd1 p1) (ProgCons bnd2 p2) = bindSyntaxEq bnd1 bnd2 && progSyntaxEq p1 p2
progSyntaxEq _ _ = False
bindSyntaxEq :: CoreBind -> CoreBind -> Bool
bindSyntaxEq (NonRec v1 e1) (NonRec v2 e2) = v1 == v2 && exprSyntaxEq e1 e2
bindSyntaxEq (Rec ies1) (Rec ies2) = all2 (\ (i1,e1) (i2,e2) -> i1 == i2 && exprSyntaxEq e1 e2) ies1 ies2
bindSyntaxEq _ _ = False
defSyntaxEq :: CoreDef -> CoreDef -> Bool
defSyntaxEq (Def i1 e1) (Def i2 e2) = i1 == i2 && exprSyntaxEq e1 e2
exprSyntaxEq :: CoreExpr -> CoreExpr -> Bool
exprSyntaxEq (Var i1) (Var i2) = i1 == i2
exprSyntaxEq (Lit l1) (Lit l2) = l1 == l2
exprSyntaxEq (App f1 e1) (App f2 e2) = exprSyntaxEq f1 f2 && exprSyntaxEq e1 e2
exprSyntaxEq (Lam v1 e1) (Lam v2 e2) = v1 == v2 && exprSyntaxEq e1 e2
exprSyntaxEq (Let b1 e1) (Let b2 e2) = bindSyntaxEq b1 b2 && exprSyntaxEq e1 e2
exprSyntaxEq (Case s1 w1 ty1 as1) (Case s2 w2 ty2 as2) = w1 == w2 && exprSyntaxEq s1 s2 && all2 altSyntaxEq as1 as2 && typeSyntaxEq ty1 ty2
exprSyntaxEq (Cast e1 co1) (Cast e2 co2) = exprSyntaxEq e1 e2 && coercionSyntaxEq co1 co2
exprSyntaxEq (Tick t1 e1) (Tick t2 e2) = t1 == t2 && exprSyntaxEq e1 e2
exprSyntaxEq (Type ty1) (Type ty2) = typeSyntaxEq ty1 ty2
exprSyntaxEq (Coercion co1) (Coercion co2) = coercionSyntaxEq co1 co2
exprSyntaxEq _ _ = False
altSyntaxEq :: CoreAlt -> CoreAlt -> Bool
altSyntaxEq (c1,vs1,e1) (c2,vs2,e2) = c1 == c2 && vs1 == vs2 && exprSyntaxEq e1 e2
typeSyntaxEq :: Type -> Type -> Bool
typeSyntaxEq (TyVarTy v1) (TyVarTy v2) = v1 == v2
typeSyntaxEq (LitTy l1) (LitTy l2) = l1 == l2
typeSyntaxEq (AppTy t1 ty1) (AppTy t2 ty2) = typeSyntaxEq t1 t2 && typeSyntaxEq ty1 ty2
typeSyntaxEq (FunTy t1 ty1) (FunTy t2 ty2) = typeSyntaxEq t1 t2 && typeSyntaxEq ty1 ty2
typeSyntaxEq (ForAllTy v1 ty1) (ForAllTy v2 ty2) = v1 == v2 && typeSyntaxEq ty1 ty2
typeSyntaxEq (TyConApp c1 ts1) (TyConApp c2 ts2) = c1 == c2 && all2 typeSyntaxEq ts1 ts2
typeSyntaxEq _ _ = False
coercionSyntaxEq :: Coercion -> Coercion -> Bool
coercionSyntaxEq (Refl role1 ty1) (Refl role2 ty2) = role1 == role2 && typeSyntaxEq ty1 ty2
coercionSyntaxEq (TyConAppCo role1 tc1 cos1) (TyConAppCo role2 tc2 cos2) = role1 == role2 && tc1 == tc2 && all2 coercionSyntaxEq cos1 cos2
coercionSyntaxEq (AppCo co11 co12) (AppCo co21 co22) = coercionSyntaxEq co11 co21 && coercionSyntaxEq co12 co22
coercionSyntaxEq (ForAllCo v1 co1) (ForAllCo v2 co2) = v1 == v2 && coercionSyntaxEq co1 co2
coercionSyntaxEq (CoVarCo v1) (CoVarCo v2) = v1 == v2
coercionSyntaxEq (AxiomInstCo con1 ind1 cos1) (AxiomInstCo con2 ind2 cos2) = con1 == con2 && ind1 == ind2 && all2 coercionSyntaxEq cos1 cos2
coercionSyntaxEq (LRCo lr1 co1) (LRCo lr2 co2) = lr1 == lr2 && coercionSyntaxEq co1 co2
coercionSyntaxEq (UnivCo role1 ty11 ty12) (UnivCo role2 ty21 ty22) = role1 == role2 && typeSyntaxEq ty11 ty21 && typeSyntaxEq ty12 ty22
coercionSyntaxEq (SubCo co1) (SubCo co2) = coercionSyntaxEq co1 co2
coercionSyntaxEq (SymCo co1) (SymCo co2) = coercionSyntaxEq co1 co2
coercionSyntaxEq (TransCo co11 co12) (TransCo co21 co22) = coercionSyntaxEq co11 co21 && coercionSyntaxEq co12 co22
coercionSyntaxEq (NthCo n1 co1) (NthCo n2 co2) = n1 == n2 && coercionSyntaxEq co1 co2
coercionSyntaxEq (InstCo co1 ty1) (InstCo co2 ty2) = coercionSyntaxEq co1 co2 && typeSyntaxEq ty1 ty2
coercionSyntaxEq _ _ = False
progAlphaEq :: CoreProg -> CoreProg -> Bool
progAlphaEq ProgNil ProgNil = True
progAlphaEq (ProgCons bnd1 p1) (ProgCons bnd2 p2) = bindVars bnd1 == bindVars bnd2 && bindAlphaEq bnd1 bnd2 && progAlphaEq p1 p2
progAlphaEq _ _ = False
bindAlphaEq :: CoreBind -> CoreBind -> Bool
bindAlphaEq (NonRec _ e1) (NonRec _ e2) = exprAlphaEq e1 e2
bindAlphaEq (Rec ps1) (Rec ps2) = all2 (eqExprX id_unf env) rs1 rs2
where
id_unf _ = noUnfolding
(bs1,rs1) = unzip ps1
(bs2,rs2) = unzip ps2
inScopeSet = mkInScopeSet $ exprsFreeVars (rs1 ++ rs2)
env = rnBndrs2 (mkRnEnv2 inScopeSet) bs1 bs2
bindAlphaEq _ _ = False
defAlphaEq :: CoreDef -> CoreDef -> Bool
defAlphaEq d1 d2 = defsToRecBind [d1] `bindAlphaEq` defsToRecBind [d2]
exprAlphaEq :: CoreExpr -> CoreExpr -> Bool
exprAlphaEq e1 e2 = eqExpr (mkInScopeSet $ exprsFreeVars [e1, e2]) e1 e2
altAlphaEq :: CoreAlt -> CoreAlt -> Bool
altAlphaEq (c1,vs1,e1) (c2,vs2,e2) = c1 == c2 && eqExprX id_unf env e1 e2
where
id_unf _ = noUnfolding
inScopeSet = mkInScopeSet $ exprsFreeVars [e1,e2]
env = rnBndrs2 (mkRnEnv2 inScopeSet) vs1 vs2
typeAlphaEq :: Type -> Type -> Bool
typeAlphaEq = eqType
coercionAlphaEq :: Coercion -> Coercion -> Bool
coercionAlphaEq = coreEqCoercion
progIds :: CoreProg -> [Id]
progIds = \case
ProgNil -> []
ProgCons bnd p -> bindVars bnd ++ progIds p
bindVars :: CoreBind -> [Var]
bindVars = \case
NonRec v _ -> [v]
Rec ds -> map fst ds
defId :: CoreDef -> Id
defId (Def i _) = i
altVars :: CoreAlt -> [Var]
altVars (_,vs,_) = vs
freeIdsExpr :: CoreExpr -> IdSet
freeIdsExpr = filterVarSet isId . freeVarsExpr
localFreeVarsExpr :: CoreExpr -> VarSet
localFreeVarsExpr = filterVarSet isLocalVar . freeVarsExpr
localFreeIdsExpr :: CoreExpr -> VarSet
localFreeIdsExpr = filterVarSet isLocalId . freeVarsExpr
freeVarsExpr :: CoreExpr -> VarSet
freeVarsExpr (Var v) = extendVarSet (freeVarsVar v) v
freeVarsExpr (Lit {}) = emptyVarSet
freeVarsExpr (App e1 e2) = freeVarsExpr e1 `unionVarSet` freeVarsExpr e2
freeVarsExpr (Lam b e) = delVarSet (freeVarsExpr e) b
freeVarsExpr (Let b e) = freeVarsBind b `unionVarSet` delVarSetList (freeVarsExpr e) (bindersOf b)
freeVarsExpr (Case s b ty alts) = let altFVs = delVarSet (unionVarSets $ map freeVarsAlt alts) b
in unionVarSets [freeVarsExpr s, freeVarsType ty, altFVs]
freeVarsExpr (Cast e co) = freeVarsExpr e `unionVarSet` freeVarsCoercion co
freeVarsExpr (Tick t e) = freeVarsTick t `unionVarSet` freeVarsExpr e
freeVarsExpr (Type ty) = freeVarsType ty
freeVarsExpr (Coercion co) = freeVarsCoercion co
freeVarsTick :: Tickish Id -> VarSet
freeVarsTick (Breakpoint _ ids) = mkVarSet ids
freeVarsTick _ = emptyVarSet
freeVarsBind :: CoreBind -> VarSet
freeVarsBind (NonRec v e) = freeVarsExpr e `unionVarSet` freeVarsVar v
freeVarsBind (Rec defs) = let (bs,es) = unzip defs
in delVarSetList (unionVarSets (map freeVarsVar bs ++ map freeVarsExpr es)) bs
freeVarsVar :: Var -> VarSet
freeVarsVar v = varTypeTyVars v `unionVarSet` bndrRuleAndUnfoldingVars v
freeVarsDef :: CoreDef -> VarSet
freeVarsDef (Def v e) = delVarSet (freeVarsExpr e) v `unionVarSet` freeVarsVar v
freeVarsAlt :: CoreAlt -> VarSet
freeVarsAlt (_,bs,e) = delVarSetList (freeVarsExpr e `unionVarSet` unionVarSets (map freeVarsVar bs)) bs
localFreeVarsAlt :: CoreAlt -> VarSet
localFreeVarsAlt (_,bs,e) = delVarSetList (localFreeVarsExpr e `unionVarSet` unionVarSets (map freeVarsVar bs)) bs
freeVarsProg :: CoreProg -> VarSet
freeVarsProg = \case
ProgNil -> emptyVarSet
ProgCons bnd p -> freeVarsBind bnd `unionVarSet` delVarSetList (freeVarsProg p) (bindVars bnd)
freeVarsType :: Type -> TyVarSet
freeVarsType = tyVarsOfType
freeVarsCoercion :: Coercion -> VarSet
freeVarsCoercion = tyCoVarsOfCo
exprKindOrType :: CoreExpr -> KindOrType
exprKindOrType (Type t) = typeKind t
exprKindOrType e = exprType e
exprTypeM :: Monad m => CoreExpr -> m Type
exprTypeM (Type _) = fail "exprTypeM failed: expression is a type, so does not have a type."
exprTypeM e = return (exprType e)
isCoArg :: CoreExpr -> Bool
isCoArg (Coercion {}) = True
isCoArg _ = False
appCount :: CoreExpr -> Int
appCount (App e1 _) = appCount e1 + 1
appCount _ = 0
mapAlts :: (CoreExpr -> CoreExpr) -> [CoreAlt] -> [CoreAlt]
mapAlts f alts = [ (ac, vs, f e) | (ac, vs, e) <- alts ]
splitTyConAppM :: Monad m => Type -> m (TyCon, [Type])
splitTyConAppM = maybeM "splitTyConApp failed." . splitTyConApp_maybe
splitFunTypeM :: MonadCatch m => Type -> m ([TyVar], Type, Type)
splitFunTypeM ty = prefixFailMsg "Split function type failed: " $ do
let (tvs, fTy) = splitForAllTys ty
(argTy, resTy) <- maybeM "not a function type." $ splitFunTy_maybe fTy
return (tvs, argTy, resTy)
endoFunTypeM :: MonadCatch m => Type -> m ([TyVar], Type)
endoFunTypeM ty =
do (tvs,ty1,ty2) <- splitFunTypeM ty
guardMsg (eqType ty1 ty2) ("argument and result types differ.")
return (tvs, ty1)
endoFunExprTypeM :: MonadCatch m => CoreExpr -> m ([TyVar], Type)
endoFunExprTypeM = exprTypeM >=> endoFunTypeM
funExprArgResTypesM :: MonadCatch m => CoreExpr -> m ([TyVar],Type,Type)
funExprArgResTypesM = exprTypeM >=> splitFunTypeM
funExprsWithInverseTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type)
funExprsWithInverseTypes f g =
do (_,fdom,fcod) <- funExprArgResTypesM f
(_,gdom,gcod) <- funExprArgResTypesM g
setFailMsg "functions do not have inverse types." $
do guardM (eqType fdom gcod)
guardM (eqType gdom fcod)
return (fdom,fcod)
data Crumb =
ModGuts_Prog
| ProgCons_Head | ProgCons_Tail
| NonRec_RHS | NonRec_Var
| Rec_Def Int
| Def_Id | Def_RHS
| Var_Id
| Lit_Lit
| App_Fun | App_Arg
| Lam_Var | Lam_Body
| Let_Bind | Let_Body
| Case_Scrutinee | Case_Binder | Case_Type | Case_Alt Int
| Cast_Expr | Cast_Co
| Tick_Tick | Tick_Expr
| Type_Type
| Co_Co
| Alt_Con | Alt_Var Int | Alt_RHS
| TyVarTy_TyVar
| LitTy_TyLit
| AppTy_Fun | AppTy_Arg
| TyConApp_TyCon | TyConApp_Arg Int
| FunTy_Dom | FunTy_CoDom
| ForAllTy_Var | ForAllTy_Body
| Refl_Type
| TyConAppCo_TyCon | TyConAppCo_Arg Int
| AppCo_Fun | AppCo_Arg
| ForAllCo_TyVar | ForAllCo_Body
| CoVarCo_CoVar
| AxiomInstCo_Axiom | AxiomInstCo_Index | AxiomInstCo_Arg Int
| UnsafeCo_Left | UnsafeCo_Right
| SymCo_Co
| TransCo_Left | TransCo_Right
| NthCo_Int | NthCo_Co
| InstCo_Co | InstCo_Type
| LRCo_LR | LRCo_Co
| Forall_Body
| Conj_Lhs | Conj_Rhs
| Disj_Lhs | Disj_Rhs
| Impl_Lhs | Impl_Rhs
| Eq_Lhs | Eq_Rhs
deriving (Eq,Read,Show)
showCrumbs :: [Crumb] -> String
showCrumbs crs = "[" ++ intercalate ", " (map showCrumb crs) ++ "]"
showCrumb :: Crumb -> String
showCrumb = \case
ModGuts_Prog -> "prog"
ProgCons_Head -> "prog-head"
ProgCons_Tail -> "prog-tail"
NonRec_RHS -> "nonrec-rhs"
Rec_Def n -> "rec-def " ++ show n
Def_RHS -> "def-rhs"
App_Fun -> "app-fun"
App_Arg -> "app-arg"
Lam_Body -> "lam-body"
Let_Bind -> "let-bind"
Let_Body -> "let-body"
Case_Scrutinee -> "case-expr"
Case_Type -> "case-type"
Case_Alt n -> "case-alt " ++ show n
Cast_Expr -> "cast-expr"
Cast_Co -> "cast-co"
Tick_Expr -> "tick-expr"
Alt_RHS -> "alt-rhs"
Type_Type -> "type"
Co_Co -> "coercion"
AppTy_Fun -> "appTy-fun"
AppTy_Arg -> "appTy-arg"
TyConApp_Arg n -> "tyCon-arg " ++ show n
FunTy_Dom -> "fun-dom"
FunTy_CoDom -> "fun-cod"
ForAllTy_Body -> "forall-body"
Refl_Type -> "refl-type"
TyConAppCo_Arg n -> "coCon-arg " ++ show n
AppCo_Fun -> "appCo-fun"
AppCo_Arg -> "appCo-arg"
ForAllCo_Body -> "coForall-body"
AxiomInstCo_Arg n -> "axiom-inst " ++ show n
UnsafeCo_Left -> "unsafe-left"
UnsafeCo_Right -> "unsafe-right"
SymCo_Co -> "sym-co"
TransCo_Left -> "trans-left"
TransCo_Right -> "trans-right"
NthCo_Co -> "nth-co"
InstCo_Co -> "inst-co"
InstCo_Type -> "inst-type"
LRCo_Co -> "lr-co"
Forall_Body -> "forall-body"
Conj_Lhs -> "conj-lhs"
Conj_Rhs -> "conj-rhs"
Disj_Lhs -> "disj-lhs"
Disj_Rhs -> "disj-rhs"
Impl_Lhs -> "antecedent"
Impl_Rhs -> "consequent"
Eq_Lhs -> "eq-lhs"
Eq_Rhs -> "eq-rhs"
_ -> "Warning: Crumb should not be in use! This is probably Neil's fault."
leftSibling :: Crumb -> Maybe Crumb
leftSibling = \case
ProgCons_Tail -> Just ProgCons_Head
Rec_Def n | n > 0 -> Just (Rec_Def (n1))
App_Arg -> Just App_Fun
Let_Body -> Just Let_Bind
Case_Alt n | n == 0 -> Just Case_Scrutinee
| n > 0 -> Just (Case_Alt (n1))
_ -> Nothing
rightSibling :: Crumb -> Maybe Crumb
rightSibling = \case
ProgCons_Head -> Just ProgCons_Tail
Rec_Def n -> Just (Rec_Def (n+1))
App_Fun -> Just App_Arg
Let_Bind -> Just Let_Body
Case_Scrutinee -> Just (Case_Alt 0)
Case_Alt n -> Just (Case_Alt (n+1))
_ -> Nothing
substCoreExpr :: Var -> CoreExpr -> (CoreExpr -> CoreExpr)
substCoreExpr v e expr = substExpr (text "substCoreExpr") (extendSubst emptySub v e) expr
where emptySub = mkEmptySubst (mkInScopeSet (localFreeVarsExpr (Let (NonRec v e) expr)))
substCoreAlt :: Var -> CoreExpr -> CoreAlt -> CoreAlt
substCoreAlt v e alt = let (con, vs, rhs) = alt
inS = (flip delVarSet v . unionVarSet (localFreeVarsExpr e) . localFreeVarsAlt) alt
subst = extendSubst (mkEmptySubst (mkInScopeSet inS)) v e
(subst', vs') = substBndrs subst vs
in (con, vs', substExpr (text "alt-rhs") subst' rhs)
betaReduceAll :: CoreExpr -> [CoreExpr] -> (CoreExpr, [CoreExpr])
betaReduceAll (Lam v body) (a:as) = betaReduceAll (substCoreExpr v a body) as
betaReduceAll e as = (e,as)
mkDataConApp :: [Type] -> DataCon -> [Var] -> CoreExpr
mkDataConApp tys dc vs = mkCoreConApps dc (map Type tys ++ map (varToCoreExpr . zapVarOccInfo) vs)