-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Retrie.Context
  ( ContextUpdater
  , updateContext
  , emptyContext
  ) where

import Control.Monad.IO.Class
import Data.Char (isDigit)
import Data.Either (partitionEithers)
import Data.Generics hiding (Fixity)
import Data.List
import Data.Maybe

import Retrie.AlphaEnv
import Retrie.ExactPrint
import Retrie.Fixity
import Retrie.FreeVars
import Retrie.GHC
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types
import Retrie.Universe

-------------------------------------------------------------------------------

-- | Type of context update functions for 'apply'.
-- When defining your own 'ContextUpdater', you probably want to extend
-- 'updateContext' using SYB combinators such as 'mkQ' and 'extQ'.
type ContextUpdater = forall m. MonadIO m => GenericCU (TransformT m) Context

-- | Default context update function.
updateContext :: forall m. MonadIO m => GenericCU (TransformT m) Context
updateContext :: forall (m :: * -> *). MonadIO m => GenericCU (TransformT m) Context
updateContext Context
c Int
i =
  forall a b. a -> b -> a
const (forall (m :: * -> *) a. Monad m => a -> m a
return Context
c)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsExpr (GhcPass 'Parsed) -> Context
updExp)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsType (GhcPass 'Parsed) -> Context
updType)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. Match (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updMatch)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. GRHSs (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updGRHSs)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. GRHS (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updGRHS)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updStmt)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (GhcPass 'Parsed) -> Context
updPat)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` [LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))]
-> TransformT m Context
updStmtList
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. HsBind (GhcPass 'Parsed) -> Context
updHsBind)
    forall a b q.
(Typeable a, Typeable b) =>
(a -> q) -> (b -> q) -> a -> q
`extQ` (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyClDecl (GhcPass 'Parsed) -> Context
updTyClDecl)
  where
    neverParen :: Context
neverParen = Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
NeverParen }

    updExp :: HsExpr GhcPs -> Context
    updExp :: HsExpr (GhcPass 'Parsed) -> Context
updExp HsApp{} =
      Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = Fixity -> ParentPrec
HasPrec forall a b. (a -> b) -> a -> b
$ SourceText -> Int -> FixityDirection -> Fixity
Fixity (String -> SourceText
SourceText String
"HsApp") (Int
10 forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
- Int
firstChild) FixityDirection
InfixL }
    -- Reason for 10 + i: (i is index of child, 0 = left, 1 = right)
    -- In left child, prec is 10, so HsApp child will NOT get paren'd
    -- In right child, prec is 11, so every child gets paren'd (unless atomic)
    updExp (OpApp XOpApp (GhcPass 'Parsed)
_ LHsExpr (GhcPass 'Parsed)
_ LHsExpr (GhcPass 'Parsed)
op LHsExpr (GhcPass 'Parsed)
_) = Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = Fixity -> ParentPrec
HasPrec forall a b. (a -> b) -> a -> b
$ LHsExpr (GhcPass 'Parsed) -> FixityEnv -> Fixity
lookupOp LHsExpr (GhcPass 'Parsed)
op (Context -> FixityEnv
ctxtFixityEnv Context
c) }
#if __GLASGOW_HASKELL__ < 904
    updExp (HsLet XLet (GhcPass 'Parsed)
_ HsLocalBinds (GhcPass 'Parsed)
lbs LHsExpr (GhcPass 'Parsed)
_) = Context -> [RdrName] -> Context
addInScope Context
neverParen forall a b. (a -> b) -> a -> b
$ forall (idL :: Pass) (idR :: Pass).
CollectPass (GhcPass idL) =>
CollectFlag (GhcPass idL)
-> HsLocalBindsLR (GhcPass idL) (GhcPass idR)
-> [IdP (GhcPass idL)]
collectLocalBinders forall p. CollectFlag p
CollNoDictBinders HsLocalBinds (GhcPass 'Parsed)
lbs
#else
    updExp (HsLet _ _ lbs _ _) = addInScope neverParen $ collectLocalBinders CollNoDictBinders lbs
#endif
    updExp HsExpr (GhcPass 'Parsed)
_ = Context
neverParen

    updType :: HsType GhcPs -> Context
    updType :: HsType (GhcPass 'Parsed) -> Context
updType HsAppTy{}
      | Int
i forall a. Ord a => a -> a -> Bool
> Int
firstChild = Context
c { ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
IsHsAppsTy }
    updType HsType (GhcPass 'Parsed)
_ = Context
neverParen

    updMatch :: Match GhcPs (LHsExpr GhcPs) -> Context
    updMatch :: Match (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updMatch
      | Int
i forall a. Eq a => a -> a -> Bool
== Int
2  -- m_pats field
      = Context -> [RdrName] -> Context
addInScope Context
c{ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
IsLhs} forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p. CollectPass p => CollectFlag p -> [LPat p] -> [IdP p]
collectPatsBinders forall p. CollectFlag p
CollNoDictBinders forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p body. Match p body -> [LPat p]
m_pats
      | Bool
otherwise = Context -> [RdrName] -> Context
addInScope Context
neverParen forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p. CollectPass p => CollectFlag p -> [LPat p] -> [IdP p]
collectPatsBinders forall p. CollectFlag p
CollNoDictBinders forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p body. Match p body -> [LPat p]
m_pats
      where

    updGRHSs :: GRHSs GhcPs (LHsExpr GhcPs) -> Context
    updGRHSs :: GRHSs (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updGRHSs = Context -> [RdrName] -> Context
addInScope Context
neverParen forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (idL :: Pass) (idR :: Pass).
CollectPass (GhcPass idL) =>
CollectFlag (GhcPass idL)
-> HsLocalBindsLR (GhcPass idL) (GhcPass idR)
-> [IdP (GhcPass idL)]
collectLocalBinders forall p. CollectFlag p
CollNoDictBinders forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p body. GRHSs p body -> HsLocalBinds p
grhssLocalBinds

    updGRHS :: GRHS GhcPs (LHsExpr GhcPs) -> Context
#if __GLASGOW_HASKELL__ < 900
    updGRHS XGRHS{} = neverParen
#endif
    updGRHS :: GRHS (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updGRHS (GRHS XCGRHS (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
_ [LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))]
gs LHsExpr (GhcPass 'Parsed)
_)
        -- binders are in scope over the body (right child) only
      | Int
i forall a. Ord a => a -> a -> Bool
> Int
firstChild = Context -> [RdrName] -> Context
addInScope Context
neverParen [RdrName]
bs
      | Bool
otherwise = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
neverParen [RdrName]
bs
      where
        bs :: [IdP (GhcPass 'Parsed)]
bs = forall (idL :: Pass) (idR :: Pass) body.
CollectPass (GhcPass idL) =>
CollectFlag (GhcPass idL)
-> [LStmtLR (GhcPass idL) (GhcPass idR) body]
-> [IdP (GhcPass idL)]
collectLStmtsBinders forall p. CollectFlag p
CollNoDictBinders [LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))]
gs

    updStmt :: Stmt GhcPs (LHsExpr GhcPs) -> Context
    updStmt :: Stmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed)) -> Context
updStmt Stmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
_ = Context
neverParen

    updStmtList :: [LStmt GhcPs (LHsExpr GhcPs)] -> TransformT m Context
    updStmtList :: [LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))]
-> TransformT m Context
updStmtList [] = forall (m :: * -> *) a. Monad m => a -> m a
return Context
neverParen
    updStmtList (LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
ls:[LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))]
_)
        -- binders are in scope over tail of list (right child)
      | Int
i forall a. Ord a => a -> a -> Bool
> Int
0 = forall k (m :: * -> *).
(Matchable k, MonadIO m) =>
Context -> [RdrName] -> k -> TransformT m Context
insertDependentRewrites Context
neverParen [RdrName]
bs LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
ls
        -- lets are recursive in do-blocks
      | L SrcSpanAnnA
_ (LetStmt XLetStmt
  (GhcPass 'Parsed)
  (GhcPass 'Parsed)
  (GenLocated SrcSpanAnnA (HsExpr (GhcPass 'Parsed)))
_ HsLocalBinds (GhcPass 'Parsed)
bnds) <- LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
ls =
          forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Context -> [RdrName] -> Context
addInScope Context
neverParen forall a b. (a -> b) -> a -> b
$ forall (idL :: Pass) (idR :: Pass).
CollectPass (GhcPass idL) =>
CollectFlag (GhcPass idL)
-> HsLocalBindsLR (GhcPass idL) (GhcPass idR)
-> [IdP (GhcPass idL)]
collectLocalBinders forall p. CollectFlag p
CollNoDictBinders HsLocalBinds (GhcPass 'Parsed)
bnds
      | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
neverParen [RdrName]
bs
      where
        bs :: [IdP (GhcPass 'Parsed)]
bs = forall (idL :: Pass) (idR :: Pass) body.
CollectPass (GhcPass idL) =>
CollectFlag (GhcPass idL)
-> LStmtLR (GhcPass idL) (GhcPass idR) body -> [IdP (GhcPass idL)]
collectLStmtBinders forall p. CollectFlag p
CollNoDictBinders LStmt (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
ls

    updHsBind :: HsBind GhcPs -> Context
    updHsBind :: HsBind (GhcPass 'Parsed) -> Context
updHsBind FunBind{[CoreTickish]
MatchGroup (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
LIdP (GhcPass 'Parsed)
XFunBind (GhcPass 'Parsed) (GhcPass 'Parsed)
fun_ext :: forall idL idR. HsBindLR idL idR -> XFunBind idL idR
fun_id :: forall idL idR. HsBindLR idL idR -> LIdP idL
fun_matches :: forall idL idR. HsBindLR idL idR -> MatchGroup idR (LHsExpr idR)
fun_tick :: forall idL idR. HsBindLR idL idR -> [CoreTickish]
fun_tick :: [CoreTickish]
fun_matches :: MatchGroup (GhcPass 'Parsed) (LHsExpr (GhcPass 'Parsed))
fun_id :: LIdP (GhcPass 'Parsed)
fun_ext :: XFunBind (GhcPass 'Parsed) (GhcPass 'Parsed)
..} =
      let rdr :: RdrName
rdr = forall l e. GenLocated l e -> e
unLoc LIdP (GhcPass 'Parsed)
fun_id
      in Context -> [RdrName] -> Context
addBinders (Context -> [RdrName] -> Context
addInScope Context
neverParen [RdrName
rdr]) [RdrName
rdr]
    updHsBind HsBind (GhcPass 'Parsed)
_ = Context
neverParen

    updTyClDecl :: TyClDecl GhcPs -> Context
    updTyClDecl :: TyClDecl (GhcPass 'Parsed) -> Context
updTyClDecl SynDecl{LHsQTyVars (GhcPass 'Parsed)
XSynDecl (GhcPass 'Parsed)
LHsType (GhcPass 'Parsed)
LIdP (GhcPass 'Parsed)
LexicalFixity
tcdFixity :: forall pass. TyClDecl pass -> LexicalFixity
tcdLName :: forall pass. TyClDecl pass -> LIdP pass
tcdRhs :: forall pass. TyClDecl pass -> LHsType pass
tcdSExt :: forall pass. TyClDecl pass -> XSynDecl pass
tcdTyVars :: forall pass. TyClDecl pass -> LHsQTyVars pass
tcdRhs :: LHsType (GhcPass 'Parsed)
tcdFixity :: LexicalFixity
tcdTyVars :: LHsQTyVars (GhcPass 'Parsed)
tcdLName :: LIdP (GhcPass 'Parsed)
tcdSExt :: XSynDecl (GhcPass 'Parsed)
..} = Context -> [RdrName] -> Context
addInScope Context
neverParen [forall l e. GenLocated l e -> e
unLoc LIdP (GhcPass 'Parsed)
tcdLName]
    updTyClDecl DataDecl{HsDataDefn (GhcPass 'Parsed)
LHsQTyVars (GhcPass 'Parsed)
LIdP (GhcPass 'Parsed)
XDataDecl (GhcPass 'Parsed)
LexicalFixity
tcdDExt :: forall pass. TyClDecl pass -> XDataDecl pass
tcdDataDefn :: forall pass. TyClDecl pass -> HsDataDefn pass
tcdDataDefn :: HsDataDefn (GhcPass 'Parsed)
tcdFixity :: LexicalFixity
tcdTyVars :: LHsQTyVars (GhcPass 'Parsed)
tcdLName :: LIdP (GhcPass 'Parsed)
tcdDExt :: XDataDecl (GhcPass 'Parsed)
tcdFixity :: forall pass. TyClDecl pass -> LexicalFixity
tcdLName :: forall pass. TyClDecl pass -> LIdP pass
tcdTyVars :: forall pass. TyClDecl pass -> LHsQTyVars pass
..} = Context -> [RdrName] -> Context
addInScope Context
neverParen [forall l e. GenLocated l e -> e
unLoc LIdP (GhcPass 'Parsed)
tcdLName]
    updTyClDecl ClassDecl{[LTyFamDefltDecl (GhcPass 'Parsed)]
[LHsFunDep (GhcPass 'Parsed)]
[LFamilyDecl (GhcPass 'Parsed)]
[LDocDecl (GhcPass 'Parsed)]
[LSig (GhcPass 'Parsed)]
Maybe (LHsContext (GhcPass 'Parsed))
LHsQTyVars (GhcPass 'Parsed)
LIdP (GhcPass 'Parsed)
XClassDecl (GhcPass 'Parsed)
LHsBinds (GhcPass 'Parsed)
LexicalFixity
tcdATDefs :: forall pass. TyClDecl pass -> [LTyFamDefltDecl pass]
tcdATs :: forall pass. TyClDecl pass -> [LFamilyDecl pass]
tcdCExt :: forall pass. TyClDecl pass -> XClassDecl pass
tcdCtxt :: forall pass. TyClDecl pass -> Maybe (LHsContext pass)
tcdDocs :: forall pass. TyClDecl pass -> [LDocDecl pass]
tcdFDs :: forall pass. TyClDecl pass -> [LHsFunDep pass]
tcdMeths :: forall pass. TyClDecl pass -> LHsBinds pass
tcdSigs :: forall pass. TyClDecl pass -> [LSig pass]
tcdDocs :: [LDocDecl (GhcPass 'Parsed)]
tcdATDefs :: [LTyFamDefltDecl (GhcPass 'Parsed)]
tcdATs :: [LFamilyDecl (GhcPass 'Parsed)]
tcdMeths :: LHsBinds (GhcPass 'Parsed)
tcdSigs :: [LSig (GhcPass 'Parsed)]
tcdFDs :: [LHsFunDep (GhcPass 'Parsed)]
tcdFixity :: LexicalFixity
tcdTyVars :: LHsQTyVars (GhcPass 'Parsed)
tcdLName :: LIdP (GhcPass 'Parsed)
tcdCtxt :: Maybe (LHsContext (GhcPass 'Parsed))
tcdCExt :: XClassDecl (GhcPass 'Parsed)
tcdFixity :: forall pass. TyClDecl pass -> LexicalFixity
tcdLName :: forall pass. TyClDecl pass -> LIdP pass
tcdTyVars :: forall pass. TyClDecl pass -> LHsQTyVars pass
..} = Context -> [RdrName] -> Context
addInScope Context
neverParen [forall l e. GenLocated l e -> e
unLoc LIdP (GhcPass 'Parsed)
tcdLName]
    updTyClDecl TyClDecl (GhcPass 'Parsed)
_ = Context
neverParen

    updPat :: Pat GhcPs -> Context
    updPat :: Pat (GhcPass 'Parsed) -> Context
updPat Pat (GhcPass 'Parsed)
_ = Context
neverParen

-- | Create an empty 'Context' with given 'FixityEnv', rewriter, and dependent
-- rewrite generator.
emptyContext :: FixityEnv -> Rewriter -> Rewriter -> Context
emptyContext :: FixityEnv -> Rewriter -> Rewriter -> Context
emptyContext FixityEnv
ctxtFixityEnv Rewriter
ctxtRewriter Rewriter
ctxtDependents = Context{FixityEnv
AlphaEnv
Rewriter
ParentPrec
forall {a}. [a]
forall {a}. Maybe a
ctxtSubst :: Maybe Substitution
ctxtRewriter :: Rewriter
ctxtInScope :: AlphaEnv
ctxtDependents :: Rewriter
ctxtBinders :: [RdrName]
ctxtSubst :: forall {a}. Maybe a
ctxtParentPrec :: ParentPrec
ctxtInScope :: AlphaEnv
ctxtBinders :: forall {a}. [a]
ctxtDependents :: Rewriter
ctxtRewriter :: Rewriter
ctxtFixityEnv :: FixityEnv
ctxtFixityEnv :: FixityEnv
ctxtParentPrec :: ParentPrec
..}
  where
    ctxtBinders :: [a]
ctxtBinders = []
    ctxtInScope :: AlphaEnv
ctxtInScope = AlphaEnv
emptyAlphaEnv
    ctxtParentPrec :: ParentPrec
ctxtParentPrec = ParentPrec
NeverParen
    ctxtSubst :: Maybe a
ctxtSubst = forall {a}. Maybe a
Nothing

-- Deal with Trees-That-Grow adding extension points
-- as the first child everywhere.
firstChild :: Int
firstChild :: Int
firstChild = Int
1

-- | Add dependent rewrites to 'ctxtRewriter' if necessary.
insertDependentRewrites
  :: (Matchable k, MonadIO m) => Context -> [RdrName] -> k -> TransformT m Context
insertDependentRewrites :: forall k (m :: * -> *).
(Matchable k, MonadIO m) =>
Context -> [RdrName] -> k -> TransformT m Context
insertDependentRewrites Context
c [RdrName]
bs k
x = do
  MatchResult k
r <- forall ast (m :: * -> *).
(Matchable ast, MonadIO m) =>
(RewriterResult Universe -> RewriterResult Universe)
-> Context -> Rewriter -> ast -> TransformT m (MatchResult ast)
runRewriter forall a. a -> a
id Context
c (Context -> Rewriter
ctxtDependents Context
c) k
x
  let
    c' :: Context
c' = Context -> [RdrName] -> Context
addInScope Context
c [RdrName]
bs
  case MatchResult k
r of
    MatchResult k
NoMatch -> forall (m :: * -> *) a. Monad m => a -> m a
return Context
c'
    MatchResult Substitution
_ Template{Maybe [Rewrite Universe]
Annotated k
AnnotatedImports
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
tImports :: forall ast. Template ast -> AnnotatedImports
tTemplate :: forall ast. Template ast -> Annotated ast
tDependents :: Maybe [Rewrite Universe]
tImports :: AnnotatedImports
tTemplate :: Annotated k
..} -> do
      let
        rrs :: [Rewrite Universe]
rrs = forall a. a -> Maybe a -> a
fromMaybe [] Maybe [Rewrite Universe]
tDependents
        ds :: [Rewrite Universe]
ds = forall ast. [Rewrite ast] -> [Rewrite ast]
rewritesWithDependents [Rewrite Universe]
rrs
        f :: [Rewrite Universe] -> Rewriter
f = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall ast. Matchable ast => AlphaEnv -> Rewrite ast -> Rewriter
mkLocalRewriter forall a b. (a -> b) -> a -> b
$ Context -> AlphaEnv
ctxtInScope Context
c')
      forall (m :: * -> *) a. Monad m => a -> m a
return Context
c'
        { ctxtRewriter :: Rewriter
ctxtRewriter = [Rewrite Universe] -> Rewriter
f [Rewrite Universe]
rrs forall a. Semigroup a => a -> a -> a
<> Context -> Rewriter
ctxtRewriter Context
c'
        , ctxtDependents :: Rewriter
ctxtDependents = [Rewrite Universe] -> Rewriter
f [Rewrite Universe]
ds forall a. Semigroup a => a -> a -> a
<> Context -> Rewriter
ctxtDependents Context
c'
        }

-- | Add set of binders to 'ctxtInScope'.
addInScope :: Context -> [RdrName] -> Context
addInScope :: Context -> [RdrName] -> Context
addInScope Context
c [RdrName]
bs =
  Context
c' { ctxtInScope :: AlphaEnv
ctxtInScope = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr RdrName -> AlphaEnv -> AlphaEnv
extendAlphaEnv (Context -> AlphaEnv
ctxtInScope Context
c') [RdrName]
bs' }
  where
    (Context
c', [RdrName]
bs') = Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
c [RdrName]
bs

-- | Add set of binders to 'ctxtBinders'.
addBinders :: Context -> [RdrName] -> Context
addBinders :: Context -> [RdrName] -> Context
addBinders Context
c [RdrName]
bs = Context
c { ctxtBinders :: [RdrName]
ctxtBinders = [RdrName]
bs forall a. [a] -> [a] -> [a]
++ Context -> [RdrName]
ctxtBinders Context
c }

-- Capture-avoiding substitution
--------------------------------------------------------------------------------

-- | Update the Context's substitution appropriately for a set of binders.
-- Returns a new Context and a potentially alpha-renamed set of binders.
updateSubstitution :: Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution :: Context -> [RdrName] -> (Context, [RdrName])
updateSubstitution Context
c [RdrName]
rdrs =
  case Context -> Maybe Substitution
ctxtSubst Context
c of
    Maybe Substitution
Nothing -> (Context
c, [RdrName]
rdrs)
    Just Substitution
sub ->
      let
        -- This prevents substituting for 'x' under a binding for 'x'.
        sub' :: Substitution
sub' = Substitution -> [FastString] -> Substitution
deleteSubst Substitution
sub forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map RdrName -> FastString
rdrFS [RdrName]
rdrs
        -- Compute free vars of substitution that could possibly be captured.
        fvs :: FreeVars
fvs = Substitution -> FreeVars
substFVs Substitution
sub'
        -- Partition binders into noncapturing and capturing.
        ([RdrName]
noncapturing, [(RdrName, RdrName)]
capturing) =
          forall a b. [Either a b] -> ([a], [b])
partitionEithers forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (FreeVars -> RdrName -> Either RdrName (RdrName, RdrName)
updateBinder FreeVars
fvs) [RdrName]
rdrs
        -- Extend substitution with alpha-renamings.
        alphaSub :: Substitution
alphaSub = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall b c a. (b -> c) -> (a -> b) -> a -> c
. Substitution -> FastString -> HoleVal -> Substitution
extendSubst) Substitution
sub'
          [ (RdrName -> FastString
rdrFS RdrName
rdr, RdrName -> HoleVal
HoleRdr RdrName
rdr') | (RdrName
rdr, RdrName
rdr') <- [(RdrName, RdrName)]
capturing ]
        -- There are no telescopes in source Haskell, so order doesn't matter.
        -- Capturing should be rare, so put it first to avoid quadratic append.
        rdrs' :: [RdrName]
rdrs' = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(RdrName, RdrName)]
capturing forall a. [a] -> [a] -> [a]
++ [RdrName]
noncapturing
      in (Context
c { ctxtSubst :: Maybe Substitution
ctxtSubst = forall a. a -> Maybe a
Just Substitution
alphaSub }, [RdrName]
rdrs')

-- | Check if RdrName is in FreeVars.
--
-- If so, return a pair of it and its new name (Right).
-- If not, return it unchanged (Left).
updateBinder :: FreeVars -> RdrName -> Either RdrName (RdrName, RdrName)
updateBinder :: FreeVars -> RdrName -> Either RdrName (RdrName, RdrName)
updateBinder FreeVars
fvs RdrName
rdr
  | RdrName -> FreeVars -> Bool
elemFVs RdrName
rdr FreeVars
fvs = forall a b. b -> Either a b
Right (RdrName
rdr, RdrName -> FreeVars -> RdrName
renameBinder RdrName
rdr FreeVars
fvs)
  | Bool
otherwise = forall a b. a -> Either a b
Left RdrName
rdr

-- | Given a RdrName, rename it to something not in given FreeVars.
--
--   x => x1
--   x1 => x2
--   x9 => x10
--
-- etc.
--
-- Only works on unqualified RdrNames. This is fine, as we only use this to
-- rename local binders.
renameBinder :: RdrName -> FreeVars -> RdrName
renameBinder :: RdrName -> FreeVars -> RdrName
renameBinder RdrName
rdr FreeVars
fvs = forall a. [a] -> a
head
  [ RdrName
rdr'
  | Int
i <- [Int
n..]
  , let rdr' :: RdrName
rdr' = FastString -> RdrName
mkVarUnqual forall a b. (a -> b) -> a -> b
$ String -> FastString
mkFastString forall a b. (a -> b) -> a -> b
$ String
baseName forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Int
i
  , Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ RdrName
rdr' RdrName -> FreeVars -> Bool
`elemFVs` FreeVars
fvs
  ]
  where
    (String
ds, String
rest) = forall a. (a -> Bool) -> [a] -> ([a], [a])
span Char -> Bool
isDigit forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ OccName -> String
occNameString forall a b. (a -> b) -> a -> b
$ forall name. HasOccName name => name -> OccName
occName RdrName
rdr

    baseName :: String
baseName = forall a. [a] -> [a]
reverse String
rest

    n :: Int
    n :: Int
n | forall (t :: * -> *) a. Foldable t => t a -> Bool
null String
ds = Int
1
      | Bool
otherwise = forall a. Read a => String -> a
read (forall a. [a] -> [a]
reverse String
ds) forall a. Num a => a -> a -> a
+ Int
1