-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE TupleSections #-}
module Retrie.Rewrites.Function
  ( dfnsToRewrites
  , getImports
  , matchToRewrites
  ) where

import Control.Monad
import Control.Monad.State.Lazy
import Data.List
import Data.Maybe
import Data.Traversable

import Retrie.ExactPrint
import Retrie.Expr
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Types
import Retrie.Util

dfnsToRewrites
  :: LibDir
  -> [(FastString, Direction)]
  -> AnnotatedModule
  -> IO (UniqFM FastString [Rewrite (LHsExpr GhcPs)])
dfnsToRewrites :: String
-> [(FastString, Direction)]
-> AnnotatedModule
-> IO (UniqFM FastString [Rewrite (LHsExpr GhcPs)])
dfnsToRewrites String
libdir [(FastString, Direction)]
specs AnnotatedModule
am = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall ast. Annotated ast -> ast
astA forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA AnnotatedModule
am forall a b. (a -> b) -> a -> b
$ \ (L SrcSpan
_ HsModule
m) -> do
  let
    fsMap :: UniqFM FastString [Direction]
fsMap = forall a b. Uniquable a => [(a, b)] -> UniqFM a [b]
uniqBag [(FastString, Direction)]
specs

  [(FastString, [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))])]
rrs <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
    [ do
        GenLocated SrcSpanAnnA (HsExpr GhcPs)
fe <- forall (m :: * -> *).
Monad m =>
LocatedN RdrName -> TransformT m (LHsExpr GhcPs)
mkLocatedHsVar XRec GhcPs (IdP GhcPs)
fRdrName
        -- lift $ debugPrint Loud "dfnsToRewrites:ef="  [showAst fe]
        Annotated [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]
imps <- String
-> Direction
-> Maybe (LocatedA ModuleName)
-> TransformT IO AnnotatedImports
getImports String
libdir Direction
dir (HsModule -> Maybe (LocatedA ModuleName)
hsmodName HsModule
m)
        (FastString
fName,) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall l e. GenLocated l e -> e
unLoc forall a b. (a -> b) -> a -> b
$ forall p body. MatchGroup p body -> XRec p [LMatch p body]
mg_alts forall a b. (a -> b) -> a -> b
$ forall idL idR. HsBindLR idL idR -> MatchGroup idR (LHsExpr idR)
fun_matches HsBind GhcPs
f) (LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> LMatch GhcPs (LHsExpr GhcPs)
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
matchToRewrites GenLocated SrcSpanAnnA (HsExpr GhcPs)
fe Annotated [GenLocated SrcSpanAnnA (ImportDecl GhcPs)]
imps Direction
dir)
    | L SrcSpanAnnA
_ (ValD XValD GhcPs
_ f :: HsBind GhcPs
f@FunBind{}) <- HsModule -> [LHsDecl GhcPs]
hsmodDecls HsModule
m
    , let fRdrName :: XRec GhcPs (IdP GhcPs)
fRdrName = forall idL idR. HsBindLR idL idR -> LIdP idL
fun_id HsBind GhcPs
f
    , let fName :: FastString
fName = OccName -> FastString
occNameFS (forall name. HasOccName name => name -> OccName
occName (forall l e. GenLocated l e -> e
unLoc XRec GhcPs (IdP GhcPs)
fRdrName))
    , Direction
dir <- forall a. a -> Maybe a -> a
fromMaybe [] (forall key elt. Uniquable key => UniqFM key elt -> key -> Maybe elt
lookupUFM UniqFM FastString [Direction]
fsMap FastString
fName)
    ]

  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall key elt.
Uniquable key =>
(elt -> elt -> elt) -> [(key, elt)] -> UniqFM key elt
listToUFM_C forall a. [a] -> [a] -> [a]
(++) [(FastString, [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))])]
rrs

------------------------------------------------------------------------

getImports
  :: LibDir -> Direction -> Maybe (LocatedA ModuleName) -> TransformT IO AnnotatedImports
getImports :: String
-> Direction
-> Maybe (LocatedA ModuleName)
-> TransformT IO AnnotatedImports
getImports String
libdir Direction
RightToLeft (Just (L SrcSpanAnnA
_ ModuleName
mn)) = -- See Note [fold only]
  forall (m :: * -> *) a. RWST () [String] Int m a -> TransformT m a
TransformT forall a b. (a -> b) -> a -> b
$ forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> [String] -> IO AnnotatedImports
parseImports String
libdir [String
"import " forall a. [a] -> [a] -> [a]
++ ModuleName -> String
moduleNameString ModuleName
mn]
getImports String
_ Direction
_ Maybe (LocatedA ModuleName)
_ = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty

matchToRewrites
  :: LHsExpr GhcPs
  -> AnnotatedImports
  -> Direction
  -> LMatch GhcPs (LHsExpr GhcPs)
  -> TransformT IO [Rewrite (LHsExpr GhcPs)]
matchToRewrites :: LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> LMatch GhcPs (LHsExpr GhcPs)
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
matchToRewrites LHsExpr GhcPs
e AnnotatedImports
imps Direction
dir (L SrcSpanAnnA
_ Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
alt) = do
  -- lift $ debugPrint Loud "matchToRewrites:e="  [showAst e]
  let
    pats :: [LPat GhcPs]
pats = forall p body. Match p body -> [LPat p]
m_pats Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
alt
    grhss :: GRHSs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
grhss = forall p body. Match p body -> GRHSs p body
m_grhss Match GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
alt
  [[Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]]
qss <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. [a] -> [[a]]
inits [LPat GhcPs]
pats) (forall a. [a] -> [[a]]
tails [LPat GhcPs]
pats)) forall a b. (a -> b) -> a -> b
$
    LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> AppBuilder
-> ([LPat GhcPs], [LPat GhcPs])
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
makeFunctionQuery LHsExpr GhcPs
e AnnotatedImports
imps Direction
dir GRHSs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
grhss forall (m :: * -> *).
MonadIO m =>
LHsExpr GhcPs -> [LHsExpr GhcPs] -> TransformT m (LHsExpr GhcPs)
mkApps
  [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qs <- LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> [LPat GhcPs]
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
backtickRules LHsExpr GhcPs
e AnnotatedImports
imps Direction
dir GRHSs GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs))
grhss [LPat GhcPs]
pats
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qs forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]]
qss

type AppBuilder =
  LHsExpr GhcPs -> [LHsExpr GhcPs] -> TransformT IO (LHsExpr GhcPs)

irrefutablePat :: LPat GhcPs -> Bool
irrefutablePat :: LPat GhcPs -> Bool
irrefutablePat = Pat GhcPs -> Bool
go forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall l e. GenLocated l e -> e
unLoc
  where
    go :: Pat GhcPs -> Bool
go WildPat{} = Bool
True
    go VarPat{} = Bool
True
    go (LazyPat XLazyPat GhcPs
_ LPat GhcPs
p) = LPat GhcPs -> Bool
irrefutablePat LPat GhcPs
p
#if __GLASGOW_HASKELL__ <= 904
    go (AsPat XAsPat GhcPs
_ XRec GhcPs (IdP GhcPs)
_ LPat GhcPs
p) = LPat GhcPs -> Bool
irrefutablePat LPat GhcPs
p
#else
    go (AsPat _ _ _ p) = irrefutablePat p
#endif
#if __GLASGOW_HASKELL__ < 904
    go (ParPat XParPat GhcPs
_ LPat GhcPs
p) = LPat GhcPs -> Bool
irrefutablePat LPat GhcPs
p
#else
    go (ParPat _ _ p _) = irrefutablePat p
#endif
    go (BangPat XBangPat GhcPs
_ LPat GhcPs
p) = LPat GhcPs -> Bool
irrefutablePat LPat GhcPs
p
    go Pat GhcPs
_ = Bool
False

makeFunctionQuery
  :: LHsExpr GhcPs
  -> AnnotatedImports
  -> Direction
  -> GRHSs GhcPs (LHsExpr GhcPs)
  -> AppBuilder
  -> ([LPat GhcPs], [LPat GhcPs])
  -> TransformT IO [Rewrite (LHsExpr GhcPs)]
makeFunctionQuery :: LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> AppBuilder
-> ([LPat GhcPs], [LPat GhcPs])
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
makeFunctionQuery LHsExpr GhcPs
e AnnotatedImports
imps Direction
dir GRHSs GhcPs (LHsExpr GhcPs)
grhss AppBuilder
mkAppFn ([LPat GhcPs]
argpats, [LPat GhcPs]
bndpats)
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. LPat GhcPs -> Bool
irrefutablePat) [LPat GhcPs]
bndpats = forall (m :: * -> *) a. Monad m => a -> m a
return []
  | Bool
otherwise = do
    let
      GRHSs XCGRHSs GhcPs (LHsExpr GhcPs)
_ [LGRHS GhcPs (LHsExpr GhcPs)]
rhss HsLocalBinds GhcPs
lbs = GRHSs GhcPs (LHsExpr GhcPs)
grhss
      bs :: [IdP GhcPs]
bs = forall p. CollectPass p => CollectFlag p -> [LPat p] -> [IdP p]
collectPatsBinders forall p. CollectFlag p
CollNoDictBinders [LPat GhcPs]
argpats
    -- See Note [Wildcards]
    ([GenLocated SrcSpanAnnA (HsExpr GhcPs)]
es,([RdrName]
_,[RdrName]
bs')) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
MonadIO m =>
LPat GhcPs -> PatQ m (LHsExpr GhcPs)
patToExpr [LPat GhcPs]
argpats) ([RdrName] -> [RdrName]
wildSupply [RdrName]
bs, [RdrName]
bs)
    -- lift $ debugPrint Loud "makeFunctionQuery:e="  [showAst e]
    GenLocated SrcSpanAnnA (HsExpr GhcPs)
lhs <- AppBuilder
mkAppFn LHsExpr GhcPs
e [GenLocated SrcSpanAnnA (HsExpr GhcPs)]
es
    forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
t a -> (a -> f b) -> f (t b)
for [GenLocated
   SrcSpan (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))]
rhss forall a b. (a -> b) -> a -> b
$ \ GenLocated
  SrcSpan (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
grhs -> do
      GenLocated SrcSpanAnnA (HsExpr GhcPs)
le <- forall (m :: * -> *).
Monad m =>
HsLocalBinds GhcPs -> LHsExpr GhcPs -> TransformT m (LHsExpr GhcPs)
mkLet HsLocalBinds GhcPs
lbs (LGRHS GhcPs (LHsExpr GhcPs) -> LHsExpr GhcPs
grhsToExpr GenLocated
  SrcSpan (GRHS GhcPs (GenLocated SrcSpanAnnA (HsExpr GhcPs)))
grhs)
      GenLocated SrcSpanAnnA (HsExpr GhcPs)
rhs <- [LPat GhcPs] -> LHsExpr GhcPs -> TransformT IO (LHsExpr GhcPs)
mkLams [LPat GhcPs]
bndpats GenLocated SrcSpanAnnA (HsExpr GhcPs)
le
      let
        (GenLocated SrcSpanAnnA (HsExpr GhcPs)
pat, GenLocated SrcSpanAnnA (HsExpr GhcPs)
temp) =
          case Direction
dir of
            Direction
LeftToRight -> (GenLocated SrcSpanAnnA (HsExpr GhcPs)
lhs,GenLocated SrcSpanAnnA (HsExpr GhcPs)
rhs)
            Direction
RightToLeft -> (GenLocated SrcSpanAnnA (HsExpr GhcPs)
rhs,GenLocated SrcSpanAnnA (HsExpr GhcPs)
lhs)
      Annotated (GenLocated SrcSpanAnnA (HsExpr GhcPs))
p <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA GenLocated SrcSpanAnnA (HsExpr GhcPs)
pat
      Annotated (GenLocated SrcSpanAnnA (HsExpr GhcPs))
t <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA GenLocated SrcSpanAnnA (HsExpr GhcPs)
temp
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall ast. AnnotatedImports -> Rewrite ast -> Rewrite ast
addRewriteImports AnnotatedImports
imps forall a b. (a -> b) -> a -> b
$ forall ast.
Quantifiers -> Annotated ast -> Annotated ast -> Rewrite ast
mkRewrite ([RdrName] -> Quantifiers
mkQs [RdrName]
bs') Annotated (GenLocated SrcSpanAnnA (HsExpr GhcPs))
p Annotated (GenLocated SrcSpanAnnA (HsExpr GhcPs))
t

backtickRules
  :: LHsExpr GhcPs
  -> AnnotatedImports
  -> Direction
  -> GRHSs GhcPs (LHsExpr GhcPs)
  -> [LPat GhcPs]
  -> TransformT IO [Rewrite (LHsExpr GhcPs)]
backtickRules :: LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> [LPat GhcPs]
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
backtickRules LHsExpr GhcPs
e AnnotatedImports
imps dir :: Direction
dir@Direction
LeftToRight GRHSs GhcPs (LHsExpr GhcPs)
grhss ps :: [LPat GhcPs]
ps@[LPat GhcPs
p1, LPat GhcPs
p2] = do
  let
    both, left, right :: AppBuilder
    both :: AppBuilder
both LHsExpr GhcPs
op [LHsExpr GhcPs
l, LHsExpr GhcPs
r] = forall e (m :: * -> *) an.
(Data e, Monad m, Monoid an) =>
DeltaPos -> e -> TransformT m (LocatedAn an e)
mkLocA (Int -> DeltaPos
SameLine Int
1) (forall p.
XOpApp p -> LHsExpr p -> LHsExpr p -> LHsExpr p -> HsExpr p
OpApp forall a. EpAnn a
noAnn LHsExpr GhcPs
l LHsExpr GhcPs
op LHsExpr GhcPs
r)
    both LHsExpr GhcPs
_ [LHsExpr GhcPs]
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"backtickRules - both: impossible!"

    left :: AppBuilder
left LHsExpr GhcPs
op [LHsExpr GhcPs
l] = forall e (m :: * -> *) an.
(Data e, Monad m, Monoid an) =>
DeltaPos -> e -> TransformT m (LocatedAn an e)
mkLocA (Int -> DeltaPos
SameLine Int
1) (forall p. XSectionL p -> LHsExpr p -> LHsExpr p -> HsExpr p
SectionL forall a. EpAnn a
noAnn LHsExpr GhcPs
l LHsExpr GhcPs
op)
    left LHsExpr GhcPs
_ [LHsExpr GhcPs]
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"backtickRules - left: impossible!"

    right :: AppBuilder
right LHsExpr GhcPs
op [LHsExpr GhcPs
r] = forall e (m :: * -> *) an.
(Data e, Monad m, Monoid an) =>
DeltaPos -> e -> TransformT m (LocatedAn an e)
mkLocA (Int -> DeltaPos
SameLine Int
1) (forall p. XSectionR p -> LHsExpr p -> LHsExpr p -> HsExpr p
SectionR forall a. EpAnn a
noAnn LHsExpr GhcPs
op LHsExpr GhcPs
r)
    right LHsExpr GhcPs
_ [LHsExpr GhcPs]
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"backtickRules - right: impossible!"
  [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qs <- LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> AppBuilder
-> ([LPat GhcPs], [LPat GhcPs])
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
makeFunctionQuery LHsExpr GhcPs
e AnnotatedImports
imps Direction
dir GRHSs GhcPs (LHsExpr GhcPs)
grhss AppBuilder
both ([LPat GhcPs]
ps, [])
  [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qsl <- LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> AppBuilder
-> ([LPat GhcPs], [LPat GhcPs])
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
makeFunctionQuery LHsExpr GhcPs
e AnnotatedImports
imps Direction
dir GRHSs GhcPs (LHsExpr GhcPs)
grhss AppBuilder
left ([LPat GhcPs
p1], [LPat GhcPs
p2])
  [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qsr <- LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> GRHSs GhcPs (LHsExpr GhcPs)
-> AppBuilder
-> ([LPat GhcPs], [LPat GhcPs])
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
makeFunctionQuery LHsExpr GhcPs
e AnnotatedImports
imps Direction
dir GRHSs GhcPs (LHsExpr GhcPs)
grhss AppBuilder
right ([LPat GhcPs
p2], [LPat GhcPs
p1])
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qs forall a. [a] -> [a] -> [a]
++ [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qsl forall a. [a] -> [a] -> [a]
++ [Rewrite (GenLocated SrcSpanAnnA (HsExpr GhcPs))]
qsr
backtickRules LHsExpr GhcPs
_ AnnotatedImports
_ Direction
_ GRHSs GhcPs (LHsExpr GhcPs)
_ [LPat GhcPs]
_ = forall (m :: * -> *) a. Monad m => a -> m a
return []

-- Note [fold only]
-- Currently we only generate imports for folds, because it is easy.
-- (We only need to add an import for the module defining the folded
-- function.) Generating the imports for unfolds will require some
-- sort of analysis with haskell-names and is a TODO.