-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}
module Retrie.Elaborate
  ( defaultElaborations
  , elaborateRewritesInternal
  ) where

import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import "list-t" ListT
import Data.Maybe

import Retrie.Context
import Retrie.ExactPrint
import Retrie.Expr
import Retrie.Fixity
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Rewrites
import Retrie.Subst
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types
import Retrie.Universe

defaultElaborations :: [RewriteSpec]
defaultElaborations :: [RewriteSpec]
defaultElaborations =
  [ String -> RewriteSpec
Adhoc String
"forall f x. f $ x = f (x)"
  ]

elaborateRewritesInternal
  :: FixityEnv
  -> [Rewrite Universe]
  -> [Rewrite Universe]
  -> IO [Rewrite Universe]
elaborateRewritesInternal :: FixityEnv
-> [Rewrite Universe]
-> [Rewrite Universe]
-> IO [Rewrite Universe]
elaborateRewritesInternal FixityEnv
_ [] [Rewrite Universe]
rewrites = forall (m :: * -> *) a. Monad m => a -> m a
return [Rewrite Universe]
rewrites
elaborateRewritesInternal FixityEnv
fixityEnv [Rewrite Universe]
elaborations [Rewrite Universe]
rewrites =
  forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (FixityEnv -> Rewriter -> Rewrite Universe -> IO [Rewrite Universe]
elaborateOne FixityEnv
fixityEnv Rewriter
elaborator) [Rewrite Universe]
rewrites
  where
    elaborator :: Rewriter
elaborator = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap forall ast. Matchable ast => Rewrite ast -> Rewriter
mkRewriter [Rewrite Universe]
elaborations

elaborateOne :: FixityEnv -> Rewriter -> Rewrite Universe -> IO [Rewrite Universe]
elaborateOne :: FixityEnv -> Rewriter -> Rewrite Universe -> IO [Rewrite Universe]
elaborateOne FixityEnv
fixityEnv Rewriter
elaborator Rewrite Universe
rr = do
  Annotated [Universe]
patterns <-
    forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA (forall ast v. Query ast v -> Annotated ast
qPattern Rewrite Universe
rr) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => ListT m a -> m [a]
toList forall b c a. (b -> c) -> (a -> b) -> a -> c
.
      forall (m :: * -> *) c.
Monad m =>
Strategy m
-> GenericQ Bool -> GenericCU m c -> GenericMC m c -> GenericMC m c
everywhereMWithContextBut forall (m :: * -> *). Strategy m
topDown
        (forall a b. a -> b -> a
const Bool
False) (\Context
c Int
i a
x -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadIO m => GenericCU (TransformT m) Context
updateContext Context
c Int
i a
x) forall a (m :: * -> *).
(Data a, MonadIO m) =>
Context -> a -> ListT (TransformT m) a
elaborate Context
ctxt
  forall (m :: * -> *) a. Monad m => a -> m a
return [ Rewrite Universe
rr { qPattern :: Annotated Universe
qPattern = Annotated Universe
pattern } | Annotated Universe
pattern <- forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA Annotated [Universe]
patterns ]
  where
    ctxt :: Context
ctxt = FixityEnv -> Rewriter -> Rewriter -> Context
emptyContext FixityEnv
fixityEnv Rewriter
elaborator forall a. Monoid a => a
mempty

elaborate
  :: (Data a, MonadIO m) => Context -> a -> ListT (TransformT m) a
elaborate :: forall a (m :: * -> *).
(Data a, MonadIO m) =>
Context -> a -> ListT (TransformT m) a
elaborate Context
c =
  forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context -> LocatedA ast -> ListT (TransformT m) (LocatedA ast)
elaborateImpl @(HsExpr GhcPs) Context
c)
    forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context -> LocatedA ast -> ListT (TransformT m) (LocatedA ast)
elaborateImpl @(Stmt GhcPs (LHsExpr GhcPs)) Context
c)
    forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context -> LocatedA ast -> ListT (TransformT m) (LocatedA ast)
elaborateImpl @(HsType GhcPs) Context
c)
    forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` (forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
elaboratePat Context
c)

elaboratePat :: MonadIO m => Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
-- We need to ensure we have a location available at the top level so we can
-- transfer annotations. This ensures we don't try to rewrite a naked Pat.
elaboratePat :: forall (m :: * -> *).
MonadIO m =>
Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
elaboratePat Context
c LPat GhcPs
p
  | Just LPat GhcPs
lp <- forall (p :: Pass). LPat (GhcPass p) -> Maybe (LPat (GhcPass p))
dLPat LPat GhcPs
p = forall (p :: Pass). LPat (GhcPass p) -> LPat (GhcPass p)
cLPat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context -> LocatedA ast -> ListT (TransformT m) (LocatedA ast)
elaborateImpl Context
c LPat GhcPs
lp
  | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return LPat GhcPs
p

elaborateImpl
  :: forall ast m. (Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m)
  => Context -> LocatedA ast -> ListT (TransformT m) (LocatedA ast)
elaborateImpl :: forall ast (m :: * -> *).
(Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m) =>
Context -> LocatedA ast -> ListT (TransformT m) (LocatedA ast)
elaborateImpl Context
ctxt LocatedA ast
e = do
  [LocatedA ast]
elaborations <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ do
    [(Substitution, RewriterResult Universe)]
matches <- forall ast (m :: * -> *) v.
(Matchable ast, MonadIO m) =>
Context -> Matcher v -> ast -> TransformT m [(Substitution, v)]
runMatcher Context
ctxt (Context -> Rewriter
ctxtRewriter Context
ctxt) (forall k. Data k => k -> k
getUnparened LocatedA ast
e)
    [MatchResult (LocatedA ast)]
validMatches <- forall ast (m :: * -> *).
(Matchable ast, MonadIO m) =>
Context
-> [(Substitution, RewriterResult Universe)]
-> TransformT m [MatchResult ast]
allMatches Context
ctxt [(Substitution, RewriterResult Universe)]
matches
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [ (Substitution
sub, Template (LocatedA ast)
tmpl) | MatchResult Substitution
sub Template (LocatedA ast)
tmpl <- [MatchResult (LocatedA ast)]
validMatches ] forall a b. (a -> b) -> a -> b
$ \(Substitution
sub, Template{Maybe [Rewrite Universe]
AnnotatedImports
Annotated (LocatedA ast)
tDependents :: forall ast. Template ast -> Maybe [Rewrite Universe]
tImports :: forall ast. Template ast -> AnnotatedImports
tTemplate :: forall ast. Template ast -> Annotated ast
tDependents :: Maybe [Rewrite Universe]
tImports :: AnnotatedImports
tTemplate :: Annotated (LocatedA ast)
..}) -> do
      -- graft template into target
      LocatedA ast
t' <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
Annotated ast -> TransformT m ast
graftA Annotated (LocatedA ast)
tTemplate
      -- substitute for quantifiers in grafted template
      LocatedA ast
r <- forall (m :: * -> *) ast.
(MonadIO m, Data ast) =>
Substitution -> Context -> ast -> TransformT m ast
subst Substitution
sub Context
ctxt LocatedA ast
t'
      -- copy appropriate annotations from old expression to template
      LocatedA ast
r0 <- forall an a b (m :: * -> *).
(HasCallStack, Monoid an, Data a, Data b, MonadIO m,
 Typeable an) =>
LocatedAn an a -> LocatedAn an b -> TransformT m (LocatedAn an b)
addAllAnnsT LocatedA ast
e LocatedA ast
r
      -- add parens to template if needed
      (forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM (forall (m :: * -> *).
Monad m =>
Context -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
parenify Context
ctxt) forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` forall (m :: * -> *).
Monad m =>
Context -> LHsType GhcPs -> TransformT m (LHsType GhcPs)
parenifyT Context
ctxt forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(a -> m a) -> (b -> m b) -> a -> m a
`extM` forall (m :: * -> *).
Monad m =>
Context -> LPat GhcPs -> TransformT m (LPat GhcPs)
parenifyP Context
ctxt) LocatedA ast
r0

  forall (m :: * -> *) (f :: * -> *) a.
(Monad m, Foldable f) =>
f a -> ListT m a
fromFoldable (LocatedA ast
e forall a. a -> [a] -> [a]
: [LocatedA ast]
elaborations)

-- | Find the first 'valid' match.
-- Runs the user's 'MatchResultTransformer' and sanity checks the result.
allMatches
  :: (Matchable ast, MonadIO m)
  => Context
  -> [(Substitution, RewriterResult Universe)]
  -> TransformT m [MatchResult ast]
allMatches :: forall ast (m :: * -> *).
(Matchable ast, MonadIO m) =>
Context
-> [(Substitution, RewriterResult Universe)]
-> TransformT m [MatchResult ast]
allMatches Context
_ [] = forall (m :: * -> *) a. Monad m => a -> m a
return []
allMatches Context
ctxt [(Substitution, RewriterResult Universe)]
matchResults = do
  [(Quantifiers, MatchResult Universe)]
results <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Substitution, RewriterResult Universe)]
matchResults forall a b. (a -> b) -> a -> b
$ \(Substitution
sub, RewriterResult{SrcSpan
Quantifiers
Template Universe
MatchResultTransformer
rrTemplate :: forall ast. RewriterResult ast -> Template ast
rrTransformer :: forall ast. RewriterResult ast -> MatchResultTransformer
rrQuantifiers :: forall ast. RewriterResult ast -> Quantifiers
rrOrigin :: forall ast. RewriterResult ast -> SrcSpan
rrTemplate :: Template Universe
rrTransformer :: MatchResultTransformer
rrQuantifiers :: Quantifiers
rrOrigin :: SrcSpan
..}) -> do
      MatchResult Universe
result <- forall (m :: * -> *) a. RWST () [String] Int m a -> TransformT m a
TransformT forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ MatchResultTransformer
rrTransformer Context
ctxt forall a b. (a -> b) -> a -> b
$ forall ast. Substitution -> Template ast -> MatchResult ast
MatchResult Substitution
sub Template Universe
rrTemplate
      forall (m :: * -> *) a. Monad m => a -> m a
return (Quantifiers
rrQuantifiers, MatchResult Universe
result)
  forall (m :: * -> *) a. Monad m => a -> m a
return
    [ forall ast. Matchable ast => Universe -> ast
project forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MatchResult Universe
result
    | (Quantifiers
quantifiers, result :: MatchResult Universe
result@(MatchResult Substitution
sub' Template Universe
_)) <- [(Quantifiers, MatchResult Universe)]
results
      -- Check that all quantifiers from the original rewrite have mappings
      -- in the resulting substitution. This is mostly to prevent a bad
      -- user-defined MatchResultTransformer from causing havok.
    , forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [ FastString -> Substitution -> Maybe HoleVal
lookupSubst FastString
q Substitution
sub' | FastString
q <- Quantifiers -> [FastString]
qList Quantifiers
quantifiers ]
    ]