{-# LANGUAGE CPP #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
module Language.Haskell.TH.ExpandSyns(-- * Expand synonyms
                                      expandSyns
                                     ,expandSynsWith
                                     ,SynonymExpansionSettings
                                     ,noWarnTypeFamilies

                                      -- * Misc utilities
                                     ,substInType
                                     ,substInCon
                                     ,evades,evade) where

import Language.Haskell.TH.Datatype
import Language.Haskell.TH.Datatype.TyVarBndr
import Language.Haskell.TH.ExpandSyns.SemigroupCompat as Sem
import Language.Haskell.TH hiding(cxt)
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Generics
import Control.Monad
import Prelude

#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative
#endif

-- For ghci
#ifndef MIN_VERSION_template_haskell
#define MIN_VERSION_template_haskell(X,Y,Z) 1
#endif

packagename :: String
packagename :: String
packagename = String
"th-expand-syns"

tyVarBndrSetName :: Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName :: Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName Name
n = (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
forall flag. (Name -> Name) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVName (Name -> Name -> Name
forall a b. a -> b -> a
const Name
n)

data SynonymExpansionSettings =
  SynonymExpansionSettings {
    SynonymExpansionSettings -> Bool
sesWarnTypeFamilies :: Bool
  }

instance Semigroup SynonymExpansionSettings where
  SynonymExpansionSettings Bool
w1 <> :: SynonymExpansionSettings
-> SynonymExpansionSettings -> SynonymExpansionSettings
<> SynonymExpansionSettings Bool
w2 =
    Bool -> SynonymExpansionSettings
SynonymExpansionSettings (Bool
w1 Bool -> Bool -> Bool
&& Bool
w2)

-- | Default settings ('mempty'):
--
-- * Warn if type families are encountered.
--
-- (The 'mappend' is currently rather useless; the monoid instance is intended for additional settings in the future).
instance Monoid SynonymExpansionSettings where
  mempty :: SynonymExpansionSettings
mempty =
    SynonymExpansionSettings :: Bool -> SynonymExpansionSettings
SynonymExpansionSettings {
      sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
True
    }

#if !MIN_VERSION_base(4,11,0)
-- starting with base-4.11, mappend definitions are redundant;
-- at some point `mappend` will be removed from `Monoid`
  mappend = (Sem.<>)
#endif

-- | Suppresses the warning that type families are unsupported.
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies :: SynonymExpansionSettings
noWarnTypeFamilies = SynonymExpansionSettings
forall a. Monoid a => a
mempty { sesWarnTypeFamilies :: Bool
sesWarnTypeFamilies = Bool
False }

warn ::  String -> Q ()
warn :: String -> Q ()
warn String
msg =
#if MIN_VERSION_template_haskell(2,8,0)
    String -> Q ()
reportWarning
#else
    report False
#endif
      (String
packagename String -> String -> String
forall a. [a] -> [a] -> [a]
++String
": WARNING: "String -> String -> String
forall a. [a] -> [a] -> [a]
++String
msg)

warnIfNameIsTypeFamily :: Name -> Q ()
warnIfNameIsTypeFamily :: Name -> Q ()
warnIfNameIsTypeFamily Name
n = do
  Info
i <- Name -> Q Info
reify Name
n
  case Info
i of
    ClassI {} -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    ClassOpI {} -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    TyConI Dec
d -> Dec -> Q ()
warnIfDecIsTypeFamily Dec
d
#if MIN_VERSION_template_haskell(2,7,0)
    FamilyI Dec
d [Dec]
_ -> Dec -> Q ()
warnIfDecIsTypeFamily Dec
d -- Called for warnings
#endif
    PrimTyConI {} -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    DataConI {} -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    VarI {} -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    TyVarI {} -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#if MIN_VERSION_template_haskell(2,12,0)
    PatSynI {} -> () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

warnIfDecIsTypeFamily :: Dec -> Q ()
warnIfDecIsTypeFamily :: Dec -> Q ()
warnIfDecIsTypeFamily = Dec -> Q ()
go
  where
    go :: Dec -> Q ()
go (TySynD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

#if MIN_VERSION_template_haskell(2,11,0)
    go (OpenTypeFamilyD (TypeFamilyHead Name
name [TyVarBndr_ flag]
_ FamilyResultSig
_ Maybe InjectivityAnn
_)) = Name -> Q ()
maybeWarnTypeFamily Name
name
    go (ClosedTypeFamilyD (TypeFamilyHead Name
name [TyVarBndr_ flag]
_ FamilyResultSig
_ Maybe InjectivityAnn
_) [TySynEqn]
_) = Name -> Q ()
maybeWarnTypeFamily Name
name
#else

#if MIN_VERSION_template_haskell(2,9,0)
    go (ClosedTypeFamilyD name _ _ _) = maybeWarnTypeFamily name
#endif

    go (FamilyD TypeFam name _ _) = maybeWarnTypeFamily name
#endif

    go (FunD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ValD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (DataD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (NewtypeD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ClassD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (InstanceD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (SigD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ForeignD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

#if MIN_VERSION_template_haskell(2,8,0)
    go (InfixD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

    go (PragmaD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    -- Nothing to expand for data families, so no warning
#if MIN_VERSION_template_haskell(2,11,0)
    go (DataFamilyD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#else
    go (FamilyD DataFam _ _ _) = return ()
#endif

    go (DataInstD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (NewtypeInstD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (TySynInstD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

#if MIN_VERSION_template_haskell(2,9,0)
    go (RoleAnnotD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,10,0)
    go (StandaloneDerivD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (DefaultSigD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,12,0)
    go (PatSynD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (PatSynSigD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,15,0)
    go (ImplicitParamBindD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

#if MIN_VERSION_template_haskell(2,16,0)
    go (KiSigD {}) = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif

warnTypeFamiliesInType :: Type -> Q ()
warnTypeFamiliesInType :: Type -> Q ()
warnTypeFamiliesInType = Type -> Q ()
go
  where
    go :: Type -> Q ()
    go :: Type -> Q ()
go (ConT Name
n)     = Name -> Q ()
warnIfNameIsTypeFamily Name
n
    go (AppT Type
t1 Type
t2) = Type -> Q ()
go Type
t1 Q () -> Q () -> Q ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go Type
t2
    go (SigT Type
t Type
k)   = Type -> Q ()
go Type
t  Q () -> Q () -> Q ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go_kind Type
k
    go ListT{}      = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go ArrowT{}     = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go VarT{}       = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go TupleT{}     = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go (ForallT [TyVarBndr_ flag]
tvbs Cxt
ctxt Type
body) = do
      (TyVarBndr_ flag -> Q ()) -> [TyVarBndr_ flag] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Q ()
go_kind (Type -> Q ())
-> (TyVarBndr_ flag -> Type) -> TyVarBndr_ flag -> Q ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr_ flag -> Type
forall flag. TyVarBndr_ flag -> Type
tvKind) [TyVarBndr_ flag]
tvbs
      (Type -> Q ()) -> Cxt -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Type -> Q ()
go_pred Cxt
ctxt
      Type -> Q ()
go Type
body
#if MIN_VERSION_template_haskell(2,6,0)
    go UnboxedTupleT{} = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,8,0)
    go PromotedT{}      = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go PromotedTupleT{} = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go PromotedConsT{}  = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go PromotedNilT{}   = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go StarT{}          = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go ConstraintT{}    = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    go LitT{}           = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,10,0)
    go EqualityT{} = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,11,0)
    go (InfixT Type
t1 Name
n Type
t2) = do
      Name -> Q ()
warnIfNameIsTypeFamily Name
n
      Type -> Q ()
go Type
t1
      Type -> Q ()
go Type
t2
    go (UInfixT Type
t1 Name
n Type
t2) = do
      Name -> Q ()
warnIfNameIsTypeFamily Name
n
      Type -> Q ()
go Type
t1
      Type -> Q ()
go Type
t2
    go (ParensT Type
t) = Type -> Q ()
go Type
t
    go WildCardT{} = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,12,0)
    go UnboxedSumT{} = () -> Q ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
#endif
#if MIN_VERSION_template_haskell(2,15,0)
    go (AppKindT Type
t Type
k)       = Type -> Q ()
go Type
t Q () -> Q () -> Q ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Type -> Q ()
go_kind Type
k
    go (ImplicitParamT String
_ Type
t) = Type -> Q ()
go Type
t
#endif
#if MIN_VERSION_template_haskell(2,16,0)
    go (ForallVisT [TyVarBndr_ flag]
tvbs Type
body) = do
      (TyVarBndr_ flag -> Q ()) -> [TyVarBndr_ flag] -> Q ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type -> Q ()
go_kind (Type -> Q ())
-> (TyVarBndr_ flag -> Type) -> TyVarBndr_ flag -> Q ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr_ flag -> Type
forall flag. TyVarBndr_ flag -> Type
tvKind) [TyVarBndr_ flag]
tvbs
      Type -> Q ()
go Type
body
#endif
#if MIN_VERSION_template_haskell(2,17,0)
    go MulArrowT{} = return ()
#endif

    go_kind :: Kind -> Q ()
#if MIN_VERSION_template_haskell(2,8,0)
    go_kind :: Type -> Q ()
go_kind = Type -> Q ()
go
#else
    go_kind _ = return ()
#endif

    go_pred :: Pred -> Q ()
#if MIN_VERSION_template_haskell(2,10,0)
    go_pred :: Type -> Q ()
go_pred = Type -> Q ()
go
#else
    go_pred (ClassP _ ts)  = mapM_ go ts
    go_pred (EqualP t1 t2) = go t1 >> go t2
#endif

maybeWarnTypeFamily :: Name -> Q ()
maybeWarnTypeFamily :: Name -> Q ()
maybeWarnTypeFamily Name
name =
  String -> Q ()
warn (String
"Type synonym families (and associated type synonyms) are currently not supported (they won't be expanded). Name of unsupported family: "String -> String -> String
forall a. [a] -> [a] -> [a]
++Name -> String
forall a. Show a => a -> String
show Name
name)

-- | Calls 'expandSynsWith' with the default settings.
expandSyns :: Type -> Q Type
expandSyns :: Type -> Q Type
expandSyns = SynonymExpansionSettings -> Type -> Q Type
expandSynsWith SynonymExpansionSettings
forall a. Monoid a => a
mempty

-- | Expands all type synonyms in the given type. Type families currently won't be expanded (but will be passed through).
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith :: SynonymExpansionSettings -> Type -> Q Type
expandSynsWith SynonymExpansionSettings
settings = Type -> Q Type
expandSyns'
    where
      expandSyns' :: Type -> Q Type
expandSyns' Type
x = do
        Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SynonymExpansionSettings -> Bool
sesWarnTypeFamilies SynonymExpansionSettings
settings) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
          Type -> Q ()
warnTypeFamiliesInType Type
x
        Type -> Q Type
resolveTypeSynonyms Type
x

-- | Make a name (based on the first arg) that's distinct from every name in the second arg
--
-- Example why this is necessary:
--
-- > type E x = forall y. Either x y
-- >
-- > ... expandSyns [t| forall y. y -> E y |]
--
-- The example as given may actually work correctly without any special capture-avoidance depending
-- on how GHC handles the @y@s, but in any case, the input type to expandSyns may be an explicit
-- AST using 'mkName' to ensure a collision.
--
evade :: Data d => Name -> d -> Name
evade :: Name -> d -> Name
evade Name
n d
t =
    let
        vars :: Set.Set Name
        vars :: Set Name
vars = (Set Name -> Set Name -> Set Name)
-> GenericQ (Set Name) -> d -> Set Name
forall r. (r -> r -> r) -> GenericQ r -> GenericQ r
everything Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Set Name -> (Name -> Set Name) -> a -> Set Name
forall a b r. (Typeable a, Typeable b) => r -> (b -> r) -> a -> r
mkQ Set Name
forall a. Set a
Set.empty Name -> Set Name
forall a. a -> Set a
Set.singleton) d
t

        go :: Name -> Name
go Name
n1 = if Name
n1 Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.member` Set Name
vars
                then Name -> Name
go (Name -> Name
bump Name
n1)
                else Name
n1

        bump :: Name -> Name
bump = String -> Name
mkName (String -> Name) -> (Name -> String) -> Name -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
'f'Char -> String -> String
forall a. a -> [a] -> [a]
:) (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameBase
    in
      Name -> Name
go Name
n

-- | Make a list of names (based on the first arg) such that every name in the result
-- is distinct from every name in the second arg, and from the other results
evades :: (Data t) => [Name] -> t -> [Name]
evades :: [Name] -> t -> [Name]
evades [Name]
ns t
t = (Name -> [Name] -> [Name]) -> [Name] -> [Name] -> [Name]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Name -> [Name] -> [Name]
c [] [Name]
ns
    where
      c :: Name -> [Name] -> [Name]
c Name
n [Name]
rec = Name -> ([Name], t) -> Name
forall d. Data d => Name -> d -> Name
evade Name
n ([Name]
rec,t
t) Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
: [Name]
rec

-- evadeTest = let v = mkName "x"
--             in
--               evade v (AppT (VarT v) (VarT (mkName "fx")))

-- | Capture-free substitution
substInType :: (Name,Type) -> Type -> Type
substInType :: (Name, Type) -> Type -> Type
substInType (Name, Type)
vt = Map Name Type -> Type -> Type
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution ([(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt])

-- | Capture-free substitution
substInCon :: (Name,Type) -> Con -> Con
substInCon :: (Name, Type) -> Con -> Con
substInCon (Name, Type)
vt = Con -> Con
go
    where
      vtSubst :: Map Name Type
vtSubst = [(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt]
      st :: a -> a
st = Map Name Type -> a -> a
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vtSubst

      go :: Con -> Con
go (NormalC Name
n [BangType]
ts) = Name -> [BangType] -> Con
NormalC Name
n [(Bang
x, Type -> Type
forall a. TypeSubstitution a => a -> a
st Type
y) | (Bang
x,Type
y) <- [BangType]
ts]
      go (RecC Name
n [VarBangType]
ts) = Name -> [VarBangType] -> Con
RecC Name
n [(Name
x, Bang
y, Type -> Type
forall a. TypeSubstitution a => a -> a
st Type
z) | (Name
x,Bang
y,Type
z) <- [VarBangType]
ts]
      go (InfixC (Bang
y1,Type
t1) Name
op (Bang
y2,Type
t2)) = BangType -> Name -> BangType -> Con
InfixC (Bang
y1,Type -> Type
forall a. TypeSubstitution a => a -> a
st Type
t1) Name
op (Bang
y2,Type -> Type
forall a. TypeSubstitution a => a -> a
st Type
t2)
      go (ForallC [TyVarBndr_ flag]
vars Cxt
cxt Con
body) =
          (Name, Type)
-> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> Con)
-> Con
forall flag a.
(Name, Type)
-> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase (Name, Type)
vt [TyVarBndr_ flag]
vars ((Map Name Type -> [TyVarBndr_ flag] -> Con) -> Con)
-> (Map Name Type -> [TyVarBndr_ flag] -> Con) -> Con
forall a b. (a -> b) -> a -> b
$ \Map Name Type
vts' [TyVarBndr_ flag]
vars' ->
          [TyVarBndr_ flag] -> Cxt -> Con -> Con
ForallC ((TyVarBndr_ flag -> TyVarBndr_ flag)
-> [TyVarBndr_ flag] -> [TyVarBndr_ flag]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> TyVarBndr_ flag -> TyVarBndr_ flag
forall flag. (Type -> Type) -> TyVarBndr_ flag -> TyVarBndr_ flag
mapTVKind (Map Name Type -> Type -> Type
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vts')) [TyVarBndr_ flag]
vars')
                  (Map Name Type -> Cxt -> Cxt
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
vts' Cxt
cxt)
                  ((Name -> Type -> Con -> Con) -> Con -> Map Name Type -> Con
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
Map.foldrWithKey (\Name
v Type
t -> (Name, Type) -> Con -> Con
substInCon (Name
v, Type
t)) Con
body Map Name Type
vts')
#if MIN_VERSION_template_haskell(2,11,0)
      go c :: Con
c@GadtC{} = Con -> Con
forall a a. Ppr a => a -> a
errGadt Con
c
      go c :: Con
c@RecGadtC{} = Con -> Con
forall a a. Ppr a => a -> a
errGadt Con
c

      errGadt :: a -> a
errGadt a
c = String -> a
forall a. HasCallStack => String -> a
error (String
packagenameString -> String -> String
forall a. [a] -> [a] -> [a]
++String
": substInCon currently doesn't support GADT constructors with GHC >= 8 ("String -> String -> String
forall a. [a] -> [a] -> [a]
++a -> String
forall a. Ppr a => a -> String
pprint a
cString -> String -> String
forall a. [a] -> [a] -> [a]
++String
")")
#endif

-- Apply a substitution to something underneath a @forall@. The continuation
-- argument provides new substitutions and fresh type variable binders to avoid
-- the outer substitution from capturing the thing underneath the @forall@.
commonForallCase :: (Name, Type) -> [TyVarBndr_ flag]
                 -> (Map Name Type -> [TyVarBndr_ flag] -> a)
                 -> a
commonForallCase :: (Name, Type)
-> [TyVarBndr_ flag]
-> (Map Name Type -> [TyVarBndr_ flag] -> a)
-> a
commonForallCase vt :: (Name, Type)
vt@(Name
v,Type
t) [TyVarBndr_ flag]
bndrs Map Name Type -> [TyVarBndr_ flag] -> a
k
            -- If a variable with the same name as the one to be replaced is bound by the forall,
            -- the variable to be replaced is shadowed in the body, so we leave the whole thing alone (no recursion)
          | Name
v Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndr_ flag -> Name) -> [TyVarBndr_ flag] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs) = Map Name Type -> [TyVarBndr_ flag] -> a
k ([(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name, Type)
vt]) [TyVarBndr_ flag]
bndrs

          | Bool
otherwise =
              let
                  -- prevent capture
                  vars :: [Name]
vars = TyVarBndr_ flag -> Name
forall flag. TyVarBndr_ flag -> Name
tvName (TyVarBndr_ flag -> Name) -> [TyVarBndr_ flag] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr_ flag]
bndrs
                  freshes :: [Name]
freshes = [Name] -> Type -> [Name]
forall t. Data t => [Name] -> t -> [Name]
evades [Name]
vars Type
t
                  freshTyVarBndrs :: [TyVarBndr_ flag]
freshTyVarBndrs = (Name -> TyVarBndr_ flag -> TyVarBndr_ flag)
-> [Name] -> [TyVarBndr_ flag] -> [TyVarBndr_ flag]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Name -> TyVarBndr_ flag -> TyVarBndr_ flag
forall flag. Name -> TyVarBndr_ flag -> TyVarBndr_ flag
tyVarBndrSetName [Name]
freshes [TyVarBndr_ flag]
bndrs
                  substs :: [(Name, Type)]
substs = [Name] -> Cxt -> [(Name, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
vars (Name -> Type
VarT (Name -> Type) -> [Name] -> Cxt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Name]
freshes)
              in
                Map Name Type -> [TyVarBndr_ flag] -> a
k ([(Name, Type)] -> Map Name Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ((Name, Type)
vt(Name, Type) -> [(Name, Type)] -> [(Name, Type)]
forall a. a -> [a] -> [a]
:[(Name, Type)]
substs)) [TyVarBndr_ flag]
forall flag. [TyVarBndr_ flag]
freshTyVarBndrs