-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
module Retrie.Rewrites.Patterns (patternSynonymsToRewrites) where

import Control.Monad.State (StateT(runStateT), lift)
import Control.Monad
import Control.Monad.IO.Class
import Data.Maybe
import Data.Void

import Retrie.ExactPrint
import Retrie.Expr
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Rewrites.Function
import Retrie.Types
import Retrie.Universe
import Retrie.Util

patternSynonymsToRewrites
  :: LibDir
  -> [(FastString, Direction)]
  -> AnnotatedModule
  -> IO (UniqFM FastString [Rewrite Universe])
patternSynonymsToRewrites :: LibDir
-> [(FastString, Direction)]
-> AnnotatedModule
-> IO (UniqFM FastString [Rewrite Universe])
patternSynonymsToRewrites LibDir
libdir [(FastString, Direction)]
specs AnnotatedModule
am = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall ast. Annotated ast -> ast
astA forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA AnnotatedModule
am forall a b. (a -> b) -> a -> b
$ \(L SrcSpan
_ HsModule
m) -> do
  let
    fsMap :: UniqFM FastString [Direction]
fsMap = forall a b. Uniquable a => [(a, b)] -> UniqFM a [b]
uniqBag [(FastString, Direction)]
specs
  Annotated [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
imports <- LibDir
-> Direction
-> Maybe (LocatedA ModuleName)
-> TransformT IO AnnotatedImports
getImports LibDir
libdir Direction
RightToLeft (HsModule -> Maybe (LocatedA ModuleName)
hsmodName HsModule
m)
  [(FastString, [Rewrite Universe])]
rrs <- forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
      [ do
          Rewrite (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
patRewrite <- Direction
-> AnnotatedImports
-> LocatedN RdrName
-> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
-> LPat GhcPs
-> TransformT IO (Rewrite (LPat GhcPs))
mkPatRewrite Direction
dir Annotated [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
imports LIdP GhcPs
nm HsPatSynDetails GhcPs
params GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
lrhs
          [Rewrite (GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs))]
expRewrites <- Direction
-> AnnotatedImports
-> LocatedN RdrName
-> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
-> LPat GhcPs
-> HsPatSynDir GhcPs
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
mkExpRewrite Direction
dir Annotated [GenLocated (SrcAnn AnnListItem) (ImportDecl GhcPs)]
imports LIdP GhcPs
nm HsPatSynDetails GhcPs
params LPat GhcPs
rhs HsPatSynDir GhcPs
patdir
          forall (m :: * -> *) a. Monad m => a -> m a
return (FastString
rdr, forall ast. Matchable ast => Rewrite ast -> Rewrite Universe
toURewrite Rewrite (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
patRewrite forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall ast. Matchable ast => Rewrite ast -> Rewrite Universe
toURewrite [Rewrite (GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs))]
expRewrites)
      | L SrcAnn AnnListItem
_ (ValD XValD GhcPs
_ (PatSynBind XPatSynBind GhcPs GhcPs
_ (PSB XPSB GhcPs GhcPs
_ LIdP GhcPs
nm HsPatSynDetails GhcPs
params LPat GhcPs
rhs HsPatSynDir GhcPs
patdir))) <- HsModule -> [LHsDecl GhcPs]
hsmodDecls HsModule
m
      , let rdr :: FastString
rdr = RdrName -> FastString
rdrFS (forall l e. GenLocated l e -> e
unLoc LIdP GhcPs
nm)
      , Direction
dir <- forall a. a -> Maybe a -> a
fromMaybe [] (forall key elt. Uniquable key => UniqFM key elt -> key -> Maybe elt
lookupUFM UniqFM FastString [Direction]
fsMap FastString
rdr)
      , Just GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
lrhs <- [forall (p :: Pass). LPat (GhcPass p) -> Maybe (LPat (GhcPass p))
dLPat LPat GhcPs
rhs]
      ]

  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall key elt.
Uniquable key =>
(elt -> elt -> elt) -> [(key, elt)] -> UniqFM key elt
listToUFM_C forall a. [a] -> [a] -> [a]
(++) [(FastString, [Rewrite Universe])]
rrs

mkPatRewrite
  :: Direction
  -> AnnotatedImports
  -> LocatedN RdrName
  -> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
  -> LPat GhcPs
  -> TransformT IO (Rewrite (LPat GhcPs))
mkPatRewrite :: Direction
-> AnnotatedImports
-> LocatedN RdrName
-> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
-> LPat GhcPs
-> TransformT IO (Rewrite (LPat GhcPs))
mkPatRewrite Direction
dir AnnotatedImports
imports LocatedN RdrName
patName HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
params LPat GhcPs
rhs = do
  GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
lhs <- forall (m :: * -> *).
Monad m =>
LocatedN RdrName
-> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
-> TransformT m (LPat GhcPs)
asPat LocatedN RdrName
patName HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
params

  (GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
pat, GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
temp) <- case Direction
dir of
    Direction
LeftToRight -> forall (m :: * -> *) a. Monad m => a -> m a
return (GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
lhs, LPat GhcPs
rhs)
    Direction
RightToLeft -> do
      let lhs' :: GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
lhs' = forall t a. Default t => LocatedAn t a -> DeltaPos -> LocatedAn t a
setEntryDP GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
lhs (Int -> DeltaPos
SameLine Int
0)
      -- Patterns from lhs have wonky annotations,
      -- the space will be attached to the name, not to the ConPatIn ast node
      let lhs'' :: LPat GhcPs
lhs'' = LPat GhcPs -> DeltaPos -> LPat GhcPs
setEntryDPTunderConPatIn GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
lhs' (Int -> DeltaPos
SameLine Int
0)
      forall (m :: * -> *) a. Monad m => a -> m a
return (LPat GhcPs
rhs, LPat GhcPs
lhs'')

  Annotated (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
p <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
pat
  Annotated (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
t <- forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
temp
  let bs :: [IdP GhcPs]
bs = forall p. CollectPass p => CollectFlag p -> LPat p -> [IdP p]
collectPatBinders forall p. CollectFlag p
CollNoDictBinders (forall (p :: Pass). LPat (GhcPass p) -> LPat (GhcPass p)
cLPat GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
temp)
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall ast. AnnotatedImports -> Rewrite ast -> Rewrite ast
addRewriteImports AnnotatedImports
imports forall a b. (a -> b) -> a -> b
$ forall ast.
Quantifiers -> Annotated ast -> Annotated ast -> Rewrite ast
mkRewrite ([RdrName] -> Quantifiers
mkQs [IdP GhcPs]
bs) Annotated (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
p Annotated (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
t

  where
    setEntryDPTunderConPatIn :: LPat GhcPs -> DeltaPos -> LPat GhcPs
    setEntryDPTunderConPatIn :: LPat GhcPs -> DeltaPos -> LPat GhcPs
setEntryDPTunderConPatIn (L SrcAnn AnnListItem
l (ConPat XConPat GhcPs
x XRec GhcPs (ConLikeP GhcPs)
nm HsConPatDetails GhcPs
args)) DeltaPos
dp
      = (forall l e. l -> e -> GenLocated l e
L SrcAnn AnnListItem
l (forall p.
XConPat p -> XRec p (ConLikeP p) -> HsConPatDetails p -> Pat p
ConPat XConPat GhcPs
x (forall t a. Default t => LocatedAn t a -> DeltaPos -> LocatedAn t a
setEntryDP XRec GhcPs (ConLikeP GhcPs)
nm DeltaPos
dp) HsConPatDetails GhcPs
args))
    setEntryDPTunderConPatIn LPat GhcPs
p DeltaPos
_ = LPat GhcPs
p

asPat
  :: Monad m
  => LocatedN RdrName
  -> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
  -> TransformT m (LPat GhcPs)
asPat :: forall (m :: * -> *).
Monad m =>
LocatedN RdrName
-> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
-> TransformT m (LPat GhcPs)
asPat LocatedN RdrName
patName HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
params = do
  HsConDetails
  (HsPatSigType GhcPs)
  (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
  (HsRecFields GhcPs (GenLocated (SrcAnn AnnListItem) (Pat GhcPs)))
params' <- forall (m :: * -> *) tyarg tyarg' arg arg' rec rec'.
Applicative m =>
([tyarg] -> m [tyarg'])
-> (arg -> m arg')
-> (rec -> m rec')
-> HsConDetails tyarg arg rec
-> m (HsConDetails tyarg' arg' rec')
bitraverseHsConDetails forall (m :: * -> *).
Monad m =>
[Void] -> TransformT m [HsPatSigType GhcPs]
convertTyVars forall (m :: * -> *).
Monad m =>
LocatedN RdrName -> TransformT m (LPat GhcPs)
mkVarPat forall (m :: * -> *).
Monad m =>
[RecordPatSynField GhcPs]
-> TransformT m (HsRecFields GhcPs (LPat GhcPs))
convertFields HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
params
  forall (m :: * -> *).
Monad m =>
LocatedN RdrName
-> HsConPatDetails GhcPs -> TransformT m (LPat GhcPs)
mkConPatIn LocatedN RdrName
patName HsConDetails
  (HsPatSigType GhcPs)
  (GenLocated (SrcAnn AnnListItem) (Pat GhcPs))
  (HsRecFields GhcPs (GenLocated (SrcAnn AnnListItem) (Pat GhcPs)))
params'
  where

#if __GLASGOW_HASKELL__ <= 904
    convertTyVars :: (Monad m) => [Void] -> TransformT m [HsPatSigType GhcPs]
#else
    convertTyVars :: (Monad m) => [Void] -> TransformT m [HsConPatTyArg GhcPs]
#endif
    convertTyVars :: forall (m :: * -> *).
Monad m =>
[Void] -> TransformT m [HsPatSigType GhcPs]
convertTyVars [Void]
_ = forall (m :: * -> *) a. Monad m => a -> m a
return []

    convertFields :: (Monad m) => [RecordPatSynField GhcPs]
                      -> TransformT m (HsRecFields GhcPs (LPat GhcPs))
    convertFields :: forall (m :: * -> *).
Monad m =>
[RecordPatSynField GhcPs]
-> TransformT m (HsRecFields GhcPs (LPat GhcPs))
convertFields [RecordPatSynField GhcPs]
fields =
      forall p arg.
[LHsRecField p arg] -> Maybe (Located Int) -> HsRecFields p arg
HsRecFields forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *).
Monad m =>
RecordPatSynField GhcPs
-> TransformT m (LHsRecField GhcPs (LPat GhcPs))
convertField [RecordPatSynField GhcPs]
fields forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing

    convertField :: (Monad m) => RecordPatSynField GhcPs
                      -> TransformT m (LHsRecField GhcPs (LPat GhcPs))
    convertField :: forall (m :: * -> *).
Monad m =>
RecordPatSynField GhcPs
-> TransformT m (LHsRecField GhcPs (LPat GhcPs))
convertField RecordPatSynField{FieldOcc GhcPs
LIdP GhcPs
recordPatSynField :: forall pass. RecordPatSynField pass -> FieldOcc pass
recordPatSynPatVar :: forall pass. RecordPatSynField pass -> LIdP pass
recordPatSynPatVar :: LIdP GhcPs
recordPatSynField :: FieldOcc GhcPs
..} = do
#if __GLASGOW_HASKELL__ < 904
      Located (FieldOcc GhcPs)
hsRecFieldLbl <- forall e (m :: * -> *).
(Data e, Monad m) =>
e -> TransformT m (Located e)
mkLoc forall a b. (a -> b) -> a -> b
$ FieldOcc GhcPs
recordPatSynField
      GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
hsRecFieldArg <- forall (m :: * -> *).
Monad m =>
LocatedN RdrName -> TransformT m (LPat GhcPs)
mkVarPat LIdP GhcPs
recordPatSynPatVar
      let hsRecPun :: Bool
hsRecPun = Bool
False
      let hsRecFieldAnn :: EpAnn a
hsRecFieldAnn = forall a. EpAnn a
noAnn
      forall e (m :: * -> *) an.
(Data e, Monad m, Monoid an) =>
DeltaPos -> e -> TransformT m (LocatedAn an e)
mkLocA (Int -> DeltaPos
SameLine Int
0) HsRecField{Bool
GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
Located (FieldOcc GhcPs)
forall a. EpAnn a
hsRecFieldAnn :: XHsRecField (FieldOcc GhcPs)
hsRecFieldArg :: GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
hsRecFieldLbl :: Located (FieldOcc GhcPs)
hsRecPun :: Bool
hsRecFieldAnn :: forall a. EpAnn a
hsRecPun :: Bool
hsRecFieldArg :: GenLocated (SrcAnn AnnListItem) (Pat GhcPs)
hsRecFieldLbl :: Located (FieldOcc GhcPs)
..}
#else
      s <- uniqueSrcSpanT
      an <- mkEpAnn (SameLine 0) NoEpAnns
      let srcspan = SrcSpanAnn an s
          hfbLHS = L srcspan recordPatSynField
      hfbRHS <- mkVarPat recordPatSynPatVar
      let hfbPun = False
          hfbAnn = noAnn
      mkLocA (SameLine 0) HsFieldBind{..}
#endif

mkExpRewrite
  :: Direction
  -> AnnotatedImports
  -> LocatedN RdrName
  -> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
  -> LPat GhcPs
  -> HsPatSynDir GhcPs
  -> TransformT IO [Rewrite (LHsExpr GhcPs)]
mkExpRewrite :: Direction
-> AnnotatedImports
-> LocatedN RdrName
-> HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
-> LPat GhcPs
-> HsPatSynDir GhcPs
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
mkExpRewrite Direction
dir AnnotatedImports
imports LocatedN RdrName
patName HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
params LPat GhcPs
rhs HsPatSynDir GhcPs
patDir = do
  GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs)
fe <- forall (m :: * -> *).
Monad m =>
LocatedN RdrName -> TransformT m (LHsExpr GhcPs)
mkLocatedHsVar LocatedN RdrName
patName
  -- lift $ debugPrint Loud "mkExpRewrite:fe="  [showAst fe]
  let altsFromParams :: TransformT IO [LMatch GhcPs (LHsExpr GhcPs)]
altsFromParams = case HsConDetails Void (LocatedN RdrName) [RecordPatSynField GhcPs]
params of
        PrefixCon [Void]
_tyargs [LocatedN RdrName]
names -> forall (m :: * -> *).
MonadIO m =>
[LocatedN RdrName]
-> LPat GhcPs -> TransformT m [LMatch GhcPs (LHsExpr GhcPs)]
buildMatch [LocatedN RdrName]
names LPat GhcPs
rhs
        InfixCon LocatedN RdrName
a1 LocatedN RdrName
a2 -> forall (m :: * -> *).
MonadIO m =>
[LocatedN RdrName]
-> LPat GhcPs -> TransformT m [LMatch GhcPs (LHsExpr GhcPs)]
buildMatch [LocatedN RdrName
a1, LocatedN RdrName
a2] LPat GhcPs
rhs
        RecCon{} -> forall a. LibDir -> a
missingSyntax LibDir
"RecCon"
  [GenLocated
   (SrcAnn AnnListItem)
   (Match GhcPs (GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs)))]
alts <- case HsPatSynDir GhcPs
patDir of
    ExplicitBidirectional MG{XRec GhcPs [LMatch GhcPs (LHsExpr GhcPs)]
mg_alts :: forall p body. MatchGroup p body -> XRec p [LMatch p body]
mg_alts :: XRec GhcPs [LMatch GhcPs (LHsExpr GhcPs)]
mg_alts} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall l e. GenLocated l e -> e
unLoc XRec GhcPs [LMatch GhcPs (LHsExpr GhcPs)]
mg_alts
    HsPatSynDir GhcPs
ImplicitBidirectional -> TransformT IO [LMatch GhcPs (LHsExpr GhcPs)]
altsFromParams
    HsPatSynDir GhcPs
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure []
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [GenLocated
   (SrcAnn AnnListItem)
   (Match GhcPs (GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs)))]
alts forall a b. (a -> b) -> a -> b
$ LHsExpr GhcPs
-> AnnotatedImports
-> Direction
-> LMatch GhcPs (LHsExpr GhcPs)
-> TransformT IO [Rewrite (LHsExpr GhcPs)]
matchToRewrites GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs)
fe AnnotatedImports
imports Direction
dir

buildMatch
  :: MonadIO m
  => [LocatedN RdrName]
  -> LPat GhcPs
  -> TransformT m [LMatch GhcPs (LHsExpr GhcPs)]
buildMatch :: forall (m :: * -> *).
MonadIO m =>
[LocatedN RdrName]
-> LPat GhcPs -> TransformT m [LMatch GhcPs (LHsExpr GhcPs)]
buildMatch [LocatedN RdrName]
names LPat GhcPs
rhs = do
  [GenLocated (SrcAnn AnnListItem) (Pat GhcPs)]
pats <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall (m :: * -> *).
Monad m =>
LocatedN RdrName -> TransformT m (LPat GhcPs)
mkVarPat [LocatedN RdrName]
names
  let bs :: [IdP GhcPs]
bs = forall p. CollectPass p => CollectFlag p -> LPat p -> [IdP p]
collectPatBinders forall p. CollectFlag p
CollNoDictBinders LPat GhcPs
rhs
  (GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs)
rhsExpr,([RdrName]
_,[RdrName]
_bs')) <- forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall (m :: * -> *).
MonadIO m =>
LPat GhcPs -> PatQ m (LHsExpr GhcPs)
patToExpr LPat GhcPs
rhs) ([RdrName] -> [RdrName]
wildSupply [IdP GhcPs]
bs, [IdP GhcPs]
bs)
  let alt :: LMatch GhcPs (LHsExpr GhcPs)
alt = forall (p :: Pass).
IsPass p =>
HsMatchContext (NoGhcTc (GhcPass p))
-> [LPat (GhcPass p)]
-> LHsExpr (GhcPass p)
-> HsLocalBinds (GhcPass p)
-> LMatch (GhcPass p) (LHsExpr (GhcPass p))
mkMatch forall p. HsMatchContext p
PatSyn [GenLocated (SrcAnn AnnListItem) (Pat GhcPs)]
pats GenLocated (SrcAnn AnnListItem) (HsExpr GhcPs)
rhsExpr forall (a :: Pass) (b :: Pass).
HsLocalBindsLR (GhcPass a) (GhcPass b)
emptyLocalBinds
  forall (m :: * -> *) a. Monad m => a -> m a
return [LMatch GhcPs (LHsExpr GhcPs)
alt]