module Data.Singletons.TH.Deriving.Ord ( mkOrdInstance ) where
import Language.Haskell.TH.Desugar
import Language.Haskell.TH.Syntax
import Data.Singletons.TH.Deriving.Infer
import Data.Singletons.TH.Deriving.Util
import Data.Singletons.TH.Names
import Data.Singletons.TH.Syntax
import Data.Singletons.TH.Util
mkOrdInstance :: DsMonad q => DerivDesc q
mkOrdInstance :: forall (q :: * -> *). DsMonad q => DerivDesc q
mkOrdInstance Maybe DCxt
mb_ctxt DType
ty (DataDecl DataFlavor
_ Name
_ [DTyVarBndrVis]
_ [DCon]
cons) = do
constraints <- Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
forall (q :: * -> *).
DsMonad q =>
Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
inferConstraintsDef Maybe DCxt
mb_ctxt (Name -> DType
DConT Name
ordName) DType
ty [DCon]
cons
compare_eq_clauses <- mapM mk_equal_clause cons
let compare_noneq_clauses = (((DCon, Int), (DCon, Int)) -> DClause)
-> [((DCon, Int), (DCon, Int))] -> [DClause]
forall a b. (a -> b) -> [a] -> [b]
map (((DCon, Int) -> (DCon, Int) -> DClause)
-> ((DCon, Int), (DCon, Int)) -> DClause
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause)
[ ((DCon, Int)
con1, (DCon, Int)
con2)
| (DCon, Int)
con1 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [Int
1..]
, (DCon, Int)
con2 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [Int
1..]
, DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con1) Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/=
DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con2) ]
clauses | [DCon] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DCon]
cons = [DClause
mk_empty_clause]
| Bool
otherwise = [DClause]
compare_eq_clauses [DClause] -> [DClause] -> [DClause]
forall a. [a] -> [a] -> [a]
++ [DClause]
compare_noneq_clauses
return (InstDecl { id_cxt = constraints
, id_name = ordName
, id_arg_tys = [ty]
, id_sigs = mempty
, id_meths = [(compareName, UFunction clauses)] })
mk_equal_clause :: Quasi q => DCon -> q DClause
mk_equal_clause :: forall (q :: * -> *). Quasi q => DCon -> q DClause
mk_equal_clause (DCon [DTyVarBndrSpec]
_tvbs DCxt
_cxt Name
name DConFields
fields DType
_rty) = do
let tys :: DCxt
tys = DConFields -> DCxt
tysOfConFields DConFields
fields
a_names <- (DType -> q Name) -> DCxt -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (q Name -> DType -> q Name
forall a b. a -> b -> a
const (q Name -> DType -> q Name) -> q Name -> DType -> q Name
forall a b. (a -> b) -> a -> b
$ String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"a") DCxt
tys
b_names <- mapM (const $ newUniqueName "b") tys
let pat1 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name [] ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
a_names)
pat2 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name [] ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
b_names)
return $ DClause [pat1, pat2] (DVarE foldlName `DAppE`
DVarE sappendName `DAppE`
DConE cmpEQName `DAppE`
mkListE (zipWith
(\Name
a Name
b -> Name -> DExp
DVarE Name
compareName DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
a
DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
b)
a_names b_names))
mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause (DCon [DTyVarBndrSpec]
_tvbs1 DCxt
_cxt1 Name
name1 DConFields
fields1 DType
_rty1, Int
n1)
(DCon [DTyVarBndrSpec]
_tvbs2 DCxt
_cxt2 Name
name2 DConFields
fields2 DType
_rty2, Int
n2) =
[DPat] -> DExp -> DClause
DClause [DPat
pat1, DPat
pat2] (case Int
n1 Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
n2 of
Ordering
LT -> Name -> DExp
DConE Name
cmpLTName
Ordering
EQ -> Name -> DExp
DConE Name
cmpEQName
Ordering
GT -> Name -> DExp
DConE Name
cmpGTName)
where
pat1 :: DPat
pat1 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name1 [] ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields1))
pat2 :: DPat
pat2 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name2 [] ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields2))
mk_empty_clause :: DClause
mk_empty_clause :: DClause
mk_empty_clause = [DPat] -> DExp -> DClause
DClause [DPat
DWildP, DPat
DWildP] (Name -> DExp
DConE Name
cmpEQName)