-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Retrie.Rewrites.Types where

import Control.Monad
import Data.Maybe

import Retrie.ExactPrint
import Retrie.Expr
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Types

typeSynonymsToRewrites
  :: [(FastString, Direction)]
  -> AnnotatedModule
#if __GLASGOW_HASKELL__ < 900
  -> IO (UniqFM [Rewrite (LHsType GhcPs)])
#else
  -> IO (UniqFM FastString [Rewrite (LHsType GhcPs)])
#endif
typeSynonymsToRewrites :: [(FastString, Direction)]
-> AnnotatedModule -> IO (UniqFM [Rewrite (LHsType GhcPs)])
typeSynonymsToRewrites [(FastString, Direction)]
specs AnnotatedModule
am = (Annotated (UniqFM [Rewrite (LHsType GhcPs)])
 -> UniqFM [Rewrite (LHsType GhcPs)])
-> IO (Annotated (UniqFM [Rewrite (LHsType GhcPs)]))
-> IO (UniqFM [Rewrite (LHsType GhcPs)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Annotated (UniqFM [Rewrite (LHsType GhcPs)])
-> UniqFM [Rewrite (LHsType GhcPs)]
forall ast. Annotated ast -> ast
astA (IO (Annotated (UniqFM [Rewrite (LHsType GhcPs)]))
 -> IO (UniqFM [Rewrite (LHsType GhcPs)]))
-> IO (Annotated (UniqFM [Rewrite (LHsType GhcPs)]))
-> IO (UniqFM [Rewrite (LHsType GhcPs)])
forall a b. (a -> b) -> a -> b
$ AnnotatedModule
-> (Located HsModule
    -> TransformT IO (UniqFM [Rewrite (LHsType GhcPs)]))
-> IO (Annotated (UniqFM [Rewrite (LHsType GhcPs)]))
forall (m :: * -> *) ast1 ast2.
Monad m =>
Annotated ast1 -> (ast1 -> TransformT m ast2) -> m (Annotated ast2)
transformA AnnotatedModule
am ((Located HsModule
  -> TransformT IO (UniqFM [Rewrite (LHsType GhcPs)]))
 -> IO (Annotated (UniqFM [Rewrite (LHsType GhcPs)])))
-> (Located HsModule
    -> TransformT IO (UniqFM [Rewrite (LHsType GhcPs)]))
-> IO (Annotated (UniqFM [Rewrite (LHsType GhcPs)]))
forall a b. (a -> b) -> a -> b
$ \ Located HsModule
m -> do
  let
    fsMap :: UniqFM [Direction]
fsMap = [(FastString, Direction)] -> UniqFM [Direction]
forall a b. Uniquable a => [(a, b)] -> UniqFM [b]
uniqBag [(FastString, Direction)]
specs
    tySyns :: [(FastString,
  (Direction,
   (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)))]
tySyns =
      [ (FastString
rdr, (Direction
dir, (Located (IdP GhcPs)
Located RdrName
nm, LHsQTyVars GhcPs -> [LHsTyVarBndr GhcPs]
forall pass. LHsQTyVars pass -> [LHsTyVarBndr pass]
hsq_explicit LHsQTyVars GhcPs
vars, LHsType GhcPs
rhs)))
        -- only hsq_explicit is available pre-renaming
      | L SrcSpan
_ (TyClD XTyClD GhcPs
_ (SynDecl XSynDecl GhcPs
_ Located (IdP GhcPs)
nm LHsQTyVars GhcPs
vars LexicalFixity
_ LHsType GhcPs
rhs)) <- HsModule -> [GenLocated SrcSpan (HsDecl GhcPs)]
forall pass. HsModule pass -> [LHsDecl pass]
hsmodDecls (HsModule -> [GenLocated SrcSpan (HsDecl GhcPs)])
-> HsModule -> [GenLocated SrcSpan (HsDecl GhcPs)]
forall a b. (a -> b) -> a -> b
$ Located HsModule -> SrcSpanLess (Located HsModule)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located HsModule
m
      , let rdr :: FastString
rdr = RdrName -> FastString
rdrFS (Located RdrName -> SrcSpanLess (Located RdrName)
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc Located (IdP GhcPs)
Located RdrName
nm)
      , Direction
dir <- [Direction] -> Maybe [Direction] -> [Direction]
forall a. a -> Maybe a -> a
fromMaybe [] (UniqFM [Direction] -> FastString -> Maybe [Direction]
forall key elt. Uniquable key => UniqFM elt -> key -> Maybe elt
lookupUFM UniqFM [Direction]
fsMap FastString
rdr)
      ]
  ([(FastString, Rewrite (LHsType GhcPs))]
 -> UniqFM [Rewrite (LHsType GhcPs)])
-> TransformT IO [(FastString, Rewrite (LHsType GhcPs))]
-> TransformT IO (UniqFM [Rewrite (LHsType GhcPs)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(FastString, Rewrite (LHsType GhcPs))]
-> UniqFM [Rewrite (LHsType GhcPs)]
forall a b. Uniquable a => [(a, b)] -> UniqFM [b]
uniqBag (TransformT IO [(FastString, Rewrite (LHsType GhcPs))]
 -> TransformT IO (UniqFM [Rewrite (LHsType GhcPs)]))
-> TransformT IO [(FastString, Rewrite (LHsType GhcPs))]
-> TransformT IO (UniqFM [Rewrite (LHsType GhcPs)])
forall a b. (a -> b) -> a -> b
$
    [(FastString,
  (Direction,
   (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)))]
-> ((FastString,
     (Direction,
      (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)))
    -> TransformT IO (FastString, Rewrite (LHsType GhcPs)))
-> TransformT IO [(FastString, Rewrite (LHsType GhcPs))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(FastString,
  (Direction,
   (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)))]
tySyns (((FastString,
   (Direction,
    (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)))
  -> TransformT IO (FastString, Rewrite (LHsType GhcPs)))
 -> TransformT IO [(FastString, Rewrite (LHsType GhcPs))])
-> ((FastString,
     (Direction,
      (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)))
    -> TransformT IO (FastString, Rewrite (LHsType GhcPs)))
-> TransformT IO [(FastString, Rewrite (LHsType GhcPs))]
forall a b. (a -> b) -> a -> b
$ \(FastString
rdr, (Direction, (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs))
args) -> (FastString
rdr,) (Rewrite (LHsType GhcPs) -> (FastString, Rewrite (LHsType GhcPs)))
-> TransformT IO (Rewrite (LHsType GhcPs))
-> TransformT IO (FastString, Rewrite (LHsType GhcPs))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Direction
 -> (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)
 -> TransformT IO (Rewrite (LHsType GhcPs)))
-> (Direction,
    (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs))
-> TransformT IO (Rewrite (LHsType GhcPs))
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Direction
-> (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)
-> TransformT IO (Rewrite (LHsType GhcPs))
mkTypeRewrite (Direction, (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs))
args

------------------------------------------------------------------------

-- | Compile a list of RULES into a list of rewrites.
mkTypeRewrite
  :: Direction
#if __GLASGOW_HASKELL__ < 900
  -> (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)
#else
  -> (Located RdrName, [LHsTyVarBndr () GhcPs], LHsType GhcPs)
#endif
  -> TransformT IO (Rewrite (LHsType GhcPs))
mkTypeRewrite :: Direction
-> (Located RdrName, [LHsTyVarBndr GhcPs], LHsType GhcPs)
-> TransformT IO (Rewrite (LHsType GhcPs))
mkTypeRewrite Direction
d (Located RdrName
lhsName, [LHsTyVarBndr GhcPs]
vars, LHsType GhcPs
rhs) = do
  Located RdrName -> DeltaPos -> TransformT IO ()
forall a (m :: * -> *).
(Data a, Monad m) =>
Located a -> DeltaPos -> TransformT m ()
setEntryDPT Located RdrName
lhsName (DeltaPos -> TransformT IO ()) -> DeltaPos -> TransformT IO ()
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> DeltaPos
DP (Int
0,Int
0)
  LHsType GhcPs
tc <- Located RdrName -> TransformT IO (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
Located RdrName -> TransformT m (LHsType GhcPs)
mkTyVar Located RdrName
lhsName
  let
    lvs :: [Located RdrName]
lvs = [LHsTyVarBndr GhcPs] -> [Located RdrName]
tyBindersToLocatedRdrNames [LHsTyVarBndr GhcPs]
vars
  [LHsType GhcPs]
args <- [Located RdrName]
-> (Located RdrName -> TransformT IO (LHsType GhcPs))
-> TransformT IO [LHsType GhcPs]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Located RdrName]
lvs ((Located RdrName -> TransformT IO (LHsType GhcPs))
 -> TransformT IO [LHsType GhcPs])
-> (Located RdrName -> TransformT IO (LHsType GhcPs))
-> TransformT IO [LHsType GhcPs]
forall a b. (a -> b) -> a -> b
$ \ Located RdrName
lv -> do
    LHsType GhcPs
tv <- Located RdrName -> TransformT IO (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
Located RdrName -> TransformT m (LHsType GhcPs)
mkTyVar Located RdrName
lv
    LHsType GhcPs -> DeltaPos -> TransformT IO ()
forall a (m :: * -> *).
(Data a, Monad m) =>
Located a -> DeltaPos -> TransformT m ()
setEntryDPT LHsType GhcPs
tv ((Int, Int) -> DeltaPos
DP (Int
0,Int
1))
    LHsType GhcPs -> TransformT IO (LHsType GhcPs)
forall (m :: * -> *) a. Monad m => a -> m a
return LHsType GhcPs
tv
  LHsType GhcPs
lhsApps <- [LHsType GhcPs] -> TransformT IO (LHsType GhcPs)
forall (m :: * -> *).
Monad m =>
[LHsType GhcPs] -> TransformT m (LHsType GhcPs)
mkHsAppsTy (LHsType GhcPs
tcLHsType GhcPs -> [LHsType GhcPs] -> [LHsType GhcPs]
forall a. a -> [a] -> [a]
:[LHsType GhcPs]
args)
  let
    (LHsType GhcPs
pat, LHsType GhcPs
tmp) = case Direction
d of
      Direction
LeftToRight -> (LHsType GhcPs
lhsApps, LHsType GhcPs
rhs)
      Direction
RightToLeft -> (LHsType GhcPs
rhs, LHsType GhcPs
lhsApps)
  Annotated (LHsType GhcPs)
p <- LHsType GhcPs -> TransformT IO (Annotated (LHsType GhcPs))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA LHsType GhcPs
pat
  Annotated (LHsType GhcPs)
t <- LHsType GhcPs -> TransformT IO (Annotated (LHsType GhcPs))
forall ast (m :: * -> *).
(Data ast, Monad m) =>
ast -> TransformT m (Annotated ast)
pruneA LHsType GhcPs
tmp
  Rewrite (LHsType GhcPs) -> TransformT IO (Rewrite (LHsType GhcPs))
forall (m :: * -> *) a. Monad m => a -> m a
return (Rewrite (LHsType GhcPs)
 -> TransformT IO (Rewrite (LHsType GhcPs)))
-> Rewrite (LHsType GhcPs)
-> TransformT IO (Rewrite (LHsType GhcPs))
forall a b. (a -> b) -> a -> b
$ Quantifiers
-> Annotated (LHsType GhcPs)
-> Annotated (LHsType GhcPs)
-> Rewrite (LHsType GhcPs)
forall ast.
Quantifiers -> Annotated ast -> Annotated ast -> Rewrite ast
mkRewrite ([RdrName] -> Quantifiers
mkQs ([RdrName] -> Quantifiers) -> [RdrName] -> Quantifiers
forall a b. (a -> b) -> a -> b
$ (Located RdrName -> RdrName) -> [Located RdrName] -> [RdrName]
forall a b. (a -> b) -> [a] -> [b]
map Located RdrName -> RdrName
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc [Located RdrName]
lvs) Annotated (LHsType GhcPs)
p Annotated (LHsType GhcPs)
t