module HERMIT.Core
(
CoreProg(..)
, CoreDef(..)
, CoreTickish
, progSyntaxEq
, bindSyntaxEq
, defSyntaxEq
, exprSyntaxEq
, altSyntaxEq
, typeSyntaxEq
, coercionSyntaxEq
, progAlphaEq
, bindAlphaEq
, defAlphaEq
, exprAlphaEq
, altAlphaEq
, typeAlphaEq
, coercionAlphaEq
, defsToRecBind
, defToIdExpr
, progToBinds
, bindsToProg
, bindToVarExprs
, progIds
, bindVars
, defId
, altVars
, freeVarsProg
, freeVarsBind
, freeVarsDef
, freeVarsExpr
, freeVarsAlt
, freeVarsVar
, localFreeVarsAlt
, freeVarsType
, freeVarsCoercion
, localFreeVarsExpr
, freeIdsExpr
, localFreeIdsExpr
, isCoArg
, exprKindOrType
, exprTypeM
, endoFunType
, splitFunTypeM
, funArgResTypes
, funsWithInverseTypes
, appCount
, mapAlts
, Crumb(..)
, showCrumbs
, deprecatedLeftSibling
, deprecatedRightSibling
) where
import Control.Monad ((>=>))
import Language.KURE.Combinators.Monad
import Language.KURE.MonadCatch
import HERMIT.GHC
import Data.List (intercalate)
type CoreTickish = Tickish Id
data CoreProg = ProgNil
| ProgCons CoreBind CoreProg
infixr 5 `ProgCons`
progToBinds :: CoreProg -> [CoreBind]
progToBinds ProgNil = []
progToBinds (ProgCons bd p) = bd : progToBinds p
bindsToProg :: [CoreBind] -> CoreProg
bindsToProg = foldr ProgCons ProgNil
bindToVarExprs :: CoreBind -> [(Var,CoreExpr)]
bindToVarExprs (NonRec v e) = [(v,e)]
bindToVarExprs (Rec bds) = bds
data CoreDef = Def Id CoreExpr
defToIdExpr :: CoreDef -> (Id,CoreExpr)
defToIdExpr (Def v e) = (v,e)
defsToRecBind :: [CoreDef] -> CoreBind
defsToRecBind = Rec . map defToIdExpr
progSyntaxEq :: CoreProg -> CoreProg -> Bool
progSyntaxEq ProgNil ProgNil = True
progSyntaxEq (ProgCons bnd1 p1) (ProgCons bnd2 p2) = bindSyntaxEq bnd1 bnd2 && progSyntaxEq p1 p2
progSyntaxEq _ _ = False
bindSyntaxEq :: CoreBind -> CoreBind -> Bool
bindSyntaxEq (NonRec v1 e1) (NonRec v2 e2) = v1 == v2 && exprSyntaxEq e1 e2
bindSyntaxEq (Rec ies1) (Rec ies2) = all2 (\ (i1,e1) (i2,e2) -> i1 == i2 && exprSyntaxEq e1 e2) ies1 ies2
bindSyntaxEq _ _ = False
defSyntaxEq :: CoreDef -> CoreDef -> Bool
defSyntaxEq (Def i1 e1) (Def i2 e2) = i1 == i2 && exprSyntaxEq e1 e2
exprSyntaxEq :: CoreExpr -> CoreExpr -> Bool
exprSyntaxEq (Var i1) (Var i2) = i1 == i2
exprSyntaxEq (Lit l1) (Lit l2) = l1 == l2
exprSyntaxEq (App f1 e1) (App f2 e2) = exprSyntaxEq f1 f2 && exprSyntaxEq e1 e2
exprSyntaxEq (Lam v1 e1) (Lam v2 e2) = v1 == v2 && exprSyntaxEq e1 e2
exprSyntaxEq (Let b1 e1) (Let b2 e2) = bindSyntaxEq b1 b2 && exprSyntaxEq e1 e2
exprSyntaxEq (Case s1 w1 ty1 as1) (Case s2 w2 ty2 as2) = w1 == w2 && exprSyntaxEq s1 s2 && all2 altSyntaxEq as1 as2 && typeSyntaxEq ty1 ty2
exprSyntaxEq (Cast e1 co1) (Cast e2 co2) = exprSyntaxEq e1 e2 && coercionSyntaxEq co1 co2
exprSyntaxEq (Tick t1 e1) (Tick t2 e2) = t1 == t2 && exprSyntaxEq e1 e2
exprSyntaxEq (Type ty1) (Type ty2) = typeSyntaxEq ty1 ty2
exprSyntaxEq (Coercion co1) (Coercion co2) = coercionSyntaxEq co1 co2
exprSyntaxEq _ _ = False
altSyntaxEq :: CoreAlt -> CoreAlt -> Bool
altSyntaxEq (c1,vs1,e1) (c2,vs2,e2) = c1 == c2 && vs1 == vs2 && exprSyntaxEq e1 e2
typeSyntaxEq :: Type -> Type -> Bool
typeSyntaxEq (TyVarTy v1) (TyVarTy v2) = v1 == v2
typeSyntaxEq (LitTy l1) (LitTy l2) = l1 == l2
typeSyntaxEq (AppTy t1 ty1) (AppTy t2 ty2) = typeSyntaxEq t1 t2 && typeSyntaxEq ty1 ty2
typeSyntaxEq (FunTy t1 ty1) (FunTy t2 ty2) = typeSyntaxEq t1 t2 && typeSyntaxEq ty1 ty2
typeSyntaxEq (ForAllTy v1 ty1) (ForAllTy v2 ty2) = v1 == v2 && typeSyntaxEq ty1 ty2
typeSyntaxEq (TyConApp c1 ts1) (TyConApp c2 ts2) = c1 == c2 && all2 typeSyntaxEq ts1 ts2
typeSyntaxEq _ _ = False
coercionSyntaxEq :: Coercion -> Coercion -> Bool
#if __GLASGOW_HASKELL__ > 706
coercionSyntaxEq (Refl role1 ty1) (Refl role2 ty2) = role1 == role2 && typeSyntaxEq ty1 ty2
coercionSyntaxEq (TyConAppCo role1 tc1 cos1) (TyConAppCo role2 tc2 cos2) = role1 == role2 && tc1 == tc2 && all2 coercionSyntaxEq cos1 cos2
#else
coercionSyntaxEq (Refl ty1) (Refl ty2) = typeSyntaxEq ty1 ty2
coercionSyntaxEq (TyConAppCo tc1 cos1) (TyConAppCo tc2 cos2) = tc1 == tc2 && all2 coercionSyntaxEq cos1 cos2
#endif
coercionSyntaxEq (AppCo co11 co12) (AppCo co21 co22) = coercionSyntaxEq co11 co21 && coercionSyntaxEq co12 co22
coercionSyntaxEq (ForAllCo v1 co1) (ForAllCo v2 co2) = v1 == v2 && coercionSyntaxEq co1 co2
coercionSyntaxEq (CoVarCo v1) (CoVarCo v2) = v1 == v2
#if __GLASGOW_HASKELL__ > 706
coercionSyntaxEq (AxiomInstCo con1 ind1 cos1) (AxiomInstCo con2 ind2 cos2) = con1 == con2 && ind1 == ind2 && all2 coercionSyntaxEq cos1 cos2
coercionSyntaxEq (LRCo lr1 co1) (LRCo lr2 co2) = lr1 == lr2 && coercionSyntaxEq co1 co2
coercionSyntaxEq (UnivCo role1 ty11 ty12) (UnivCo role2 ty21 ty22) = role1 == role2 && typeSyntaxEq ty11 ty21 && typeSyntaxEq ty12 ty22
#else
coercionSyntaxEq (AxiomInstCo con1 cos1) (AxiomInstCo con2 cos2) = con1 == con2 && all2 coercionSyntaxEq cos1 cos2
coercionSyntaxEq (UnsafeCo ty11 ty12) (UnsafeCo ty21 ty22) = typeSyntaxEq ty11 ty21 && typeSyntaxEq ty12 ty22
#endif
coercionSyntaxEq (SymCo co1) (SymCo co2) = coercionSyntaxEq co1 co2
coercionSyntaxEq (TransCo co11 co12) (TransCo co21 co22) = coercionSyntaxEq co11 co21 && coercionSyntaxEq co12 co22
coercionSyntaxEq (NthCo n1 co1) (NthCo n2 co2) = n1 == n2 && coercionSyntaxEq co1 co2
coercionSyntaxEq (InstCo co1 ty1) (InstCo co2 ty2) = coercionSyntaxEq co1 co2 && typeSyntaxEq ty1 ty2
coercionSyntaxEq _ _ = False
progAlphaEq :: CoreProg -> CoreProg -> Bool
progAlphaEq ProgNil ProgNil = True
progAlphaEq (ProgCons bnd1 p1) (ProgCons bnd2 p2) = bindVars bnd1 == bindVars bnd2 && bindAlphaEq bnd1 bnd2 && progAlphaEq p1 p2
progAlphaEq _ _ = False
bindAlphaEq :: CoreBind -> CoreBind -> Bool
bindAlphaEq (NonRec _ e1) (NonRec _ e2) = exprAlphaEq e1 e2
bindAlphaEq (Rec ps1) (Rec ps2) = all2 (eqExprX id_unf env) rs1 rs2
where
id_unf _ = noUnfolding
(bs1,rs1) = unzip ps1
(bs2,rs2) = unzip ps2
inScopeSet = mkInScopeSet $ exprsFreeVars (rs1 ++ rs2)
env = rnBndrs2 (mkRnEnv2 inScopeSet) bs1 bs2
bindAlphaEq _ _ = False
defAlphaEq :: CoreDef -> CoreDef -> Bool
defAlphaEq d1 d2 = defsToRecBind [d1] `bindAlphaEq` defsToRecBind [d2]
exprAlphaEq :: CoreExpr -> CoreExpr -> Bool
exprAlphaEq e1 e2 = eqExpr (mkInScopeSet $ exprsFreeVars [e1, e2]) e1 e2
altAlphaEq :: CoreAlt -> CoreAlt -> Bool
altAlphaEq (c1,vs1,e1) (c2,vs2,e2) = c1 == c2 && eqExprX id_unf env e1 e2
where
id_unf _ = noUnfolding
inScopeSet = mkInScopeSet $ exprsFreeVars [e1,e2]
env = rnBndrs2 (mkRnEnv2 inScopeSet) vs1 vs2
typeAlphaEq :: Type -> Type -> Bool
typeAlphaEq = eqType
coercionAlphaEq :: Coercion -> Coercion -> Bool
coercionAlphaEq = coreEqCoercion
progIds :: CoreProg -> [Id]
progIds = \case
ProgNil -> []
ProgCons bnd p -> bindVars bnd ++ progIds p
bindVars :: CoreBind -> [Var]
bindVars = \case
NonRec v _ -> [v]
Rec ds -> map fst ds
defId :: CoreDef -> Id
defId (Def i _) = i
altVars :: CoreAlt -> [Var]
altVars (_,vs,_) = vs
freeVarsExpr :: CoreExpr -> VarSet
freeVarsExpr = exprSomeFreeVars (const True)
freeIdsExpr :: CoreExpr -> IdSet
freeIdsExpr = exprSomeFreeVars isId
localFreeVarsExpr :: CoreExpr -> VarSet
localFreeVarsExpr = exprSomeFreeVars isLocalVar
localFreeIdsExpr :: CoreExpr -> VarSet
localFreeIdsExpr = exprSomeFreeVars isLocalId
freeVarsBind :: CoreBind -> VarSet
freeVarsBind (NonRec v e) = freeVarsExpr e `unionVarSet` freeVarsVar v
freeVarsBind (Rec defs) = let (bs,es) = unzip defs
in delVarSetList (unionVarSets $ map freeVarsExpr es) bs
`unionVarSet` unionVarSets (map freeVarsVar bs)
freeVarsVar :: Var -> VarSet
freeVarsVar v = varTypeTyVars v `unionVarSet` bndrRuleAndUnfoldingVars v
freeVarsDef :: CoreDef -> VarSet
freeVarsDef (Def v e) = delVarSet (freeVarsExpr e) v `unionVarSet` freeVarsVar v
freeVarsAlt :: CoreAlt -> VarSet
freeVarsAlt (_,bs,e) = delVarSetList (freeVarsExpr e `unionVarSet` unionVarSets (map freeVarsVar bs)) bs
localFreeVarsAlt :: CoreAlt -> VarSet
localFreeVarsAlt (_,bs,e) = delVarSetList (localFreeVarsExpr e `unionVarSet` unionVarSets (map freeVarsVar bs)) bs
freeVarsProg :: CoreProg -> VarSet
freeVarsProg = \case
ProgNil -> emptyVarSet
ProgCons bnd p -> freeVarsBind bnd `unionVarSet` delVarSetList (freeVarsProg p) (bindVars bnd)
freeVarsType :: Type -> TyVarSet
freeVarsType = tyVarsOfType
freeVarsCoercion :: Coercion -> VarSet
freeVarsCoercion = tyCoVarsOfCo
exprKindOrType :: CoreExpr -> KindOrType
exprKindOrType (Type t) = typeKind t
exprKindOrType e = exprType e
exprTypeM :: Monad m => CoreExpr -> m Type
exprTypeM (Type _) = fail "exprTypeM failed: expression is a type, so does not have a type."
exprTypeM e = return (exprType e)
isCoArg :: CoreExpr -> Bool
isCoArg (Coercion {}) = True
isCoArg _ = False
appCount :: CoreExpr -> Int
appCount (App e1 _) = appCount e1 + 1
appCount _ = 0
mapAlts :: (CoreExpr -> CoreExpr) -> [CoreAlt] -> [CoreAlt]
mapAlts f alts = [ (ac, vs, f e) | (ac, vs, e) <- alts ]
endoFunType :: Monad m => CoreExpr -> m Type
endoFunType f = do (ty1,ty2) <- funArgResTypes f
guardMsg (eqType ty1 ty2) ("argument and result types differ.")
return ty1
splitFunTypeM :: Monad m => Type -> m (Type,Type)
splitFunTypeM = maybe (fail "not a function type.") return . splitFunTy_maybe
funArgResTypes :: Monad m => CoreExpr -> m (Type,Type)
funArgResTypes = exprTypeM >=> splitFunTypeM
funsWithInverseTypes :: MonadCatch m => CoreExpr -> CoreExpr -> m (Type,Type)
funsWithInverseTypes f g = do (fdom,fcod) <- funArgResTypes f
(gdom,gcod) <- funArgResTypes g
setFailMsg "functions do not have inverse types." $
do guardM (eqType fdom gcod)
guardM (eqType gdom fcod)
return (fdom,fcod)
data Crumb =
ModGuts_Prog
| ProgCons_Head | ProgCons_Tail
| NonRec_RHS | NonRec_Var
| Rec_Def Int
| Def_Id | Def_RHS
| Var_Id
| Lit_Lit
| App_Fun | App_Arg
| Lam_Var | Lam_Body
| Let_Bind | Let_Body
| Case_Scrutinee | Case_Binder | Case_Type | Case_Alt Int
| Cast_Expr | Cast_Co
| Tick_Tick | Tick_Expr
| Type_Type
| Co_Co
| Alt_Con | Alt_Var Int | Alt_RHS
| TyVarTy_TyVar
| LitTy_TyLit
| AppTy_Fun | AppTy_Arg
| TyConApp_TyCon | TyConApp_Arg Int
| FunTy_Dom | FunTy_CoDom
| ForAllTy_Var | ForAllTy_Body
| Refl_Type
| TyConAppCo_TyCon | TyConAppCo_Arg Int
| AppCo_Fun | AppCo_Arg
| ForAllCo_TyVar | ForAllCo_Body
| CoVarCo_CoVar
| AxiomInstCo_Axiom | AxiomInstCo_Index | AxiomInstCo_Arg Int
| UnsafeCo_Left | UnsafeCo_Right
| SymCo_Co
| TransCo_Left | TransCo_Right
| NthCo_Int | NthCo_Co
| InstCo_Co | InstCo_Type
| LRCo_LR | LRCo_Co
deriving (Eq,Read,Show)
showCrumbs :: [Crumb] -> String
showCrumbs crs = "[" ++ intercalate ", " (map showCrumb crs) ++ "]"
showCrumb :: Crumb -> String
showCrumb = \case
ModGuts_Prog -> "prog"
ProgCons_Head -> "prog-head"
ProgCons_Tail -> "prog-tail"
NonRec_RHS -> "nonrec-rhs"
Rec_Def n -> "rec-def " ++ show n
Def_RHS -> "def-rhs"
App_Fun -> "app-fun"
App_Arg -> "app-arg"
Lam_Body -> "lam-body"
Let_Bind -> "let-bind"
Let_Body -> "let-body"
Case_Scrutinee -> "case-expr"
Case_Type -> "case-type"
Case_Alt n -> "case-alt " ++ show n
Cast_Expr -> "cast-expr"
Cast_Co -> "cast-co"
Tick_Expr -> "tick-expr"
Alt_RHS -> "alt-rhs"
Type_Type -> "type"
Co_Co -> "coercion"
AppTy_Fun -> "appTy-fun"
AppTy_Arg -> "appTy-arg"
TyConApp_Arg n -> "tyCon-arg " ++ show n
FunTy_Dom -> "fun-dom"
FunTy_CoDom -> "fun-cod"
ForAllTy_Body -> "forall-body"
Refl_Type -> "refl-type"
TyConAppCo_Arg n -> "coCon-arg " ++ show n
AppCo_Fun -> "appCo-fun"
AppCo_Arg -> "appCo-arg"
ForAllCo_Body -> "coForall-body"
AxiomInstCo_Arg n -> "axiom-inst " ++ show n
UnsafeCo_Left -> "unsafe-left"
UnsafeCo_Right -> "unsafe-right"
SymCo_Co -> "sym-co"
TransCo_Left -> "trans-left"
TransCo_Right -> "trans-right"
NthCo_Co -> "nth-co"
InstCo_Co -> "inst-co"
InstCo_Type -> "inst-type"
LRCo_Co -> "lr-co"
_ -> "Warning: Crumb should not be in use! This is probably Neil's fault."
deprecatedLeftSibling :: Crumb -> Maybe Crumb
deprecatedLeftSibling = \case
ProgCons_Tail -> Just ProgCons_Head
Rec_Def n | n > 0 -> Just (Rec_Def (n1))
App_Arg -> Just App_Fun
Let_Body -> Just Let_Bind
Case_Alt n | n == 0 -> Just Case_Scrutinee
| n > 0 -> Just (Case_Alt (n1))
_ -> Nothing
deprecatedRightSibling :: Crumb -> Maybe Crumb
deprecatedRightSibling = \case
ProgCons_Head -> Just ProgCons_Tail
Rec_Def n -> Just (Rec_Def (n+1))
App_Fun -> Just App_Arg
Let_Bind -> Just Let_Body
Case_Scrutinee -> Just (Case_Alt 0)
Case_Alt n -> Just (Case_Alt (n+1))
_ -> Nothing