module Language.SequentCore.Subst (
Subst(..),
TvSubstEnv, IdSubstEnv, InScopeSet,
deShadowBinds, substSpec, substRulesForImportedIds,
substTy, substCo, substTerm, substBind, substBindSC,
substUnfolding, substUnfoldingSC,
lookupIdSubst, lookupTvSubst, lookupCvSubst, substIdOcc,
substTickish, substVarSet,
emptySubst, mkEmptySubst, mkSubst, mkOpenSubst, substInScope, isEmptySubst,
extendIdSubst, extendIdSubstList, extendTvSubst, extendTvSubstList,
extendCvSubst, extendCvSubstList,
extendSubst, extendSubstList, extendSubstWithVar, zapSubstEnv,
addInScopeSet, extendInScope, extendInScopeList, extendInScopeIds,
isInScope, setInScope,
delBndr, delBndrs,
substBndr, substBndrs, substRecBndrs,
cloneBndr, cloneBndrs, cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs
) where
import Language.SequentCore.Syntax
import Language.SequentCore.Translate
import qualified Type
import qualified Coercion
import Type hiding ( substTy, extendTvSubst, extendTvSubstList
, isInScope, substTyVarBndr, cloneTyVarBndr )
import Coercion hiding ( substTy, substCo, extendTvSubst
, substTyVarBndr, substCoVarBndr )
import CoreFVs
import qualified CoreSubst
import CoreSubst ( simpleOptExprWith )
import CoreSyn ( Tickish(..), Unfolding(..), CoreRule(..)
, seqExpr, isStableSource, isClosedUnfolding )
import FastString
import Id
import IdInfo
import Maybes
import Name
import Outputable
import UniqSupply
import Unique
import Var
import VarEnv
import VarSet
import Control.Exception ( assert )
import Data.List ( mapAccumL )
data Subst
= Subst InScopeSet
IdSubstEnv
TvSubstEnv
CvSubstEnv
type IdSubstEnv = IdEnv SeqCoreTerm
isEmptySubst :: Subst -> Bool
isEmptySubst (Subst _ id_env tv_env cv_env)
= isEmptyVarEnv id_env && isEmptyVarEnv tv_env && isEmptyVarEnv cv_env
emptySubst :: Subst
emptySubst = Subst emptyInScopeSet emptyVarEnv emptyVarEnv emptyVarEnv
mkEmptySubst :: InScopeSet -> Subst
mkEmptySubst in_scope = Subst in_scope emptyVarEnv emptyVarEnv emptyVarEnv
mkSubst :: InScopeSet -> TvSubstEnv -> CvSubstEnv -> IdSubstEnv -> Subst
mkSubst in_scope tvs cvs ids = Subst in_scope ids tvs cvs
substInScope :: Subst -> InScopeSet
substInScope (Subst in_scope _ _ _) = in_scope
zapSubstEnv :: Subst -> Subst
zapSubstEnv (Subst in_scope _ _ _) = Subst in_scope emptyVarEnv emptyVarEnv emptyVarEnv
extendIdSubst :: Subst -> Id -> SeqCoreTerm -> Subst
extendIdSubst (Subst in_scope ids tvs cvs) v r = Subst in_scope (extendVarEnv ids v r) tvs cvs
extendIdSubstList :: Subst -> [(Id, SeqCoreTerm)] -> Subst
extendIdSubstList (Subst in_scope ids tvs cvs) prs = Subst in_scope (extendVarEnvList ids prs) tvs cvs
extendTvSubst :: Subst -> TyVar -> Type -> Subst
extendTvSubst (Subst in_scope ids tvs cvs) v r = Subst in_scope ids (extendVarEnv tvs v r) cvs
extendTvSubstList :: Subst -> [(TyVar,Type)] -> Subst
extendTvSubstList (Subst in_scope ids tvs cvs) prs = Subst in_scope ids (extendVarEnvList tvs prs) cvs
extendCvSubst :: Subst -> CoVar -> Coercion -> Subst
extendCvSubst (Subst in_scope ids tvs cvs) v r = Subst in_scope ids tvs (extendVarEnv cvs v r)
extendCvSubstList :: Subst -> [(CoVar,Coercion)] -> Subst
extendCvSubstList (Subst in_scope ids tvs cvs) prs = Subst in_scope ids tvs (extendVarEnvList cvs prs)
extendSubst :: Subst -> Var -> SeqCoreTerm -> Subst
extendSubst subst var arg
= case arg of
Type ty -> assert (isTyVar var) $ extendTvSubst subst var ty
Coercion co -> assert (isCoVar var) $ extendCvSubst subst var co
_ -> assert (isId var) $ extendIdSubst subst var arg
extendSubstWithVar :: Subst -> Var -> Var -> Subst
extendSubstWithVar subst v1 v2
| isTyVar v1 = assert (isTyVar v2) $ extendTvSubst subst v1 (mkTyVarTy v2)
| isCoVar v1 = assert (isCoVar v2) $ extendCvSubst subst v1 (mkCoVarCo v2)
| otherwise = assert (isId v2) $ extendIdSubst subst v1 (Var v2)
extendSubstList :: Subst -> [(Var,SeqCoreTerm)] -> Subst
extendSubstList subst [] = subst
extendSubstList subst ((var,rhs):prs) = extendSubstList (extendSubst subst var rhs) prs
lookupIdSubst :: SDoc -> Subst -> Id -> SeqCoreTerm
lookupIdSubst doc (Subst in_scope ids _ _) v
| not (isLocalId v) = Var v
| Just e <- lookupVarEnv ids v = e
| Just v' <- lookupInScope in_scope v = Var v'
| otherwise = pprTrace "Subst.lookupIdSubst" ( doc <+> ppr v
$$ ppr in_scope)
Var v
lookupTvSubst :: Subst -> TyVar -> Type
lookupTvSubst (Subst _ _ tvs _) v = assert (isTyVar v) $ lookupVarEnv tvs v `orElse` Type.mkTyVarTy v
lookupCvSubst :: Subst -> CoVar -> Coercion
lookupCvSubst (Subst _ _ _ cvs) v = assert (isCoVar v) $ lookupVarEnv cvs v `orElse` mkCoVarCo v
delBndr :: Subst -> Var -> Subst
delBndr (Subst in_scope ids tvs cvs) v
| isCoVar v = Subst in_scope ids tvs (delVarEnv cvs v)
| isTyVar v = Subst in_scope ids (delVarEnv tvs v) cvs
| otherwise = Subst in_scope (delVarEnv ids v) tvs cvs
delBndrs :: Subst -> [Var] -> Subst
delBndrs (Subst in_scope ids tvs cvs) vs
= Subst in_scope (delVarEnvList ids vs) (delVarEnvList tvs vs) (delVarEnvList cvs vs)
mkOpenSubst :: InScopeSet -> [(Var,SeqCoreTerm)] -> Subst
mkOpenSubst in_scope pairs = Subst in_scope
(mkVarEnv [(id,e) | (id, e) <- pairs, isId id])
(mkVarEnv [(tv,ty) | (tv, Type ty) <- pairs])
(mkVarEnv [(v,co) | (v, Coercion co) <- pairs])
isInScope :: Var -> Subst -> Bool
isInScope v (Subst in_scope _ _ _) = v `elemInScopeSet` in_scope
addInScopeSet :: Subst -> VarSet -> Subst
addInScopeSet (Subst in_scope ids tvs cvs) vs
= Subst (in_scope `extendInScopeSetSet` vs) ids tvs cvs
extendInScope :: Subst -> Var -> Subst
extendInScope (Subst in_scope ids tvs cvs) v
= Subst (in_scope `extendInScopeSet` v)
(ids `delVarEnv` v) (tvs `delVarEnv` v) (cvs `delVarEnv` v)
extendInScopeList :: Subst -> [Var] -> Subst
extendInScopeList (Subst in_scope ids tvs cvs) vs
= Subst (in_scope `extendInScopeSetList` vs)
(ids `delVarEnvList` vs) (tvs `delVarEnvList` vs) (cvs `delVarEnvList` vs)
extendInScopeIds :: Subst -> [Id] -> Subst
extendInScopeIds (Subst in_scope ids tvs cvs) vs
= Subst (in_scope `extendInScopeSetList` vs)
(ids `delVarEnvList` vs) tvs cvs
setInScope :: Subst -> InScopeSet -> Subst
setInScope (Subst _ ids tvs cvs) in_scope = Subst in_scope ids tvs cvs
instance Outputable Subst where
ppr (Subst in_scope ids tvs cvs)
= ptext (sLit "<InScope =") <+> braces (fsep (map ppr (varEnvElts (getInScopeVars in_scope))))
$$ ptext (sLit " IdSubst =") <+> ppr ids
$$ ptext (sLit " TvSubst =") <+> ppr tvs
$$ ptext (sLit " CvSubst =") <+> ppr cvs
<> char '>'
toCoreIdSubstEnv :: IdSubstEnv -> CoreSubst.IdSubstEnv
toCoreIdSubstEnv = mapVarEnv termToCoreExpr
toCoreSubst :: Subst -> CoreSubst.Subst
toCoreSubst (Subst ins ids tvs cvs)
= CoreSubst.Subst ins (toCoreIdSubstEnv ids) tvs cvs
substTerm :: SDoc -> Subst -> SeqCoreTerm -> SeqCoreTerm
substTerm _doc subst orig_term = unT $ subst_expr subst (T orig_term)
subst_term :: Subst -> SeqCoreTerm -> SeqCoreTerm
subst_comm :: Subst -> SeqCoreCommand -> SeqCoreCommand
subst_cont :: Subst -> SeqCoreCont -> SeqCoreCont
subst_term subst = unT . subst_expr subst . T
subst_comm subst = unC . subst_expr subst . C
subst_cont subst = unK . subst_expr subst . K
subst_expr :: Subst -> SeqCoreExpr -> SeqCoreExpr
subst_expr subst expr
= go expr
where
go (T term) = T (goV term)
go (C comm) = C (goC comm)
go (K cont) = K (goK cont)
goV (Var v) = lookupIdSubst (text "subst_term") subst v
goV (Type ty) = Type (substTy subst ty)
goV (Coercion co) = Coercion (substCo subst co)
goV (Cont cont) = Cont (goK cont)
goV (Lit lit) = Lit lit
goV (Cons con args) = Cons con (map goV args)
goV (Compute k comm)= Compute k' (subst_comm subst' comm)
where
(subst', k') = substBndr subst k
goV (Lam xs k body) = Lam xs' k' (subst_comm subst'' body)
where
(subst', xs') = mapAccumL substBndr subst xs
(subst'', k') = substBndr subst' k
goK (App arg cont) = App (goV arg) (goK cont)
goK (Tick tickish cont) = Tick (substTickish subst tickish) (goK cont)
goK (Cast co cont) = Cast (substCo subst co) (goK cont)
goK (Case bndr alts) = Case bndr' (map (go_alt subst') alts)
where
(subst', bndr') = substBndr subst bndr
goK (Return k) = case lookupIdSubst (text "subst_cont") subst k of
Var k' -> Return k'
Cont cont -> cont
other -> pprPanic "subst_expr::goK" (ppr other)
goC (Command binds term cont) = Command binds' (subst_term subst' term)
(subst_cont subst' cont)
where
(subst', binds') = mapAccumL substBind subst binds
go_alt subst (Alt con bndrs rhs) = (Alt con bndrs' (subst_comm subst' rhs))
where
(subst', bndrs') = substBndrs subst bndrs
substBind, substBindSC :: Subst -> SeqCoreBind -> (Subst, SeqCoreBind)
substBindSC subst bind
| not (isEmptySubst subst)
= substBind subst bind
| otherwise
= case bind of
NonRec bndr rhs -> (subst', NonRec bndr' rhs)
where
(subst', bndr') = substBndr subst bndr
Rec pairs -> (subst', Rec (bndrs' `zip` rhss'))
where
(bndrs, rhss) = unzip pairs
(subst', bndrs') = substRecBndrs subst bndrs
rhss' | isEmptySubst subst' = rhss
| otherwise = map (subst_term subst') rhss
substBind subst (NonRec bndr rhs) = (subst', NonRec bndr' (subst_term subst rhs))
where
(subst', bndr') = substBndr subst bndr
substBind subst (Rec pairs) = (subst', Rec (bndrs' `zip` rhss'))
where
(bndrs, rhss) = unzip pairs
(subst', bndrs') = substRecBndrs subst bndrs
rhss' = map (subst_term subst') rhss
deShadowBinds :: [SeqCoreBind] -> [SeqCoreBind]
deShadowBinds binds = snd (mapAccumL substBind emptySubst binds)
substBndr :: Subst -> Var -> (Subst, Var)
substBndr subst bndr
| isTyVar bndr = substTyVarBndr subst bndr
| isCoVar bndr = substCoVarBndr subst bndr
| otherwise = substIdBndr (text "var-bndr") subst subst bndr
substBndrs :: Subst -> [Var] -> (Subst, [Var])
substBndrs subst bndrs = mapAccumL substBndr subst bndrs
substRecBndrs :: Subst -> [Id] -> (Subst, [Id])
substRecBndrs subst bndrs
= (new_subst, new_bndrs)
where
(new_subst, new_bndrs) = mapAccumL (substIdBndr (text "rec-bndr") new_subst) subst bndrs
substIdBndr :: SDoc
-> Subst
-> Subst -> Id
-> (Subst, Id)
substIdBndr _doc rec_subst subst@(Subst in_scope env tvs cvs) old_id
=
(Subst (in_scope `extendInScopeSet` new_id) new_env tvs cvs, new_id)
where
id1 = uniqAway in_scope old_id
id2 | no_type_change = id1
| otherwise = setIdType id1 (substTy subst old_ty)
old_ty = idType old_id
no_type_change = isEmptyVarEnv tvs ||
isEmptyVarSet (Type.tyVarsOfType old_ty)
new_id = maybeModifyIdInfo mb_new_info id2
mb_new_info = substIdInfo rec_subst id2 (idInfo id2)
new_env | no_change = delVarEnv env old_id
| otherwise = extendVarEnv env old_id (Var new_id)
no_change = id1 == old_id
cloneIdBndr :: Subst -> UniqSupply -> Id -> (Subst, Id)
cloneIdBndr subst us old_id
= clone_id subst subst (old_id, uniqFromSupply us)
cloneIdBndrs :: Subst -> UniqSupply -> [Id] -> (Subst, [Id])
cloneIdBndrs subst us ids
= mapAccumL (clone_id subst) subst (ids `zip` uniqsFromSupply us)
cloneBndrs :: Subst -> UniqSupply -> [Var] -> (Subst, [Var])
cloneBndrs subst us vs
= mapAccumL (\subst (v, u) -> cloneBndr subst u v) subst (vs `zip` uniqsFromSupply us)
cloneBndr :: Subst -> Unique -> Var -> (Subst, Var)
cloneBndr subst uniq v
| isTyVar v = cloneTyVarBndr subst v uniq
| otherwise = clone_id subst subst (v,uniq)
cloneRecIdBndrs :: Subst -> UniqSupply -> [Id] -> (Subst, [Id])
cloneRecIdBndrs subst us ids
= (subst', ids')
where
(subst', ids') = mapAccumL (clone_id subst') subst
(ids `zip` uniqsFromSupply us)
clone_id :: Subst
-> Subst -> (Id, Unique)
-> (Subst, Id)
clone_id rec_subst subst@(Subst in_scope idvs tvs cvs) (old_id, uniq)
= (Subst (in_scope `extendInScopeSet` new_id) new_idvs tvs new_cvs, new_id)
where
id1 = setVarUnique old_id uniq
id2 = substIdType subst id1
new_id = maybeModifyIdInfo (substIdInfo rec_subst id2 (idInfo old_id)) id2
(new_idvs, new_cvs) | isCoVar old_id = (idvs, extendVarEnv cvs old_id (mkCoVarCo new_id))
| otherwise = (extendVarEnv idvs old_id (Var new_id), cvs)
substTyVarBndr :: Subst -> TyVar -> (Subst, TyVar)
substTyVarBndr (Subst in_scope id_env tv_env cv_env) tv
= case Type.substTyVarBndr (TvSubst in_scope tv_env) tv of
(TvSubst in_scope' tv_env', tv')
-> (Subst in_scope' id_env tv_env' cv_env, tv')
cloneTyVarBndr :: Subst -> TyVar -> Unique -> (Subst, TyVar)
cloneTyVarBndr (Subst in_scope id_env tv_env cv_env) tv uniq
= case Type.cloneTyVarBndr (TvSubst in_scope tv_env) tv uniq of
(TvSubst in_scope' tv_env', tv')
-> (Subst in_scope' id_env tv_env' cv_env, tv')
substCoVarBndr :: Subst -> TyVar -> (Subst, TyVar)
substCoVarBndr (Subst in_scope id_env tv_env cv_env) cv
= case Coercion.substCoVarBndr (CvSubst in_scope tv_env cv_env) cv of
(CvSubst in_scope' tv_env' cv_env', cv')
-> (Subst in_scope' id_env tv_env' cv_env', cv')
substTy :: Subst -> Type -> Type
substTy subst ty = Type.substTy (getTvSubst subst) ty
getTvSubst :: Subst -> TvSubst
getTvSubst (Subst in_scope _ tenv _) = TvSubst in_scope tenv
getCvSubst :: Subst -> CvSubst
getCvSubst (Subst in_scope _ tenv cenv) = CvSubst in_scope tenv cenv
substCo :: Subst -> Coercion -> Coercion
substCo subst co = Coercion.substCo (getCvSubst subst) co
substIdType :: Subst -> Id -> Id
substIdType subst@(Subst _ _ tv_env cv_env) id
| (isEmptyVarEnv tv_env && isEmptyVarEnv cv_env) || isEmptyVarSet (Type.tyVarsOfType old_ty) = id
| otherwise = setIdType id (substTy subst old_ty)
where
old_ty = idType id
substIdInfo :: Subst -> Id -> IdInfo -> Maybe IdInfo
substIdInfo subst new_id info
| nothing_to_do = Nothing
| otherwise = Just (info `setSpecInfo` substSpec subst new_id old_rules
`setUnfoldingInfo` substUnfolding subst old_unf)
where
old_rules = specInfo info
old_unf = unfoldingInfo info
nothing_to_do = isEmptySpecInfo old_rules && isClosedUnfolding old_unf
substUnfolding, substUnfoldingSC :: Subst -> Unfolding -> Unfolding
substUnfoldingSC subst unf
| isEmptySubst subst = unf
| otherwise = substUnfolding subst unf
substUnfolding subst df@(DFunUnfolding { df_bndrs = bndrs, df_args = args })
= df { df_bndrs = bndrs', df_args = args' }
where
(subst',bndrs') = substBndrs subst bndrs
args' = map (onCoreExpr (substTerm (text "subst-unf:dfun") subst'))
args
substUnfolding subst unf@(CoreUnfolding { uf_tmpl = tmpl, uf_src = src })
| not (isStableSource src)
= NoUnfolding
| otherwise
= seqExpr new_tmpl `seq`
unf { uf_tmpl = new_tmpl }
where
new_tmpl = onCoreExpr (substTerm (text "subst-unf") subst) tmpl
substUnfolding _ unf = unf
substIdOcc :: Subst -> Id -> Id
substIdOcc subst v = case lookupIdSubst (text "substIdOcc") subst v of
Var v' -> v'
other -> pprPanic "substIdOcc" (vcat [ppr v <+> ppr other, ppr subst])
substSpec :: Subst -> Id -> SpecInfo -> SpecInfo
substSpec subst new_id (SpecInfo rules rhs_fvs)
= seqSpecInfo new_spec `seq` new_spec
where
subst_ru_fn = const (idName new_id)
new_spec = SpecInfo (map (substRule subst subst_ru_fn) rules)
(substVarSet subst rhs_fvs)
substRulesForImportedIds :: Subst -> [CoreRule] -> [CoreRule]
substRulesForImportedIds subst rules
= map (substRule subst not_needed) rules
where
not_needed name = pprPanic "substRulesForImportedIds" (ppr name)
substRule :: Subst -> (Name -> Name) -> CoreRule -> CoreRule
substRule _ _ rule@(BuiltinRule {}) = rule
substRule subst subst_ru_fn rule@(Rule { ru_bndrs = bndrs, ru_args = args
, ru_fn = fn_name, ru_rhs = rhs
, ru_local = is_local })
= rule { ru_bndrs = bndrs',
ru_fn = if is_local
then subst_ru_fn fn_name
else fn_name,
ru_args = map (onCoreExpr (substTerm (text "subst-rule" <+> ppr fn_name) subst')) args,
ru_rhs = simpleOptExprWith (toCoreSubst subst') rhs }
where
(subst', bndrs') = substBndrs subst bndrs
substVarSet :: Subst -> VarSet -> VarSet
substVarSet subst fvs
= foldVarSet (unionVarSet . subst_fv subst) emptyVarSet fvs
where
subst_fv subst fv
| isId fv = exprFreeVars (termToCoreExpr (lookupIdSubst (text "substVarSet") subst fv))
| otherwise = Type.tyVarsOfType (lookupTvSubst subst fv)
substTickish :: Subst -> Tickish Id -> Tickish Id
substTickish subst (Breakpoint n ids) = Breakpoint n (map do_one ids)
where do_one x = let (Var x') = lookupIdSubst (text "subst_tickish") subst x in x'
substTickish _subst other = other