{-# Language CPP, DeriveGeneric, DeriveDataTypeable #-}
module Language.Haskell.TH.Datatype
(
DatatypeInfo(..)
, ConstructorInfo(..)
, DatatypeVariant(..)
, ConstructorVariant(..)
, reifyDatatype
, normalizeInfo
, normalizeDec
, normalizeCon
, TypeSubstitution(..)
, quantifyType
, freshenFreeVariables
, equalPred
, classPred
, resolveTypeSynonyms
, unifyTypes
, tvName
, datatypeType
) where
import Data.Data (Typeable, Data)
import Data.Foldable (foldMap, foldl')
import Data.List (union, (\\))
import Data.Map (Map)
import qualified Data.Map as Map
import Control.Monad (foldM)
import GHC.Generics (Generic)
import Language.Haskell.TH
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative (Applicative(..), (<$>))
import Data.Traversable (traverse, sequenceA)
#endif
data DatatypeInfo = DatatypeInfo
{ datatypeContext :: Cxt
, datatypeName :: Name
, datatypeVars :: [Type]
, datatypeVariant :: DatatypeVariant
, datatypeCons :: [ConstructorInfo]
}
deriving (Show, Eq, Typeable, Data, Generic)
data DatatypeVariant
= Datatype
| Newtype
| DataInstance
| NewtypeInstance
deriving (Show, Read, Eq, Ord, Typeable, Data, Generic)
data ConstructorInfo = ConstructorInfo
{ constructorName :: Name
, constructorVars :: [TyVarBndr]
, constructorContext :: Cxt
, constructorFields :: [Type]
, constructorVariant :: ConstructorVariant
}
deriving (Show, Eq, Typeable, Data, Generic)
data ConstructorVariant
= NormalConstructor
| RecordConstructor [Name]
deriving (Show, Eq, Ord, Typeable, Data, Generic)
datatypeType :: DatatypeInfo -> Type
datatypeType di
= foldl AppT (ConT (datatypeName di))
$ datatypeVars di
reifyDatatype ::
Name ->
Q DatatypeInfo
reifyDatatype n = normalizeInfo =<< reify n
normalizeInfo :: Info -> Q DatatypeInfo
normalizeInfo (TyConI dec) = normalizeDec dec
normalizeInfo _ = fail "reifyDatatype: Expected a type constructor"
normalizeDec :: Dec -> Q DatatypeInfo
#if MIN_VERSION_template_haskell(2,12,0)
normalizeDec (NewtypeD context name tyvars _kind con _derives) =
normalizeDec' context name (bndrParams tyvars) [con] Newtype
normalizeDec (DataD context name tyvars _kind cons _derives) =
normalizeDec' context name (bndrParams tyvars) cons Datatype
normalizeDec (NewtypeInstD context name params _kind con _derives) =
normalizeDec' context name params [con] NewtypeInstance
normalizeDec (DataInstD context name params _kind cons _derives) =
normalizeDec' context name params cons DataInstance
#elif MIN_VERSION_template_haskell(2,11,0)
normalizeDec (NewtypeD context name tyvars _kind con _derives) =
normalizeDec' context name (bndrParams tyvars) [con] Newtype
normalizeDec (DataD context name tyvars _kind cons _derives) =
normalizeDec' context name (bndrParams tyvars) cons Datatype
normalizeDec (NewtypeInstD context name params _kind con _derives) =
normalizeDec' context name params [con] NewtypeInstance
normalizeDec (DataInstD context name params _kind cons _derives) =
normalizeDec' context name params cons DataInstance
#else
normalizeDec (NewtypeD context name tyvars con _derives) =
normalizeDec' context name (bndrParams tyvars) [con] Newtype
normalizeDec (DataD context name tyvars cons _derives) =
normalizeDec' context name (bndrParams tyvars) cons Datatype
normalizeDec (NewtypeInstD context name params con _derives) =
normalizeDec' context name params [con] NewtypeInstance
normalizeDec (DataInstD context name params cons _derives) =
normalizeDec' context name params cons DataInstance
#endif
normalizeDec _ = fail "reifyDatatype: DataD or NewtypeD required"
bndrParams :: [TyVarBndr] -> [Type]
bndrParams = map (VarT . tvName)
normalizeDec' ::
Cxt ->
Name ->
[Type] ->
[Con] ->
DatatypeVariant ->
Q DatatypeInfo
normalizeDec' context name params cons variant =
do let vs = freeVariables params
cons' <- concat <$> traverse (normalizeCon name vs) cons
pure DatatypeInfo
{ datatypeContext = context
, datatypeName = name
, datatypeVars = params
, datatypeCons = cons'
, datatypeVariant = variant
}
normalizeCon ::
Name ->
[Name] ->
Con ->
Q [ConstructorInfo]
normalizeCon typename vars = go [] []
where
go tyvars context c =
case c of
NormalC n xs ->
pure [ConstructorInfo n tyvars context (map snd xs) NormalConstructor]
InfixC l n r ->
pure [ConstructorInfo n tyvars context [snd l,snd r] NormalConstructor]
RecC n xs ->
let fns = takeFieldNames xs in
pure [ConstructorInfo n tyvars context
(takeFieldTypes xs) (RecordConstructor fns)]
ForallC tyvars' context' c' ->
go (tyvars'++tyvars) (context'++context) c'
#if MIN_VERSION_template_haskell(2,11,0)
GadtC ns xs innerType ->
gadtCase ns innerType (map snd xs) NormalConstructor
RecGadtC ns xs innerType ->
let fns = takeFieldNames xs in
gadtCase ns innerType (takeFieldTypes xs) (RecordConstructor fns)
where
gadtCase = normalizeGadtC typename vars tyvars context
normalizeGadtC ::
Name ->
[Name] ->
[TyVarBndr] ->
Cxt ->
[Name] ->
Type ->
[Type] ->
ConstructorVariant ->
Q [ConstructorInfo]
normalizeGadtC typename vars tyvars context names innerType fields variant =
do innerType' <- resolveTypeSynonyms innerType
case decomposeType innerType' of
ConT innerTyCon :| ts | typename == innerTyCon ->
let (substName, context1) = mergeArguments vars ts
subst = VarT <$> substName
tyvars' = [ tv | tv <- tyvars, Map.notMember (tvName tv) subst ]
context2 = applySubstitution subst (context1 ++ context)
fields' = applySubstitution subst fields
in pure [ConstructorInfo name tyvars' context2 fields' variant
| name <- names]
_ -> fail "normalizeGadtC: Expected type constructor application"
mergeArguments :: [Name] -> [Type] -> (Map Name Name, Cxt)
mergeArguments ns ts = foldl' aux (Map.empty, []) (zip ns ts)
where
aux (subst, context) (n,p) =
case p of
VarT m | Map.notMember m subst -> (Map.insert m n subst, context)
_ -> (subst, EqualityT `AppT` VarT n `AppT` p : context)
#endif
resolveTypeSynonyms :: Type -> Q Type
resolveTypeSynonyms t =
let f :| xs = decomposeType t
notTypeSynCase = foldl AppT f <$> traverse resolveTypeSynonyms xs
in case f of
ConT n ->
do info <- reify n
case info of
TyConI (TySynD _ synvars def) ->
let argNames = map tvName synvars
(args,rest) = splitAt (length argNames) xs
subst = Map.fromList (zip argNames args)
t' = foldl AppT (applySubstitution subst def) rest
in resolveTypeSynonyms t'
_ -> notTypeSynCase
_ -> notTypeSynCase
decomposeType :: Type -> NonEmpty Type
decomposeType = reverseNonEmpty . go
where
go (AppT f x ) = x <| go f
#if MIN_VERSION_template_haskell(2,11,0)
go (InfixT l f r) = ConT f :| [l,r]
go (UInfixT l f r) = ConT f :| [l,r]
go (ParensT t ) = decomposeType t
#endif
go t = t :| []
tvName :: TyVarBndr -> Name
tvName (PlainTV name ) = name
tvName (KindedTV name _) = name
takeFieldNames :: [(Name,a,b)] -> [Name]
takeFieldNames xs = [a | (a,_,_) <- xs]
takeFieldTypes :: [(a,b,Type)] -> [Type]
takeFieldTypes xs = [a | (_,_,a) <- xs]
quantifyType :: Type -> Type
quantifyType t
| null vs = t
| otherwise = ForallT (PlainTV <$> vs) [] t
where
vs = freeVariables t
freshenFreeVariables :: Type -> Q Type
freshenFreeVariables t =
do let xs = [ (n, VarT <$> newName (nameBase n)) | n <- freeVariables t]
subst <- sequenceA (Map.fromList xs)
return (applySubstitution subst t)
class TypeSubstitution a where
applySubstitution :: Map Name Type -> a -> a
freeVariables :: a -> [Name]
instance TypeSubstitution a => TypeSubstitution [a] where
freeVariables = foldMap freeVariables
applySubstitution = fmap . applySubstitution
instance TypeSubstitution Type where
applySubstitution subst = go
where
go (ForallT tvs context t) =
let subst' = foldl' (flip Map.delete) subst (map tvName tvs) in
ForallT tvs (applySubstitution subst' context)
(applySubstitution subst' t)
go (AppT f x) = AppT (go f) (go x)
go (SigT t k) = SigT (go t) (applySubstitution subst k)
go (VarT v) = Map.findWithDefault (VarT v) v subst
#if MIN_VERSION_template_haskell(2,11,0)
go (InfixT l c r) = InfixT (go l) c (go r)
go (UInfixT l c r) = UInfixT (go l) c (go r)
go (ParensT t) = ParensT (go t)
#endif
go t = t
freeVariables t =
case t of
ForallT tvs context t' ->
(freeVariables context `union` freeVariables t')
\\ map tvName tvs
AppT f x -> freeVariables f `union` freeVariables x
SigT t' k -> freeVariables t' `union` freeVariables k
VarT v -> [v]
#if MIN_VERSION_template_haskell(2,11,0)
InfixT l _ r -> freeVariables l `union` freeVariables r
UInfixT l _ r -> freeVariables l `union` freeVariables r
ParensT t' -> freeVariables t'
#endif
_ -> []
instance TypeSubstitution ConstructorInfo where
freeVariables ci =
(freeVariables (constructorContext ci) `union`
freeVariables (constructorFields ci))
\\ (tvName <$> constructorVars ci)
applySubstitution subst ci =
let subst' = foldl' (flip Map.delete) subst (map tvName (constructorVars ci)) in
ci { constructorContext = applySubstitution subst' (constructorContext ci)
, constructorFields = applySubstitution subst' (constructorFields ci)
}
#if !MIN_VERSION_template_haskell(2,10,0)
instance TypeSubstitution Pred where
freeVariables (ClassP _ xs) = freeVariables xs
freeVariables (EqualP x y) = freeVariables x `union` freeVariables y
applySubstitution p (ClassP n xs) = ClassP n (applySubstitution p xs)
applySubstitution p (EqualP x y) = EqualP (applySubstitution p x)
(applySubstitution p y)
#endif
#if !MIN_VERSION_template_haskell(2,8,0)
instance TypeSubstitution Kind where
freeVariables _ = []
applySubstitution _ k = k
#endif
combineSubstitutions :: Map Name Type -> Map Name Type -> Map Name Type
combineSubstitutions x y = Map.union (fmap (applySubstitution y) x) y
unifyTypes :: [Type] -> Q (Map Name Type)
unifyTypes [] = pure Map.empty
unifyTypes (t:ts) =
do t':ts' <- traverse resolveTypeSynonyms (t:ts)
let aux sub u =
do sub' <- unify' (applySubstitution sub t')
(applySubstitution sub u)
return (combineSubstitutions sub sub')
case foldM aux Map.empty ts' of
Right m -> return m
Left (x,y) ->
fail $ showString "Unable to unify types "
. showsPrec 11 x
. showString " and "
. showsPrec 11 y
$ ""
unify' :: Type -> Type -> Either (Type,Type) (Map Name Type)
unify' (VarT n) (VarT m) | n == m = pure Map.empty
unify' (VarT n) t | n `elem` freeVariables t = Left (VarT n, t)
| otherwise = pure (Map.singleton n t)
unify' t (VarT n) | n `elem` freeVariables t = Left (VarT n, t)
| otherwise = pure (Map.singleton n t)
unify' (ConT n) (ConT m) | n == m = pure Map.empty
unify' (AppT f1 x1) (AppT f2 x2) =
do sub1 <- unify' f1 f2
sub2 <- unify' (applySubstitution sub1 x1) (applySubstitution sub1 x2)
return (combineSubstitutions sub1 sub2)
unify' (TupleT n) (TupleT m) | n == m = pure Map.empty
unify' t u = Left (t,u)
equalPred :: Type -> Type -> Pred
equalPred x y =
#if MIN_VERSION_template_haskell(2,10,0)
AppT (AppT EqualityT x) y
#else
EqualP x y
#endif
classPred :: Name -> [Type] -> Pred
classPred =
#if MIN_VERSION_template_haskell(2,10,0)
foldl AppT . ConT
#else
ClassP
#endif
data NonEmpty a = a :| [a]
(<|) :: a -> NonEmpty a -> NonEmpty a
x <| (y :| ys) = x :| (y : ys)
reverseNonEmpty :: NonEmpty a -> NonEmpty a
reverseNonEmpty (x :| xs) = y :| ys
where y:ys = reverse (x:xs)