-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Singletons.TH.Deriving.Ord
-- Copyright   :  (C) 2015 Richard Eisenberg
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Ryan Scott
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Implements deriving of Ord instances
--
----------------------------------------------------------------------------

module Data.Singletons.TH.Deriving.Ord ( mkOrdInstance ) where

import Language.Haskell.TH.Desugar
import Language.Haskell.TH.Syntax
import Data.Singletons.TH.Deriving.Infer
import Data.Singletons.TH.Deriving.Util
import Data.Singletons.TH.Names
import Data.Singletons.TH.Syntax
import Data.Singletons.TH.Util

-- | Make a *non-singleton* Ord instance
mkOrdInstance :: DsMonad q => DerivDesc q
mkOrdInstance :: forall (q :: * -> *). DsMonad q => DerivDesc q
mkOrdInstance Maybe DCxt
mb_ctxt DType
ty (DataDecl DataFlavor
_ Name
_ [DTyVarBndrVis]
_ [DCon]
cons) = do
  constraints <- Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
forall (q :: * -> *).
DsMonad q =>
Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
inferConstraintsDef Maybe DCxt
mb_ctxt (Name -> DType
DConT Name
ordName) DType
ty [DCon]
cons
  compare_eq_clauses <- mapM mk_equal_clause cons
  let compare_noneq_clauses = (((DCon, Int), (DCon, Int)) -> DClause)
-> [((DCon, Int), (DCon, Int))] -> [DClause]
forall a b. (a -> b) -> [a] -> [b]
map (((DCon, Int) -> (DCon, Int) -> DClause)
-> ((DCon, Int), (DCon, Int)) -> DClause
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause)
                                  [ ((DCon, Int)
con1, (DCon, Int)
con2)
                                  | (DCon, Int)
con1 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [Int
1..]
                                  , (DCon, Int)
con2 <- [DCon] -> [Int] -> [(DCon, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [DCon]
cons [Int
1..]
                                  , DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con1) Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/=
                                    DCon -> Name
extractName ((DCon, Int) -> DCon
forall a b. (a, b) -> a
fst (DCon, Int)
con2) ]
      clauses | [DCon] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DCon]
cons = [DClause
mk_empty_clause]
              | Bool
otherwise = [DClause]
compare_eq_clauses [DClause] -> [DClause] -> [DClause]
forall a. [a] -> [a] -> [a]
++ [DClause]
compare_noneq_clauses
  return (InstDecl { id_cxt = constraints
                   , id_name = ordName
                   , id_arg_tys = [ty]
                   , id_sigs  = mempty
                   , id_meths = [(compareName, UFunction clauses)] })

mk_equal_clause :: Quasi q => DCon -> q DClause
mk_equal_clause :: forall (q :: * -> *). Quasi q => DCon -> q DClause
mk_equal_clause (DCon [DTyVarBndrSpec]
_tvbs DCxt
_cxt Name
name DConFields
fields DType
_rty) = do
  let tys :: DCxt
tys = DConFields -> DCxt
tysOfConFields DConFields
fields
  a_names <- (DType -> q Name) -> DCxt -> q [Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (q Name -> DType -> q Name
forall a b. a -> b -> a
const (q Name -> DType -> q Name) -> q Name -> DType -> q Name
forall a b. (a -> b) -> a -> b
$ String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"a") DCxt
tys
  b_names <- mapM (const $ newUniqueName "b") tys
  let pat1 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name [] ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
a_names)
      pat2 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name [] ((Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
b_names)
  return $ DClause [pat1, pat2] (DVarE foldlName `DAppE`
                                 DVarE sappendName `DAppE`
                                 DConE cmpEQName `DAppE`
                                 mkListE (zipWith
                                          (\Name
a Name
b -> Name -> DExp
DVarE Name
compareName DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
a
                                                                     DExp -> DExp -> DExp
`DAppE` Name -> DExp
DVarE Name
b)
                                          a_names b_names))

mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause :: (DCon, Int) -> (DCon, Int) -> DClause
mk_nonequal_clause (DCon [DTyVarBndrSpec]
_tvbs1 DCxt
_cxt1 Name
name1 DConFields
fields1 DType
_rty1, Int
n1)
                   (DCon [DTyVarBndrSpec]
_tvbs2 DCxt
_cxt2 Name
name2 DConFields
fields2 DType
_rty2, Int
n2) =
  [DPat] -> DExp -> DClause
DClause [DPat
pat1, DPat
pat2] (case Int
n1 Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
n2 of
                          Ordering
LT -> Name -> DExp
DConE Name
cmpLTName
                          Ordering
EQ -> Name -> DExp
DConE Name
cmpEQName
                          Ordering
GT -> Name -> DExp
DConE Name
cmpGTName)
  where
    pat1 :: DPat
pat1 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name1 [] ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields1))
    pat2 :: DPat
pat2 = Name -> DCxt -> [DPat] -> DPat
DConP Name
name2 [] ((DType -> DPat) -> DCxt -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map (DPat -> DType -> DPat
forall a b. a -> b -> a
const DPat
DWildP) (DConFields -> DCxt
tysOfConFields DConFields
fields2))

-- A variant of mk_equal_clause tailored to empty datatypes
mk_empty_clause :: DClause
mk_empty_clause :: DClause
mk_empty_clause = [DPat] -> DExp -> DClause
DClause [DPat
DWildP, DPat
DWildP] (Name -> DExp
DConE Name
cmpEQName)