-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE ViewPatterns #-}
module Retrie.Subst (subst) where

import Control.Monad.Writer.Strict
import Data.Generics

import Retrie.Context
import Retrie.ExactPrint
import Retrie.Expr
import Retrie.GHC
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types
import Retrie.Util

------------------------------------------------------------------------

-- | Perform the given 'Substitution' on an AST, avoiding variable capture
-- by alpha-renaming binders as needed.
subst
  :: (MonadIO m, Data ast)
  => Substitution
  -> Context
  -> ast
  -> TransformT m ast
subst :: forall (m :: * -> *) ast.
(MonadIO m, Data ast) =>
Substitution -> Context -> ast -> TransformT m ast
subst Substitution
sub Context
ctxt =
  forall (m :: * -> *) c.
Monad m =>
Strategy m
-> GenericQ Bool -> GenericCU m c -> GenericMC m c -> GenericMC m c
everywhereMWithContextBut forall (m :: * -> *). Strategy m
bottomUp (forall a b. a -> b -> a
const Bool
False) forall (m :: * -> *). MonadIO m => GenericCU (TransformT m) Context
updateContext forall {m :: * -> *} {a}.
(Typeable a, MonadIO m) =>
Context -> a -> TransformT m a
f Context
ctxt'
  where
    ctxt' :: Context
ctxt' = Context
ctxt { ctxtSubst :: Maybe Substitution
ctxtSubst = forall a. a -> Maybe a
Just Substitution
sub }
    f :: Context -> a -> TransformT m a
f Context
c =
      forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (forall (m :: * -> *).
MonadIO m =>
Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
substExpr Context
c)
        forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
substPat Context
c
        forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` forall (m :: * -> *).
MonadIO m =>
Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
substType Context
c
        forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` forall (m :: * -> *).
Monad m =>
Context
-> HsMatchContext GhcPs -> TransformT m (HsMatchContext GhcPs)
substHsMatchContext Context
c
        forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` forall (m :: * -> *).
Monad m =>
Context -> HsBind GhcPs -> TransformT m (HsBind GhcPs)
substBind Context
c

lookupHoleVar :: RdrName -> Context -> Maybe HoleVal
lookupHoleVar :: RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
rdr Context
ctxt = do
  Substitution
sub <- Context -> Maybe Substitution
ctxtSubst Context
ctxt
  FastString -> Substitution -> Maybe HoleVal
lookupSubst (RdrName -> FastString
rdrFS RdrName
rdr) Substitution
sub

substExpr
  :: MonadIO m
  => Context
  -> LHsExpr GhcPs
  -> TransformT m (LHsExpr GhcPs)
substExpr :: forall (m :: * -> *).
MonadIO m =>
Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
substExpr Context
ctxt e :: LHsExpr GhcPs
e@(L SrcSpanAnnA
l1 (HsVar XVar GhcPs
x (L SrcSpanAnnN
l2 RdrName
v))) =
  case RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
v Context
ctxt of
    Just (HoleExpr AnnotatedHsExpr
eA) -> do
      -- lift $ liftIO $ debugPrint Loud "substExpr:HoleExpr:e" [showAst e]
      -- lift $ liftIO $ debugPrint Loud "substExpr:HoleExpr:eA" [showAst eA]
      GenLocated SrcSpanAnnA (HsExpr GhcPs)
e0 <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA (LHsExpr GhcPs -> LHsExpr GhcPs
unparen forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AnnotatedHsExpr
eA)
      let comments :: Bool
comments = forall an a. LocatedAn an a -> Bool
hasComments GenLocated SrcSpanAnnA (HsExpr GhcPs)
e0
      -- unless comments $ transferEntryDPT e e'
      GenLocated SrcSpanAnnA (HsExpr GhcPs)
e1 <- if Bool
comments
               then forall (m :: * -> *) a. Monad m => a -> m a
return GenLocated SrcSpanAnnA (HsExpr GhcPs)
e0
               else forall (m :: * -> *) t2 t1 a b.
(Monad m, Monoid t2, Typeable t1, Typeable t2) =>
LocatedAn t1 a -> LocatedAn t2 b -> TransformT m (LocatedAn t2 b)
transferEntryDP LHsExpr GhcPs
e GenLocated SrcSpanAnnA (HsExpr GhcPs)
e0
      GenLocated SrcSpanAnnA (HsExpr GhcPs)
e2 <- forall a b (m :: * -> *).
(Data a, Data b, Monad m) =>
(TrailingAnn -> Bool)
-> LocatedA a -> LocatedA b -> TransformT m (LocatedA b)
transferAnnsT TrailingAnn -> Bool
isComma LHsExpr GhcPs
e GenLocated SrcSpanAnnA (HsExpr GhcPs)
e1
      -- let e'' = setEntryDP e' (SameLine 1)
      -- lift $ liftIO $ debugPrint Loud "substExpr:HoleExpr:e2" [showAst e2]
      forall (m :: * -> *).
Monad m =>
Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
parenify Context
ctxt GenLocated SrcSpanAnnA (HsExpr GhcPs)
e2
    Just (HoleRdr RdrName
rdr) ->
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall l e. l -> e -> GenLocated l e
L SrcSpanAnnA
l1 forall a b. (a -> b) -> a -> b
$ forall p. XVar p -> LIdP p -> HsExpr p
HsVar XVar GhcPs
x forall a b. (a -> b) -> a -> b
$ forall l e. l -> e -> GenLocated l e
L SrcSpanAnnN
l2 RdrName
rdr
    Maybe HoleVal
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return LHsExpr GhcPs
e
substExpr Context
_ LHsExpr GhcPs
e = forall (m :: * -> *) a. Monad m => a -> m a
return LHsExpr GhcPs
e

substPat
  :: MonadIO m
  => Context
  -> LPat GhcPs
  -> TransformT m (LPat GhcPs)
substPat :: forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
substPat Context
ctxt (forall (p :: Pass). LPat (GhcPass p) -> Maybe (LPat (GhcPass p))
dLPat -> Just p :: LPat GhcPs
p@(L SrcSpanAnnA
l1 (VarPat XVarPat GhcPs
x _vl :: LIdP GhcPs
_vl@(L SrcSpanAnnN
l2 RdrName
v)))) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (p :: Pass). LPat (GhcPass p) -> LPat (GhcPass p)
cLPat forall a b. (a -> b) -> a -> b
$
  case RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
v Context
ctxt of
    Just (HolePat AnnotatedPat
pA) -> do
      -- lift $ liftIO $ debugPrint Loud "substPat:HolePat:p" [showAst p]
      -- lift $ liftIO $ debugPrint Loud "substPat:HolePat:pA" [showAst pA]
      GenLocated SrcSpanAnnA (Pat GhcPs)
p' <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA (LPat GhcPs -> LPat GhcPs
unparenP forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AnnotatedPat
pA)
      GenLocated SrcSpanAnnA (Pat GhcPs)
p0 <- forall a b (m :: * -> *).
(HasCallStack, Data a, Data b, Monad m) =>
(TrailingAnn -> Bool)
-> LocatedA a -> LocatedA b -> TransformT m (LocatedA b)
transferEntryAnnsT TrailingAnn -> Bool
isComma LPat GhcPs
p GenLocated SrcSpanAnnA (Pat GhcPs)
p'
      -- the relevant entry delta is sometimes attached to
      -- the OccName and not to the VarPat.
      -- This seems to be the case only when the pattern comes from a lhs,
      -- whereas it has no annotations in patterns found in rhs's.
      -- tryTransferEntryDPT vl p'
      forall (m :: * -> *).
Monad m =>
Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
parenifyP Context
ctxt GenLocated SrcSpanAnnA (Pat GhcPs)
p0
    Just (HoleRdr RdrName
rdr) ->
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall l e. l -> e -> GenLocated l e
L SrcSpanAnnA
l1 forall a b. (a -> b) -> a -> b
$ forall p. XVarPat p -> LIdP p -> Pat p
VarPat XVarPat GhcPs
x forall a b. (a -> b) -> a -> b
$ forall l e. l -> e -> GenLocated l e
L SrcSpanAnnN
l2 RdrName
rdr
    Maybe HoleVal
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return LPat GhcPs
p
substPat Context
_ LPat GhcPs
p = forall (m :: * -> *) a. Monad m => a -> m a
return LPat GhcPs
p

substType
  :: MonadIO m
  => Context
  -> LHsType GhcPs
  -> TransformT m (LHsType GhcPs)
substType :: forall (m :: * -> *).
MonadIO m =>
Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
substType Context
ctxt LHsType GhcPs
ty
  | Just (L SrcSpanAnnN
_ RdrName
v) <- forall p. HsType p -> Maybe (LIdP p)
tyvarRdrName (forall l e. GenLocated l e -> e
unLoc LHsType GhcPs
ty)
  , Just (HoleType AnnotatedHsType
tyA) <- RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
v Context
ctxt = do
    -- lift $ liftIO $ debugPrint Loud "substType:HoleType:ty" [showAst ty]
    -- lift $ liftIO $ debugPrint Loud "substType:HoleType:tyA" [showAst tyA]
    GenLocated SrcSpanAnnA (HsType GhcPs)
ty' <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA (LHsType GhcPs -> LHsType GhcPs
unparenT forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AnnotatedHsType
tyA)
    GenLocated SrcSpanAnnA (HsType GhcPs)
ty0 <- forall a b (m :: * -> *).
(HasCallStack, Data a, Data b, Monad m) =>
(TrailingAnn -> Bool)
-> LocatedA a -> LocatedA b -> TransformT m (LocatedA b)
transferEntryAnnsT TrailingAnn -> Bool
isComma LHsType GhcPs
ty GenLocated SrcSpanAnnA (HsType GhcPs)
ty'
    forall (m :: * -> *).
Monad m =>
Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
parenifyT Context
ctxt GenLocated SrcSpanAnnA (HsType GhcPs)
ty0
substType Context
_ LHsType GhcPs
ty = forall (m :: * -> *) a. Monad m => a -> m a
return LHsType GhcPs
ty

-- You might reasonably think that we would replace the RdrName in FunBind...
-- but no, exactprint only cares about the RdrName in the MatchGroup matches,
-- which are here. In case that changes in the future, we define substBind too.
substHsMatchContext
  :: Monad m
  => Context
#if __GLASGOW_HASKELL__ < 900
  -> HsMatchContext RdrName
  -> TransformT m (HsMatchContext RdrName)
#else
  -> HsMatchContext GhcPs
  -> TransformT m (HsMatchContext GhcPs)
#endif
substHsMatchContext :: forall (m :: * -> *).
Monad m =>
Context
-> HsMatchContext GhcPs -> TransformT m (HsMatchContext GhcPs)
substHsMatchContext Context
ctxt (FunRhs (L SrcSpanAnnN
l RdrName
v) LexicalFixity
f SrcStrictness
s)
  | Just (HoleRdr RdrName
rdr) <- RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
v Context
ctxt =
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall p.
LIdP p -> LexicalFixity -> SrcStrictness -> HsMatchContext p
FunRhs (forall l e. l -> e -> GenLocated l e
L SrcSpanAnnN
l RdrName
rdr) LexicalFixity
f SrcStrictness
s
substHsMatchContext Context
_ HsMatchContext GhcPs
other = forall (m :: * -> *) a. Monad m => a -> m a
return HsMatchContext GhcPs
other

substBind
  :: Monad m
  => Context
  -> HsBind GhcPs
  -> TransformT m (HsBind GhcPs)
substBind :: forall (m :: * -> *).
Monad m =>
Context -> HsBind GhcPs -> TransformT m (HsBind GhcPs)
substBind Context
ctxt fb :: HsBind GhcPs
fb@FunBind{}
  | L SrcSpanAnnN
l RdrName
v <- forall idL idR. HsBindLR idL idR -> LIdP idL
fun_id HsBind GhcPs
fb
  , Just (HoleRdr RdrName
rdr) <- RdrName -> Context -> Maybe HoleVal
lookupHoleVar RdrName
v Context
ctxt =
    forall (m :: * -> *) a. Monad m => a -> m a
return HsBind GhcPs
fb { fun_id :: LIdP GhcPs
fun_id = forall l e. l -> e -> GenLocated l e
L SrcSpanAnnN
l RdrName
rdr }
substBind Context
_ HsBind GhcPs
other = forall (m :: * -> *) a. Monad m => a -> m a
return HsBind GhcPs
other