module Language.Haskell.TH.Extras where
import Control.Monad
import Data.Generics
import Data.Maybe
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
intIs64 :: Bool
intIs64 = toInteger (maxBound :: Int) > 2^32
replace :: (a -> Maybe a) -> (a -> a)
replace = ap fromMaybe
composeExprs :: [ExpQ] -> ExpQ
composeExprs [] = [| id |]
composeExprs [f] = f
composeExprs (f:fs) = [| $f . $(composeExprs fs) |]
nameOfCon :: Con -> Name
nameOfCon (NormalC name _) = name
nameOfCon (RecC name _) = name
nameOfCon (InfixC _ name _) = name
nameOfCon (ForallC _ _ con) = nameOfCon con
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 800
nameOfCon (GadtC [name] _ _) = name
nameOfCon (RecGadtC [name] _ _) = name
#endif
argTypesOfCon :: Con -> [Type]
argTypesOfCon (NormalC _ args) = map snd args
argTypesOfCon (RecC _ args) = [t | (_,_,t) <- args]
argTypesOfCon (InfixC x _ y) = map snd [x,y]
argTypesOfCon (ForallC _ _ con) = argTypesOfCon con
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 800
argTypesOfCon (GadtC _ args _) = map snd args
argTypesOfCon (RecGadtC _ args _) = [t | (_,_,t) <- args]
#endif
nameOfBinder :: TyVarBndr -> Name
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 700
nameOfBinder (PlainTV n) = n
nameOfBinder (KindedTV n _) = n
#else
nameOfBinder = id
type TyVarBndr = Name
#endif
varsBoundInCon :: Con -> [TyVarBndr]
varsBoundInCon (ForallC bndrs _ con) = bndrs ++ varsBoundInCon con
varsBoundInCon _ = []
namesBoundInPat :: Pat -> [Name]
namesBoundInPat (VarP name) = [name]
namesBoundInPat (TupP pats) = pats >>= namesBoundInPat
namesBoundInPat (ConP _ pats) = pats >>= namesBoundInPat
namesBoundInPat (InfixP p1 _ p2) = namesBoundInPat p1 ++ namesBoundInPat p2
namesBoundInPat (TildeP pat) = namesBoundInPat pat
namesBoundInPat (AsP name pat) = name : namesBoundInPat pat
namesBoundInPat (RecP _ fieldPats) = map snd fieldPats >>= namesBoundInPat
namesBoundInPat (ListP pats) = pats >>= namesBoundInPat
namesBoundInPat (SigP pat _) = namesBoundInPat pat
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
namesBoundInPat (BangP pat) = namesBoundInPat pat
#endif
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 700
namesBoundInPat (ViewP _ pat) = namesBoundInPat pat
#endif
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
namesBoundInPat (UnboxedTupP pats) = pats >>= namesBoundInPat
#endif
namesBoundInPat _ = []
namesBoundInDec :: Dec -> [Name]
namesBoundInDec (FunD name _) = [name]
namesBoundInDec (ValD pat _ _) = namesBoundInPat pat
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 800
namesBoundInDec (DataD _ name _ _ _ _) = [name]
namesBoundInDec (NewtypeD _ name _ _ _ _) = [name]
#else
namesBoundInDec (DataD _ name _ _ _) = [name]
namesBoundInDec (NewtypeD _ name _ _ _) = [name]
#endif
namesBoundInDec (TySynD name _ _) = [name]
namesBoundInDec (ClassD _ name _ _ _) = [name]
namesBoundInDec (ForeignD (ImportF _ _ _ name _)) = [name]
#if defined(__GLASGOW_HASKELL__)
#if __GLASGOW_HASKELL__ >= 800
namesBoundInDec (OpenTypeFamilyD (TypeFamilyHead name _ _ _)) = [name]
namesBoundInDec (ClosedTypeFamilyD (TypeFamilyHead name _ _ _) _) = [name]
#elif __GLASGOW_HASKELL__ >= 612
namesBoundInDec (FamilyD _ name _ _) = [name]
#endif
#endif
namesBoundInDec _ = []
genericalizeName :: Name -> Name
genericalizeName = mkName . nameBase
genericalizeDecs :: [Dec] -> [Dec]
genericalizeDecs decs = everywhere (mkT fixName) decs
where
names = decs >>= namesBoundInDec
genericalizedNames = [ (n, genericalizeName n) | n <- names]
fixName = replace (`lookup` genericalizedNames)
headOfType :: Type -> Name
headOfType (ForallT _ _ ty) = headOfType ty
headOfType (VarT name) = name
headOfType (ConT name) = name
headOfType (TupleT n) = tupleTypeName n
headOfType ArrowT = ''(->)
headOfType ListT = ''[]
headOfType (AppT t _) = headOfType t
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
headOfType (SigT t _) = headOfType t
#endif
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 702
headOfType (UnboxedTupleT n) = unboxedTupleTypeName n
#endif
occursInType :: Name -> Type -> Bool
occursInType var ty = case ty of
ForallT bndrs _ ty
| any (var ==) (map nameOfBinder bndrs)
-> False
| otherwise
-> occursInType var ty
VarT name
| name == var -> True
| otherwise -> False
AppT ty1 ty2 -> occursInType var ty1 || occursInType var ty2
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ >= 612
SigT ty _ -> occursInType var ty
#endif
_ -> False