-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE ViewPatterns #-}
module Retrie.Subst (subst) where

import Control.Monad.Writer.Strict
import Data.Generics

import Retrie.Context
import Retrie.ExactPrint
import Retrie.Expr
import Retrie.GHC
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types

------------------------------------------------------------------------

-- | Perform the given 'Substitution' on an AST, avoiding variable capture
-- by alpha-renaming binders as needed.
subst
  :: (MonadIO m, Data ast)
  => Substitution
  -> Context
  -> ast
  -> TransformT m ast
subst :: Substitution -> Context -> ast -> TransformT m ast
subst Substitution
sub Context
ctxt =
  Strategy (TransformT m)
-> GenericQ Bool
-> GenericCU (TransformT m) Context
-> GenericMC (TransformT m) Context
-> Context
-> ast
-> TransformT m ast
forall (m :: * -> *) c.
Monad m =>
Strategy m
-> GenericQ Bool -> GenericCU m c -> GenericMC m c -> GenericMC m c
everywhereMWithContextBut Strategy (TransformT m)
forall (m :: * -> *). Strategy m
bottomUp (Bool -> a -> Bool
forall a b. a -> b -> a
const Bool
False) GenericCU (TransformT m) Context
forall (m :: * -> *). MonadIO m => GenericCU (TransformT m) Context
updateContext GenericMC (TransformT m) Context
forall (m :: * -> *) a.
(Monad m, Typeable a) =>
Context -> a -> TransformT m a
f Context
ctxt'
  where
    ctxt' :: Context
ctxt' = Context
ctxt { ctxtSubst :: Maybe Substitution
ctxtSubst = Substitution -> Maybe Substitution
forall a. a -> Maybe a
Just Substitution
sub }
    f :: Context -> a -> TransformT m a
f Context
c =
      (LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs))
-> a -> TransformT m a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
substExpr Context
c)
        (a -> TransformT m a)
-> (Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs)))
-> a
-> TransformT m a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
substPat Context
c
        (a -> TransformT m a)
-> (LHsType GhcPs -> TransformT m (LHsType GhcPs))
-> a
-> TransformT m a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
substType Context
c
        (a -> TransformT m a)
-> (HsMatchContext RdrName
    -> TransformT m (HsMatchContext RdrName))
-> a
-> TransformT m a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context
-> HsMatchContext RdrName -> TransformT m (HsMatchContext RdrName)
forall (m :: * -> *).
Monad m =>
Context
-> HsMatchContext RdrName -> TransformT m (HsMatchContext RdrName)
substHsMatchContext Context
c
        (a -> TransformT m a)
-> (HsBind GhcPs -> TransformT m (HsBind GhcPs))
-> a
-> TransformT m a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` Context -> HsBind GhcPs -> TransformT m (HsBind GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> HsBind GhcPs -> TransformT m (HsBind GhcPs)
substBind Context
c

lookupHoleVar :: RdrName -> Context -> Maybe HoleVal
lookupHoleVar :: RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
rdr Context
ctxt = do
  Substitution
sub <- Context -> Maybe Substitution
ctxtSubst Context
ctxt
  FastString -> Substitution -> Maybe HoleVal
lookupSubst (RdrName -> FastString
rdrFS RdrName
rdr) Substitution
sub

substExpr
  :: Monad m
  => Context
  -> LHsExpr GhcPs
  -> TransformT m (LHsExpr GhcPs)
substExpr :: Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
substExpr Context
ctxt e :: LHsExpr GhcPs
e@(L SrcSpan
l1 (HsVar XVar GhcPs
x (L SrcSpan
l2 IdP GhcPs
v))) =
  case RdrName -> Context -> Maybe HoleVal
lookupHoleVar IdP GhcPs
RdrName
v Context
ctxt of
    Just (HoleExpr AnnotatedHsExpr
eA) -> do
      LHsExpr GhcPs
e' <- AnnotatedHsExpr -> TransformT m (LHsExpr GhcPs)
forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA (LHsExpr GhcPs -> LHsExpr GhcPs
unparen (LHsExpr GhcPs -> LHsExpr GhcPs)
-> AnnotatedHsExpr -> AnnotatedHsExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AnnotatedHsExpr
eA)
      Bool
comments <- LHsExpr GhcPs -> TransformT m Bool
forall a (m :: * -> *).
(Data a, Monad m) =>
Located a -> TransformT m Bool
hasComments LHsExpr GhcPs
e'
      Bool -> TransformT m () -> TransformT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
comments (TransformT m () -> TransformT m ())
-> TransformT m () -> TransformT m ()
forall a b. (a -> b) -> a -> b
$ LHsExpr GhcPs -> LHsExpr GhcPs -> TransformT m ()
forall a b (m :: * -> *).
(HasCallStack, Data a, Data b, Monad m) =>
Located a -> Located b -> TransformT m ()
transferEntryDPT LHsExpr GhcPs
e LHsExpr GhcPs
e'
      (KeywordId -> Bool)
-> LHsExpr GhcPs -> LHsExpr GhcPs -> TransformT m ()
forall a b (m :: * -> *).
(Data a, Data b, Monad m) =>
(KeywordId -> Bool) -> Located a -> Located b -> TransformT m ()
transferAnnsT KeywordId -> Bool
isComma LHsExpr GhcPs
e LHsExpr GhcPs
e'
      Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
parenify Context
ctxt LHsExpr GhcPs
e'
    Just (HoleRdr RdrName
rdr) ->
      LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
forall (m :: * -> *) a. Monad m => a -> m a
return (LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs))
-> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
forall a b. (a -> b) -> a -> b
$ SrcSpan -> HsExpr GhcPs -> LHsExpr GhcPs
forall l e. l -> e -> GenLocated l e
L SrcSpan
l1 (HsExpr GhcPs -> LHsExpr GhcPs) -> HsExpr GhcPs -> LHsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ XVar GhcPs -> GenLocated SrcSpan (IdP GhcPs) -> HsExpr GhcPs
forall p. XVar p -> Located (IdP p) -> HsExpr p
HsVar XVar GhcPs
x (GenLocated SrcSpan (IdP GhcPs) -> HsExpr GhcPs)
-> GenLocated SrcSpan (IdP GhcPs) -> HsExpr GhcPs
forall a b. (a -> b) -> a -> b
$ SrcSpan -> RdrName -> GenLocated SrcSpan RdrName
forall l e. l -> e -> GenLocated l e
L SrcSpan
l2 RdrName
rdr
    Maybe HoleVal
_ -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
forall (m :: * -> *) a. Monad m => a -> m a
return LHsExpr GhcPs
e
substExpr Context
_ LHsExpr GhcPs
e = LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
forall (m :: * -> *) a. Monad m => a -> m a
return LHsExpr GhcPs
e

substPat
  :: Monad m
  => Context
  -> LPat GhcPs
  -> TransformT m (LPat GhcPs)
substPat :: Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
substPat Context
ctxt (LPat GhcPs -> Maybe (Located (Pat GhcPs))
forall (p :: Pass).
LPat (GhcPass p) -> Maybe (Located (Pat (GhcPass p)))
dLPat -> Just p :: Located (Pat GhcPs)
p@(L SrcSpan
l1 (VarPat XVarPat GhcPs
x vl :: GenLocated SrcSpan (IdP GhcPs)
vl@(L SrcSpan
l2 IdP GhcPs
v)))) = (Located (Pat GhcPs) -> Located (Pat GhcPs))
-> TransformT m (Located (Pat GhcPs))
-> TransformT m (Located (Pat GhcPs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Located (Pat GhcPs) -> Located (Pat GhcPs)
forall (p :: Pass). Located (Pat (GhcPass p)) -> LPat (GhcPass p)
cLPat (TransformT m (Located (Pat GhcPs))
 -> TransformT m (Located (Pat GhcPs)))
-> TransformT m (Located (Pat GhcPs))
-> TransformT m (Located (Pat GhcPs))
forall a b. (a -> b) -> a -> b
$
  case RdrName -> Context -> Maybe HoleVal
lookupHoleVar IdP GhcPs
RdrName
v Context
ctxt of
    Just (HolePat AnnotatedPat
pA) -> do
      Located (Pat GhcPs)
p' <- AnnotatedPat -> TransformT m (Located (Pat GhcPs))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA (Located (Pat GhcPs) -> Located (Pat GhcPs)
unparenP (Located (Pat GhcPs) -> Located (Pat GhcPs))
-> AnnotatedPat -> AnnotatedPat
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AnnotatedPat
pA)
      (KeywordId -> Bool)
-> Located (Pat GhcPs) -> Located (Pat GhcPs) -> TransformT m ()
forall a b (m :: * -> *).
(HasCallStack, Data a, Data b, Monad m) =>
(KeywordId -> Bool) -> Located a -> Located b -> TransformT m ()
transferEntryAnnsT KeywordId -> Bool
isComma Located (Pat GhcPs)
p Located (Pat GhcPs)
p'
      -- the relevant entry delta is sometimes attached to
      -- the OccName and not to the VarPat.
      -- This seems to be the case only when the pattern comes from a lhs,
      -- whereas it has no annotations in patterns found in rhs's.
      GenLocated SrcSpan RdrName
-> Located (Pat GhcPs) -> TransformT m ()
forall a b (m :: * -> *).
(Data a, Data b, Monad m) =>
Located a -> Located b -> TransformT m ()
tryTransferEntryDPT GenLocated SrcSpan (IdP GhcPs)
GenLocated SrcSpan RdrName
vl Located (Pat GhcPs)
p'
      Context
-> Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
forall (m :: * -> *).
Monad m =>
Context
-> Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
parenifyP Context
ctxt Located (Pat GhcPs)
p'
    Just (HoleRdr RdrName
rdr) ->
      Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
forall (m :: * -> *) a. Monad m => a -> m a
return (Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs)))
-> Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
forall a b. (a -> b) -> a -> b
$ SrcSpan -> Pat GhcPs -> Located (Pat GhcPs)
forall l e. l -> e -> GenLocated l e
L SrcSpan
l1 (Pat GhcPs -> Located (Pat GhcPs))
-> Pat GhcPs -> Located (Pat GhcPs)
forall a b. (a -> b) -> a -> b
$ XVarPat GhcPs -> GenLocated SrcSpan (IdP GhcPs) -> Pat GhcPs
forall p. XVarPat p -> Located (IdP p) -> Pat p
VarPat XVarPat GhcPs
x (GenLocated SrcSpan (IdP GhcPs) -> Pat GhcPs)
-> GenLocated SrcSpan (IdP GhcPs) -> Pat GhcPs
forall a b. (a -> b) -> a -> b
$ SrcSpan -> RdrName -> GenLocated SrcSpan RdrName
forall l e. l -> e -> GenLocated l e
L SrcSpan
l2 RdrName
rdr
    Maybe HoleVal
_ -> Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
forall (m :: * -> *) a. Monad m => a -> m a
return Located (Pat GhcPs)
p
substPat Context
_ LPat GhcPs
p = Located (Pat GhcPs) -> TransformT m (Located (Pat GhcPs))
forall (m :: * -> *) a. Monad m => a -> m a
return LPat GhcPs
Located (Pat GhcPs)
p

substType
  :: Monad m
  => Context
  -> LHsType GhcPs
  -> TransformT m (LHsType GhcPs)
substType :: Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
substType Context
ctxt LHsType GhcPs
ty
  | Just (L SrcSpan
_ IdP GhcPs
v) <- HsType GhcPs -> Maybe (GenLocated SrcSpan (IdP GhcPs))
forall p. HsType p -> Maybe (Located (IdP p))
tyvarRdrName (LHsType GhcPs -> SrcSpanLess (LHsType GhcPs)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc LHsType GhcPs
ty)
  , Just (HoleType AnnotatedHsType
tyA) <- RdrName -> Context -> Maybe HoleVal
lookupHoleVar IdP GhcPs
RdrName
v Context
ctxt = do
    LHsType GhcPs
ty' <- AnnotatedHsType -> TransformT m (LHsType GhcPs)
forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA (LHsType GhcPs -> LHsType GhcPs
unparenT (LHsType GhcPs -> LHsType GhcPs)
-> AnnotatedHsType -> AnnotatedHsType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AnnotatedHsType
tyA)
    (KeywordId -> Bool)
-> LHsType GhcPs -> LHsType GhcPs -> TransformT m ()
forall a b (m :: * -> *).
(HasCallStack, Data a, Data b, Monad m) =>
(KeywordId -> Bool) -> Located a -> Located b -> TransformT m ()
transferEntryAnnsT KeywordId -> Bool
isComma LHsType GhcPs
ty LHsType GhcPs
ty'
    Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
parenifyT Context
ctxt LHsType GhcPs
ty'
substType Context
_ LHsType GhcPs
ty = LHsType GhcPs -> TransformT m (LHsType GhcPs)
forall (m :: * -> *) a. Monad m => a -> m a
return LHsType GhcPs
ty

-- You might reasonably think that we would replace the RdrName in FunBind...
-- but no, exactprint only cares about the RdrName in the MatchGroup matches,
-- which are here. In case that changes in the future, we define substBind too.
substHsMatchContext
  :: Monad m
  => Context
#if __GLASGOW_HASKELL__ < 900
  -> HsMatchContext RdrName
  -> TransformT m (HsMatchContext RdrName)
#else
  -> HsMatchContext GhcPs
  -> TransformT m (HsMatchContext GhcPs)
#endif
substHsMatchContext :: Context
-> HsMatchContext RdrName -> TransformT m (HsMatchContext RdrName)
substHsMatchContext Context
ctxt (FunRhs (L SrcSpan
l RdrName
v) LexicalFixity
f SrcStrictness
s)
  | Just (HoleRdr RdrName
rdr) <- RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
v Context
ctxt =
    HsMatchContext RdrName -> TransformT m (HsMatchContext RdrName)
forall (m :: * -> *) a. Monad m => a -> m a
return (HsMatchContext RdrName -> TransformT m (HsMatchContext RdrName))
-> HsMatchContext RdrName -> TransformT m (HsMatchContext RdrName)
forall a b. (a -> b) -> a -> b
$ GenLocated SrcSpan RdrName
-> LexicalFixity -> SrcStrictness -> HsMatchContext RdrName
forall id.
Located id -> LexicalFixity -> SrcStrictness -> HsMatchContext id
FunRhs (SrcSpan -> RdrName -> GenLocated SrcSpan RdrName
forall l e. l -> e -> GenLocated l e
L SrcSpan
l RdrName
rdr) LexicalFixity
f SrcStrictness
s
substHsMatchContext Context
_ HsMatchContext RdrName
other = HsMatchContext RdrName -> TransformT m (HsMatchContext RdrName)
forall (m :: * -> *) a. Monad m => a -> m a
return HsMatchContext RdrName
other

substBind
  :: Monad m
  => Context
  -> HsBind GhcPs
  -> TransformT m (HsBind GhcPs)
substBind :: Context -> HsBind GhcPs -> TransformT m (HsBind GhcPs)
substBind Context
ctxt fb :: HsBind GhcPs
fb@FunBind{}
  | L SrcSpan
l IdP GhcPs
v <- HsBind GhcPs -> GenLocated SrcSpan (IdP GhcPs)
forall idL idR. HsBindLR idL idR -> Located (IdP idL)
fun_id HsBind GhcPs
fb
  , Just (HoleRdr RdrName
rdr) <- RdrName -> Context -> Maybe HoleVal
lookupHoleVar IdP GhcPs
RdrName
v Context
ctxt =
    HsBind GhcPs -> TransformT m (HsBind GhcPs)
forall (m :: * -> *) a. Monad m => a -> m a
return HsBind GhcPs
fb { fun_id :: GenLocated SrcSpan (IdP GhcPs)
fun_id = SrcSpan -> RdrName -> GenLocated SrcSpan RdrName
forall l e. l -> e -> GenLocated l e
L SrcSpan
l RdrName
rdr }
substBind Context
_ HsBind GhcPs
other = HsBind GhcPs -> TransformT m (HsBind GhcPs)
forall (m :: * -> *) a. Monad m => a -> m a
return HsBind GhcPs
other