{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
module Data.GADT.Compare.TH
( DeriveGEQ(..)
, DeriveGCompare(..)
, module Data.GADT.Compare.Monad
) where
import Control.Monad
import Control.Monad.Writer
import Data.GADT.TH.Internal
import Data.Functor.Identity
import Data.GADT.Compare
import Data.GADT.Compare.Monad
import Data.Type.Equality ((:~:) (..))
import qualified Data.Set as Set
import Data.Set (Set)
import qualified Data.Map as Map
import qualified Data.Map.Merge.Lazy as Map
import Data.Map (Map)
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
class DeriveGEQ t where
deriveGEq :: t -> Q [Dec]
instance DeriveGEQ Name where
deriveGEq :: Name -> Q [Dec]
deriveGEq Name
typeName = do
DatatypeInfo
typeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
typeName
let instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
typeInfo
paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
instTypes' :: Cxt
instTypes' = case forall a. [a] -> [a]
reverse Cxt
instTypes of
[] -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
(Type
_:Cxt
xs) -> forall a. [a] -> [a]
reverse Cxt
xs
instanceHead :: Type
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT ''GEq) (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
typeName) Cxt
instTypes')
([Clause]
clauses, Cxt
cxt) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT Cxt Q Clause
geqClause Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo))
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD forall a. Maybe a
Nothing Cxt
cxt Type
instanceHead [[Clause] -> Dec
geqFunction [Clause]
clauses]]
instance DeriveGEQ Dec where
deriveGEq :: Dec -> Q [Dec]
deriveGEq = Name -> (DatatypeInfo -> WriterT Cxt Q Dec) -> Dec -> Q [Dec]
deriveForDec ''GEq forall a b. (a -> b) -> a -> b
$ \DatatypeInfo
typeInfo -> do
let
instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
typeInfo
paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
[Clause]
clauses <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT Cxt Q Clause
geqClause Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Clause] -> Dec
geqFunction [Clause]
clauses
instance DeriveGEQ t => DeriveGEQ [t] where
deriveGEq :: [t] -> Q [Dec]
deriveGEq [t
it] = forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq t
it
deriveGEq [t]
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: [] instance only applies to single-element lists"
instance DeriveGEQ t => DeriveGEQ (Q t) where
deriveGEq :: Q t -> Q [Dec]
deriveGEq = (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq)
geqFunction :: [Clause] -> Dec
geqFunction :: [Clause] -> Dec
geqFunction [Clause]
clauses = Name -> [Clause] -> Dec
FunD 'geq forall a b. (a -> b) -> a -> b
$ [Clause]
clauses forall a. [a] -> [a] -> [a]
++ [ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP, Pat
WildP] (Exp -> Body
NormalB (Name -> Exp
ConE 'Nothing)) [] ]
geqClause :: Set Name -> ConstructorInfo -> WriterT Cxt Q Clause
geqClause :: Set Name -> ConstructorInfo -> WriterT Cxt Q Clause
geqClause Set Name
paramVars ConstructorInfo
con = do
let conName :: Name
conName = ConstructorInfo -> Name
constructorName ConstructorInfo
con
argTypes :: Cxt
argTypes = ConstructorInfo -> Cxt
constructorFields ConstructorInfo
con
conTyVars :: Set Name
conTyVars = forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName (ConstructorInfo -> [TyVarBndrUnit]
constructorVars ConstructorInfo
con))
needsGEq :: Type -> Bool
needsGEq Type
argType = Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> Bool
Set.null forall a b. (a -> b) -> a -> b
$
forall a. Ord a => Set a -> Set a -> Set a
Set.intersection (Type -> Set Name
freeTypeVariables Type
argType) (forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Name
paramVars Set Name
conTyVars)
[Name]
lArgNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Cxt
argTypes forall a b. (a -> b) -> a -> b
$ \Type
_ -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
[Name]
rArgNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Cxt
argTypes forall a b. (a -> b) -> a -> b
$ \Type
_ -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => String -> m Name
newName String
"y"
[Stmt]
stmts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Name]
lArgNames [Name]
rArgNames Cxt
argTypes) forall a b. (a -> b) -> a -> b
$ \(Name
l, Name
r, Type
t) -> do
case Type
t of
AppT Type
tyFun Type
tyArg | Type -> Bool
needsGEq Type
t -> do
[Dec]
u <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Set Name -> Name -> Cxt -> Q [Dec]
reifyInstancesWithRigids Set Name
paramVars ''GEq [Type
tyFun]
case [Dec]
u of
[] -> forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Type -> Type -> Type
AppT (Name -> Type
ConT ''GEq) Type
tyFun]
[(InstanceD Maybe Overlap
_ Cxt
cxt Type
_ [Dec]
_)] -> forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Cxt
cxt
[Dec]
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"More than one instance found for GEq (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a. Ppr a => a -> Doc
ppr Type
tyFun) forall a. [a] -> [a] -> [a]
++ String
"), and unsure what to do. Please report this."
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => m Pat -> m Exp -> m Stmt
bindS (forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP 'Refl []) [| geq $(varE l) $(varE r) |]
Type
_ -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS [| guard ($(varE l) == $(varE r)) |]
Stmt
ret <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS [| return Refl |]
[Pat]
pats <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
[ forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
conName (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
lArgNames)
, forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
conName (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
rArgNames)
]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
pats
(Exp -> Body
NormalB ([Stmt] -> Exp
doUnqualifiedE ([Stmt]
stmts forall a. [a] -> [a] -> [a]
++ [Stmt
ret])))
[]
class DeriveGCompare t where
deriveGCompare :: t -> Q [Dec]
instance DeriveGCompare Name where
deriveGCompare :: Name -> Q [Dec]
deriveGCompare Name
typeName = do
DatatypeInfo
typeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
typeName
let instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
typeInfo
paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
instTypes' :: Cxt
instTypes' = case forall a. [a] -> [a]
reverse Cxt
instTypes of
[] -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGCompare: Not enough type parameters"
(Type
_:Cxt
xs) -> forall a. [a] -> [a]
reverse Cxt
xs
instanceHead :: Type
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT ''GCompare) (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT Name
typeName) Cxt
instTypes')
([Clause]
clauses, Cxt
cxt) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT Cxt Q [Clause]
gcompareClauses Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo))
Dec
dec <- [Clause] -> Q Dec
gcompareFunction [Clause]
clauses
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> Cxt -> Type -> [Dec] -> Dec
InstanceD forall a. Maybe a
Nothing Cxt
cxt Type
instanceHead [Dec
dec]]
instance DeriveGCompare Dec where
deriveGCompare :: Dec -> Q [Dec]
deriveGCompare = Name -> (DatatypeInfo -> WriterT Cxt Q Dec) -> Dec -> Q [Dec]
deriveForDec ''GCompare forall a b. (a -> b) -> a -> b
$ \DatatypeInfo
typeInfo -> do
let
instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
typeInfo
paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
[[Clause]]
clauses <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Set Name -> ConstructorInfo -> WriterT Cxt Q [Clause]
gcompareClauses Set Name
paramVars) (DatatypeInfo -> [ConstructorInfo]
datatypeCons DatatypeInfo
typeInfo)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ [Clause] -> Q Dec
gcompareFunction (forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Clause]]
clauses)
instance DeriveGCompare t => DeriveGCompare [t] where
deriveGCompare :: [t] -> Q [Dec]
deriveGCompare [t
it] = forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare t
it
deriveGCompare [t]
_ = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGCompare: [] instance only applies to single-element lists"
instance DeriveGCompare t => DeriveGCompare (Q t) where
deriveGCompare :: Q t -> Q [Dec]
deriveGCompare = (forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare)
gcompareFunction :: [Clause] -> Q Dec
gcompareFunction :: [Clause] -> Q Dec
gcompareFunction [] = forall (m :: * -> *). Quote m => Name -> [m Clause] -> m Dec
funD 'gcompare [forall (m :: * -> *).
Quote m =>
[m Pat] -> m Body -> [m Dec] -> m Clause
clause [] (forall (m :: * -> *). Quote m => m Exp -> m Body
normalB [| \x y -> seq x (seq y undefined) |]) []]
gcompareFunction [Clause]
clauses = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Name -> [Clause] -> Dec
FunD 'gcompare [Clause]
clauses
gcompareClauses :: Set Name -> ConstructorInfo -> WriterT Cxt Q [Clause]
gcompareClauses :: Set Name -> ConstructorInfo -> WriterT Cxt Q [Clause]
gcompareClauses Set Name
paramVars ConstructorInfo
con = do
let conName :: Name
conName = ConstructorInfo -> Name
constructorName ConstructorInfo
con
argTypes :: Cxt
argTypes = ConstructorInfo -> Cxt
constructorFields ConstructorInfo
con
conTyVars :: Set Name
conTyVars = forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName (ConstructorInfo -> [TyVarBndrUnit]
constructorVars ConstructorInfo
con))
needsGCompare :: Type -> Bool
needsGCompare Type
argType = Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Set a -> Bool
Set.null forall a b. (a -> b) -> a -> b
$
forall a. Ord a => Set a -> Set a -> Set a
Set.intersection (Type -> Set Name
freeTypeVariables Type
argType) (forall a. Ord a => Set a -> Set a -> Set a
Set.union Set Name
paramVars Set Name
conTyVars)
[Name]
lArgNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Cxt
argTypes forall a b. (a -> b) -> a -> b
$ \Type
_ -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => String -> m Name
newName String
"x"
[Name]
rArgNames <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Cxt
argTypes forall a b. (a -> b) -> a -> b
$ \Type
_ -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => String -> m Name
newName String
"y"
[Stmt]
stmts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Name]
lArgNames [Name]
rArgNames Cxt
argTypes) forall a b. (a -> b) -> a -> b
$ \(Name
lArg, Name
rArg, Type
argType) ->
case Type
argType of
AppT Type
tyFun Type
tyArg | Type -> Bool
needsGCompare Type
argType -> do
[Dec]
u <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ Set Name -> Name -> Cxt -> Q [Dec]
reifyInstancesWithRigids Set Name
paramVars ''GCompare [Type
tyFun]
case [Dec]
u of
[] -> forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Type -> Type -> Type
AppT (Name -> Type
ConT ''GCompare) Type
tyFun]
[(InstanceD Maybe Overlap
_ Cxt
cxt Type
_ [Dec]
_)] -> forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Cxt
cxt
[Dec]
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"More than one instance of GCompare (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall a. Ppr a => a -> Doc
ppr Type
tyFun) forall a. [a] -> [a] -> [a]
++ String
") found, and unsure what to do. Please report this."
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => m Pat -> m Exp -> m Stmt
bindS (forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP 'Refl []) [| geq' $(varE lArg) $(varE rArg) |]
Type
_ -> forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS [| compare' $(varE lArg) $(varE rArg) |]
Stmt
ret <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). Quote m => m Exp -> m Stmt
noBindS [| return GEQ |]
[Pat]
pats <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence
[ forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
conName (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
lArgNames)
, forall (m :: * -> *). Quote m => Name -> [m Pat] -> m Pat
conP Name
conName (forall a b. (a -> b) -> [a] -> [b]
map forall (m :: * -> *). Quote m => Name -> m Pat
varP [Name]
rArgNames)
]
let main :: Clause
main = [Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
pats
(Exp -> Body
NormalB (Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'runGComparing) ([Stmt] -> Exp
doUnqualifiedE ([Stmt]
stmts forall a. [a] -> [a] -> [a]
++ [Stmt
ret]))))
[]
lt :: Clause
lt = [Pat] -> Body -> [Dec] -> Clause
Clause [Name -> [FieldPat] -> Pat
RecP Name
conName [], Pat
WildP] (Exp -> Body
NormalB (Name -> Exp
ConE 'GLT)) []
gt :: Clause
gt = [Pat] -> Body -> [Dec] -> Clause
Clause [Pat
WildP, Name -> [FieldPat] -> Pat
RecP Name
conName []] (Exp -> Body
NormalB (Name -> Exp
ConE 'GGT)) []
forall (m :: * -> *) a. Monad m => a -> m a
return [Clause
main, Clause
lt, Clause
gt]
#if MIN_VERSION_template_haskell(2,17,0)
doUnqualifiedE :: [Stmt] -> Exp
doUnqualifiedE = Maybe ModName -> [Stmt] -> Exp
DoE forall a. Maybe a
Nothing
#else
doUnqualifiedE = DoE
#endif