module TH.Utilities where
import Data.Proxy
import Data.Typeable
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
tyVarBndrName :: TyVarBndr -> Name
tyVarBndrName (PlainTV n) = n
tyVarBndrName (KindedTV n _) = n
unAppsT :: Type -> [Type]
unAppsT = go []
where
go xs (AppT l x) = go (x : xs) l
go xs ty = ty : xs
typeToNamedCon :: Type -> Maybe (Name, [Type])
#if MIN_VERSION_template_haskell(2,11,0)
typeToNamedCon (InfixT l n r) = Just (n, [l, r])
typeToNamedCon (UInfixT l n r) = Just (n, [l, r])
#endif
typeToNamedCon (unAppsT -> (ConT n : args)) = Just (n, args)
typeToNamedCon _ = Nothing
expectTyCon1 :: Name -> Type -> Q Type
expectTyCon1 expected (AppT (ConT n) x) | expected == n = return x
expectTyCon1 expected (AppT (PromotedT n) x) | expected == n = return x
expectTyCon1 expected x = fail $
"Expected " ++ pprint expected ++
", applied to one argument, but instead got " ++ pprint x ++ "."
expectTyCon2 :: Name -> Type -> Q (Type, Type)
expectTyCon2 expected (AppT (AppT (ConT n) x) y) | expected == n = return (x, y)
expectTyCon2 expected (AppT (AppT (PromotedT n) x) y) | expected == n = return (x, y)
#if MIN_VERSION_template_haskell(2,11,0)
expectTyCon2 expected (InfixT x n r) | expected == n = return (x, y)
expectTyCon2 expected (UInfixT x n r) | expected == n = return (x, y)
#endif
expectTyCon2 expected x = fail $
"Expected " ++ pprint expected ++
", applied to two arguments, but instead got " ++ pprint x ++ "."
proxyE :: TypeQ -> ExpQ
proxyE ty = [| Proxy :: Proxy $(ty) |]
dequalify :: Name -> Name
dequalify = mkName . nameBase
freeVarsT :: Type -> [Name]
freeVarsT (ForallT tvs _ ty) = filter (`notElem` (map tyVarBndrName tvs)) (freeVarsT ty)
freeVarsT (AppT l r) = freeVarsT l ++ freeVarsT r
freeVarsT (SigT ty k) = freeVarsT ty ++ freeVarsT k
freeVarsT (VarT n) = [n]
#if MIN_VERSION_template_haskell(2,11,0)
freeVarsT (InfixT x n r) = freeVarsT x ++ freeVarsT y
freeVarsT (UInfixT x n r) = freeVarsT x ++ freeVarsT y
#endif
freeVarsT _ = []
data ExpLifter = ExpLifter ExpQ deriving (Typeable)
instance Lift ExpLifter where
lift (ExpLifter e) = e