{-# LANGUAGE CPP #-}
{-# OPTIONS_GHC -Wno-incomplete-record-updates #-}
module CoreSubst (
        
        Subst(..), 
        TvSubstEnv, IdSubstEnv, InScopeSet,
        
        deShadowBinds, substSpec, substRulesForImportedIds,
        substTy, substCo, substExpr, substExprSC, substBind, substBindSC,
        substUnfolding, substUnfoldingSC,
        lookupIdSubst, lookupTCvSubst, substIdOcc,
        substTickish, substDVarSet, substIdInfo,
        
        emptySubst, mkEmptySubst, mkSubst, mkOpenSubst, substInScope, isEmptySubst,
        extendIdSubst, extendIdSubstList, extendTCvSubst, extendTvSubstList,
        extendSubst, extendSubstList, extendSubstWithVar, zapSubstEnv,
        addInScopeSet, extendInScope, extendInScopeList, extendInScopeIds,
        isInScope, setInScope, getTCvSubst, extendTvSubst, extendCvSubst,
        delBndr, delBndrs,
        
        substBndr, substBndrs, substRecBndrs, substTyVarBndr, substCoVarBndr,
        cloneBndr, cloneBndrs, cloneIdBndr, cloneIdBndrs, cloneRecIdBndrs,
    ) where
#include "HsVersions.h"
import GhcPrelude
import CoreSyn
import CoreFVs
import CoreSeq
import CoreUtils
import qualified Type
import qualified Coercion
        
import Type     hiding ( substTy, extendTvSubst, extendCvSubst, extendTvSubstList
                       , isInScope, substTyVarBndr, cloneTyVarBndr )
import Coercion hiding ( substCo, substCoVarBndr )
import PrelNames
import VarSet
import VarEnv
import Id
import Name     ( Name )
import Var
import IdInfo
import UniqSupply
import Maybes
import Util
import Outputable
import Data.List
data Subst
  = Subst InScopeSet  
                      
          IdSubstEnv  
          TvSubstEnv  
          CvSubstEnv  
        
        
        
        
        
        
        
        
type IdSubstEnv = IdEnv CoreExpr   
isEmptySubst :: Subst -> Bool
isEmptySubst (Subst _ id_env tv_env cv_env)
  = isEmptyVarEnv id_env && isEmptyVarEnv tv_env && isEmptyVarEnv cv_env
emptySubst :: Subst
emptySubst = Subst emptyInScopeSet emptyVarEnv emptyVarEnv emptyVarEnv
mkEmptySubst :: InScopeSet -> Subst
mkEmptySubst in_scope = Subst in_scope emptyVarEnv emptyVarEnv emptyVarEnv
mkSubst :: InScopeSet -> TvSubstEnv -> CvSubstEnv -> IdSubstEnv -> Subst
mkSubst in_scope tvs cvs ids = Subst in_scope ids tvs cvs
substInScope :: Subst -> InScopeSet
substInScope (Subst in_scope _ _ _) = in_scope
zapSubstEnv :: Subst -> Subst
zapSubstEnv (Subst in_scope _ _ _) = Subst in_scope emptyVarEnv emptyVarEnv emptyVarEnv
extendIdSubst :: Subst -> Id -> CoreExpr -> Subst
extendIdSubst (Subst in_scope ids tvs cvs) v r
  = ASSERT2( isNonCoVarId v, ppr v $$ ppr r )
    Subst in_scope (extendVarEnv ids v r) tvs cvs
extendIdSubstList :: Subst -> [(Id, CoreExpr)] -> Subst
extendIdSubstList (Subst in_scope ids tvs cvs) prs
  = ASSERT( all (isNonCoVarId . fst) prs )
    Subst in_scope (extendVarEnvList ids prs) tvs cvs
extendTvSubst :: Subst -> TyVar -> Type -> Subst
extendTvSubst (Subst in_scope ids tvs cvs) tv ty
  = ASSERT( isTyVar tv )
    Subst in_scope ids (extendVarEnv tvs tv ty) cvs
extendTvSubstList :: Subst -> [(TyVar,Type)] -> Subst
extendTvSubstList subst vrs
  = foldl' extend subst vrs
  where
    extend subst (v, r) = extendTvSubst subst v r
extendCvSubst :: Subst -> CoVar -> Coercion -> Subst
extendCvSubst (Subst in_scope ids tvs cvs) v r
  = ASSERT( isCoVar v )
    Subst in_scope ids tvs (extendVarEnv cvs v r)
extendSubst :: Subst -> Var -> CoreArg -> Subst
extendSubst subst var arg
  = case arg of
      Type ty     -> ASSERT( isTyVar var ) extendTvSubst subst var ty
      Coercion co -> ASSERT( isCoVar var ) extendCvSubst subst var co
      _           -> ASSERT( isId    var ) extendIdSubst subst var arg
extendSubstWithVar :: Subst -> Var -> Var -> Subst
extendSubstWithVar subst v1 v2
  | isTyVar v1 = ASSERT( isTyVar v2 ) extendTvSubst subst v1 (mkTyVarTy v2)
  | isCoVar v1 = ASSERT( isCoVar v2 ) extendCvSubst subst v1 (mkCoVarCo v2)
  | otherwise  = ASSERT( isId    v2 ) extendIdSubst subst v1 (Var v2)
extendSubstList :: Subst -> [(Var,CoreArg)] -> Subst
extendSubstList subst []              = subst
extendSubstList subst ((var,rhs):prs) = extendSubstList (extendSubst subst var rhs) prs
lookupIdSubst :: SDoc -> Subst -> Id -> CoreExpr
lookupIdSubst doc (Subst in_scope ids _ _) v
  | not (isLocalId v) = Var v
  | Just e  <- lookupVarEnv ids       v = e
  | Just v' <- lookupInScope in_scope v = Var v'
        
  | otherwise = WARN( True, text "CoreSubst.lookupIdSubst" <+> doc <+> ppr v
                            $$ ppr in_scope)
                Var v
lookupTCvSubst :: Subst -> TyVar -> Type
lookupTCvSubst (Subst _ _ tvs cvs) v
  | isTyVar v
  = lookupVarEnv tvs v `orElse` Type.mkTyVarTy v
  | otherwise
  = mkCoercionTy $ lookupVarEnv cvs v `orElse` mkCoVarCo v
delBndr :: Subst -> Var -> Subst
delBndr (Subst in_scope ids tvs cvs) v
  | isCoVar v = Subst in_scope ids tvs (delVarEnv cvs v)
  | isTyVar v = Subst in_scope ids (delVarEnv tvs v) cvs
  | otherwise = Subst in_scope (delVarEnv ids v) tvs cvs
delBndrs :: Subst -> [Var] -> Subst
delBndrs (Subst in_scope ids tvs cvs) vs
  = Subst in_scope (delVarEnvList ids vs) (delVarEnvList tvs vs) (delVarEnvList cvs vs)
      
mkOpenSubst :: InScopeSet -> [(Var,CoreArg)] -> Subst
mkOpenSubst in_scope pairs = Subst in_scope
                                   (mkVarEnv [(id,e)  | (id, e) <- pairs, isId id])
                                   (mkVarEnv [(tv,ty) | (tv, Type ty) <- pairs])
                                   (mkVarEnv [(v,co)  | (v, Coercion co) <- pairs])
isInScope :: Var -> Subst -> Bool
isInScope v (Subst in_scope _ _ _) = v `elemInScopeSet` in_scope
addInScopeSet :: Subst -> VarSet -> Subst
addInScopeSet (Subst in_scope ids tvs cvs) vs
  = Subst (in_scope `extendInScopeSetSet` vs) ids tvs cvs
extendInScope :: Subst -> Var -> Subst
extendInScope (Subst in_scope ids tvs cvs) v
  = Subst (in_scope `extendInScopeSet` v)
          (ids `delVarEnv` v) (tvs `delVarEnv` v) (cvs `delVarEnv` v)
extendInScopeList :: Subst -> [Var] -> Subst
extendInScopeList (Subst in_scope ids tvs cvs) vs
  = Subst (in_scope `extendInScopeSetList` vs)
          (ids `delVarEnvList` vs) (tvs `delVarEnvList` vs) (cvs `delVarEnvList` vs)
extendInScopeIds :: Subst -> [Id] -> Subst
extendInScopeIds (Subst in_scope ids tvs cvs) vs
  = Subst (in_scope `extendInScopeSetList` vs)
          (ids `delVarEnvList` vs) tvs cvs
setInScope :: Subst -> InScopeSet -> Subst
setInScope (Subst _ ids tvs cvs) in_scope = Subst in_scope ids tvs cvs
instance Outputable Subst where
  ppr (Subst in_scope ids tvs cvs)
        =  text "<InScope =" <+> in_scope_doc
        $$ text " IdSubst   =" <+> ppr ids
        $$ text " TvSubst   =" <+> ppr tvs
        $$ text " CvSubst   =" <+> ppr cvs
         <> char '>'
    where
    in_scope_doc = pprVarSet (getInScopeVars in_scope) (braces . fsep . map ppr)
substExprSC :: SDoc -> Subst -> CoreExpr -> CoreExpr
substExprSC doc subst orig_expr
  | isEmptySubst subst = orig_expr
  | otherwise          = 
                         subst_expr doc subst orig_expr
substExpr :: SDoc -> Subst -> CoreExpr -> CoreExpr
substExpr doc subst orig_expr = subst_expr doc subst orig_expr
subst_expr :: SDoc -> Subst -> CoreExpr -> CoreExpr
subst_expr doc subst expr
  = go expr
  where
    go (Var v)         = lookupIdSubst (doc $$ text "subst_expr") subst v
    go (Type ty)       = Type (substTy subst ty)
    go (Coercion co)   = Coercion (substCo subst co)
    go (Lit lit)       = Lit lit
    go (App fun arg)   = App (go fun) (go arg)
    go (Tick tickish e) = mkTick (substTickish subst tickish) (go e)
    go (Cast e co)     = Cast (go e) (substCo subst co)
       
       
       
       
       
    go (Lam bndr body) = Lam bndr' (subst_expr doc subst' body)
                       where
                         (subst', bndr') = substBndr subst bndr
    go (Let bind body) = Let bind' (subst_expr doc subst' body)
                       where
                         (subst', bind') = substBind subst bind
    go (Case scrut bndr ty alts) = Case (go scrut) bndr' (substTy subst ty) (map (go_alt subst') alts)
                                 where
                                 (subst', bndr') = substBndr subst bndr
    go_alt subst (con, bndrs, rhs) = (con, bndrs', subst_expr doc subst' rhs)
                                 where
                                   (subst', bndrs') = substBndrs subst bndrs
substBind, substBindSC :: Subst -> CoreBind -> (Subst, CoreBind)
substBindSC subst bind    
  | not (isEmptySubst subst)
  = substBind subst bind
  | otherwise
  = case bind of
       NonRec bndr rhs -> (subst', NonRec bndr' rhs)
          where
            (subst', bndr') = substBndr subst bndr
       Rec pairs -> (subst', Rec (bndrs' `zip` rhss'))
          where
            (bndrs, rhss)    = unzip pairs
            (subst', bndrs') = substRecBndrs subst bndrs
            rhss' | isEmptySubst subst'
                  = rhss
                  | otherwise
                  = map (subst_expr (text "substBindSC") subst') rhss
substBind subst (NonRec bndr rhs)
  = (subst', NonRec bndr' (subst_expr (text "substBind") subst rhs))
  where
    (subst', bndr') = substBndr subst bndr
substBind subst (Rec pairs)
   = (subst', Rec (bndrs' `zip` rhss'))
   where
       (bndrs, rhss)    = unzip pairs
       (subst', bndrs') = substRecBndrs subst bndrs
       rhss' = map (subst_expr (text "substBind") subst') rhss
deShadowBinds :: CoreProgram -> CoreProgram
deShadowBinds binds = snd (mapAccumL substBind emptySubst binds)
substBndr :: Subst -> Var -> (Subst, Var)
substBndr subst bndr
  | isTyVar bndr  = substTyVarBndr subst bndr
  | isCoVar bndr  = substCoVarBndr subst bndr
  | otherwise     = substIdBndr (text "var-bndr") subst subst bndr
substBndrs :: Subst -> [Var] -> (Subst, [Var])
substBndrs subst bndrs = mapAccumL substBndr subst bndrs
substRecBndrs :: Subst -> [Id] -> (Subst, [Id])
substRecBndrs subst bndrs
  = (new_subst, new_bndrs)
  where         
    (new_subst, new_bndrs) = mapAccumL (substIdBndr (text "rec-bndr") new_subst) subst bndrs
substIdBndr :: SDoc
            -> Subst            
            -> Subst -> Id      
            -> (Subst, Id)      
                                
substIdBndr _doc rec_subst subst@(Subst in_scope env tvs cvs) old_id
  = 
    (Subst (in_scope `extendInScopeSet` new_id) new_env tvs cvs, new_id)
  where
    id1 = uniqAway in_scope old_id      
    id2 | no_type_change = id1
        | otherwise      = setIdType id1 (substTy subst old_ty)
    old_ty = idType old_id
    no_type_change = (isEmptyVarEnv tvs && isEmptyVarEnv cvs) ||
                     noFreeVarsOfType old_ty
        
        
        
    new_id = maybeModifyIdInfo mb_new_info id2
    mb_new_info = substIdInfo rec_subst id2 (idInfo id2)
        
        
        
    new_env | no_change = delVarEnv env old_id
            | otherwise = extendVarEnv env old_id (Var new_id)
    no_change = id1 == old_id
        
        
cloneIdBndr :: Subst -> UniqSupply -> Id -> (Subst, Id)
cloneIdBndr subst us old_id
  = clone_id subst subst (old_id, uniqFromSupply us)
cloneIdBndrs :: Subst -> UniqSupply -> [Id] -> (Subst, [Id])
cloneIdBndrs subst us ids
  = mapAccumL (clone_id subst) subst (ids `zip` uniqsFromSupply us)
cloneBndrs :: Subst -> UniqSupply -> [Var] -> (Subst, [Var])
cloneBndrs subst us vs
  = mapAccumL (\subst (v, u) -> cloneBndr subst u v) subst (vs `zip` uniqsFromSupply us)
cloneBndr :: Subst -> Unique -> Var -> (Subst, Var)
cloneBndr subst uniq v
  | isTyVar v = cloneTyVarBndr subst v uniq
  | otherwise = clone_id subst subst (v,uniq)  
cloneRecIdBndrs :: Subst -> UniqSupply -> [Id] -> (Subst, [Id])
cloneRecIdBndrs subst us ids
  = (subst', ids')
  where
    (subst', ids') = mapAccumL (clone_id subst') subst
                               (ids `zip` uniqsFromSupply us)
clone_id    :: Subst                    
            -> Subst -> (Id, Unique)    
            -> (Subst, Id)              
clone_id rec_subst subst@(Subst in_scope idvs tvs cvs) (old_id, uniq)
  = (Subst (in_scope `extendInScopeSet` new_id) new_idvs tvs new_cvs, new_id)
  where
    id1     = setVarUnique old_id uniq
    id2     = substIdType subst id1
    new_id  = maybeModifyIdInfo (substIdInfo rec_subst id2 (idInfo old_id)) id2
    (new_idvs, new_cvs) | isCoVar old_id = (idvs, extendVarEnv cvs old_id (mkCoVarCo new_id))
                        | otherwise      = (extendVarEnv idvs old_id (Var new_id), cvs)
substTyVarBndr :: Subst -> TyVar -> (Subst, TyVar)
substTyVarBndr (Subst in_scope id_env tv_env cv_env) tv
  = case Type.substTyVarBndr (TCvSubst in_scope tv_env cv_env) tv of
        (TCvSubst in_scope' tv_env' cv_env', tv')
           -> (Subst in_scope' id_env tv_env' cv_env', tv')
cloneTyVarBndr :: Subst -> TyVar -> Unique -> (Subst, TyVar)
cloneTyVarBndr (Subst in_scope id_env tv_env cv_env) tv uniq
  = case Type.cloneTyVarBndr (TCvSubst in_scope tv_env cv_env) tv uniq of
        (TCvSubst in_scope' tv_env' cv_env', tv')
           -> (Subst in_scope' id_env tv_env' cv_env', tv')
substCoVarBndr :: Subst -> TyVar -> (Subst, TyVar)
substCoVarBndr (Subst in_scope id_env tv_env cv_env) cv
  = case Coercion.substCoVarBndr (TCvSubst in_scope tv_env cv_env) cv of
        (TCvSubst in_scope' tv_env' cv_env', cv')
           -> (Subst in_scope' id_env tv_env' cv_env', cv')
substTy :: Subst -> Type -> Type
substTy subst ty = Type.substTyUnchecked (getTCvSubst subst) ty
getTCvSubst :: Subst -> TCvSubst
getTCvSubst (Subst in_scope _ tenv cenv) = TCvSubst in_scope tenv cenv
substCo :: HasCallStack => Subst -> Coercion -> Coercion
substCo subst co = Coercion.substCo (getTCvSubst subst) co
substIdType :: Subst -> Id -> Id
substIdType subst@(Subst _ _ tv_env cv_env) id
  | (isEmptyVarEnv tv_env && isEmptyVarEnv cv_env) || noFreeVarsOfType old_ty = id
  | otherwise   = setIdType id (substTy subst old_ty)
                
                
                
  where
    old_ty = idType id
substIdInfo :: Subst -> Id -> IdInfo -> Maybe IdInfo
substIdInfo subst new_id info
  | nothing_to_do = Nothing
  | otherwise     = Just (info `setRuleInfo`      substSpec subst new_id old_rules
                               `setUnfoldingInfo` substUnfolding subst old_unf)
  where
    old_rules     = ruleInfo info
    old_unf       = unfoldingInfo info
    nothing_to_do = isEmptyRuleInfo old_rules && not (isFragileUnfolding old_unf)
substUnfolding, substUnfoldingSC :: Subst -> Unfolding -> Unfolding
        
        
substUnfoldingSC subst unf       
  | isEmptySubst subst = unf
  | otherwise          = substUnfolding subst unf
substUnfolding subst df@(DFunUnfolding { df_bndrs = bndrs, df_args = args })
  = df { df_bndrs = bndrs', df_args = args' }
  where
    (subst',bndrs') = substBndrs subst bndrs
    args'           = map (substExpr (text "subst-unf:dfun") subst') args
substUnfolding subst unf@(CoreUnfolding { uf_tmpl = tmpl, uf_src = src })
        
  | not (isStableSource src)  
  = NoUnfolding
  | otherwise                 
  = seqExpr new_tmpl `seq`
    unf { uf_tmpl = new_tmpl }
  where
    new_tmpl = substExpr (text "subst-unf") subst tmpl
substUnfolding _ unf = unf      
substIdOcc :: Subst -> Id -> Id
substIdOcc subst v = case lookupIdSubst (text "substIdOcc") subst v of
                        Var v' -> v'
                        other  -> pprPanic "substIdOcc" (vcat [ppr v <+> ppr other, ppr subst])
substSpec :: Subst -> Id -> RuleInfo -> RuleInfo
substSpec subst new_id (RuleInfo rules rhs_fvs)
  = seqRuleInfo new_spec `seq` new_spec
  where
    subst_ru_fn = const (idName new_id)
    new_spec = RuleInfo (map (substRule subst subst_ru_fn) rules)
                        (substDVarSet subst rhs_fvs)
substRulesForImportedIds :: Subst -> [CoreRule] -> [CoreRule]
substRulesForImportedIds subst rules
  = map (substRule subst not_needed) rules
  where
    not_needed name = pprPanic "substRulesForImportedIds" (ppr name)
substRule :: Subst -> (Name -> Name) -> CoreRule -> CoreRule
substRule _ _ rule@(BuiltinRule {}) = rule
substRule subst subst_ru_fn rule@(Rule { ru_bndrs = bndrs, ru_args = args
                                       , ru_fn = fn_name, ru_rhs = rhs
                                       , ru_local = is_local })
  = rule { ru_bndrs = bndrs'
         , ru_fn    = if is_local
                        then subst_ru_fn fn_name
                        else fn_name
         , ru_args  = map (substExpr doc subst') args
         , ru_rhs   = substExpr (text "foo") subst' rhs }
           
           
  where
    doc = text "subst-rule" <+> ppr fn_name
    (subst', bndrs') = substBndrs subst bndrs
substDVarSet :: Subst -> DVarSet -> DVarSet
substDVarSet subst fvs
  = mkDVarSet $ fst $ foldr (subst_fv subst) ([], emptyVarSet) $ dVarSetElems fvs
  where
  subst_fv subst fv acc
     | isId fv = expr_fvs (lookupIdSubst (text "substDVarSet") subst fv) isLocalVar emptyVarSet $! acc
     | otherwise = tyCoFVsOfType (lookupTCvSubst subst fv) (const True) emptyVarSet $! acc
substTickish :: Subst -> Tickish Id -> Tickish Id
substTickish subst (Breakpoint n ids)
   = Breakpoint n (map do_one ids)
 where
    do_one = getIdFromTrivialExpr . lookupIdSubst (text "subst_tickish") subst
substTickish _subst other = other