-----------------------------------------------------------------------------
-- |
-- Module      :  Data.Singletons.TH.Deriving.Eq
-- Copyright   :  (C) 2020 Ryan Scott
-- License     :  BSD-style (see LICENSE)
-- Maintainer  :  Ryan Scott
-- Stability   :  experimental
-- Portability :  non-portable
--
-- Implements deriving of Eq instances
--
----------------------------------------------------------------------------
module Data.Singletons.TH.Deriving.Eq (mkEqInstance) where

import Control.Monad
import Data.Singletons.TH.Deriving.Infer
import Data.Singletons.TH.Deriving.Util
import Data.Singletons.TH.Names
import Data.Singletons.TH.Syntax
import Data.Singletons.TH.Util
import Language.Haskell.TH.Desugar
import Language.Haskell.TH.Syntax

mkEqInstance :: DsMonad q => DerivDesc q
mkEqInstance :: forall (q :: * -> *). DsMonad q => DerivDesc q
mkEqInstance Maybe DCxt
mb_ctxt DType
ty (DataDecl Name
_ [DTyVarBndrUnit]
_ [DCon]
cons) = do
  let con_pairs :: [(DCon, DCon)]
con_pairs = [ (DCon
c1, DCon
c2) | DCon
c1 <- [DCon]
cons, DCon
c2 <- [DCon]
cons ]
  DCxt
constraints <- Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
forall (q :: * -> *).
DsMonad q =>
Maybe DCxt -> DType -> DType -> [DCon] -> q DCxt
inferConstraintsDef Maybe DCxt
mb_ctxt (Name -> DType
DConT Name
eqName) DType
ty [DCon]
cons
  [DClause]
clauses <- if [DCon] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [DCon]
cons
             then [DClause] -> q [DClause]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [[DPat] -> DExp -> DClause
DClause [DPat
DWildP, DPat
DWildP] (Name -> DExp
DConE Name
trueName)]
             else ((DCon, DCon) -> q DClause) -> [(DCon, DCon)] -> q [DClause]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (DCon, DCon) -> q DClause
forall (q :: * -> *). Quasi q => (DCon, DCon) -> q DClause
mkEqClause [(DCon, DCon)]
con_pairs
  UInstDecl -> q UInstDecl
forall (f :: * -> *) a. Applicative f => a -> f a
pure (InstDecl :: forall (ann :: AnnotationFlag).
DCxt
-> Name
-> DCxt
-> OMap Name DType
-> [(Name, LetDecRHS ann)]
-> InstDecl ann
InstDecl { id_cxt :: DCxt
id_cxt = DCxt
constraints
                 , id_name :: Name
id_name = Name
eqName
                 , id_arg_tys :: DCxt
id_arg_tys = [DType
ty]
                 , id_sigs :: OMap Name DType
id_sigs  = OMap Name DType
forall a. Monoid a => a
mempty
                 , id_meths :: [(Name, LetDecRHS Unannotated)]
id_meths = [(Name
equalsName, [DClause] -> LetDecRHS Unannotated
UFunction [DClause]
clauses)] })

mkEqClause :: Quasi q => (DCon, DCon) -> q DClause
mkEqClause :: forall (q :: * -> *). Quasi q => (DCon, DCon) -> q DClause
mkEqClause (DCon
c1, DCon
c2)
  | Name
lname Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
== Name
rname = do
      [Name]
lnames <- Int -> q Name -> q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
lNumArgs (String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"a")
      [Name]
rnames <- Int -> q Name -> q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
lNumArgs (String -> q Name
forall (q :: * -> *). Quasi q => String -> q Name
newUniqueName String
"b")
      let lpats :: [DPat]
lpats = (Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
lnames
          rpats :: [DPat]
rpats = (Name -> DPat) -> [Name] -> [DPat]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DPat
DVarP [Name]
rnames
          lvars :: [DExp]
lvars = (Name -> DExp) -> [Name] -> [DExp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DExp
DVarE [Name]
lnames
          rvars :: [DExp]
rvars = (Name -> DExp) -> [Name] -> [DExp]
forall a b. (a -> b) -> [a] -> [b]
map Name -> DExp
DVarE [Name]
rnames
      DClause -> q DClause
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DClause -> q DClause) -> DClause -> q DClause
forall a b. (a -> b) -> a -> b
$ [DPat] -> DExp -> DClause
DClause
        [Name -> [DPat] -> DPat
DConP Name
lname [DPat]
lpats, Name -> [DPat] -> DPat
DConP Name
rname [DPat]
rpats]
        ([DExp] -> DExp
andExp ((DExp -> DExp -> DExp) -> [DExp] -> [DExp] -> [DExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\DExp
l DExp
r -> DExp -> [DExp] -> DExp
foldExp (Name -> DExp
DVarE Name
equalsName) [DExp
l, DExp
r])
                         [DExp]
lvars [DExp]
rvars))
  | Bool
otherwise =
      DClause -> q DClause
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DClause -> q DClause) -> DClause -> q DClause
forall a b. (a -> b) -> a -> b
$ [DPat] -> DExp -> DClause
DClause
        [Name -> [DPat] -> DPat
DConP Name
lname (Int -> DPat -> [DPat]
forall a. Int -> a -> [a]
replicate Int
lNumArgs DPat
DWildP),
         Name -> [DPat] -> DPat
DConP Name
rname (Int -> DPat -> [DPat]
forall a. Int -> a -> [a]
replicate Int
rNumArgs DPat
DWildP)]
        (Name -> DExp
DConE Name
falseName)
  where
    andExp :: [DExp] -> DExp
    andExp :: [DExp] -> DExp
andExp []    = Name -> DExp
DConE Name
trueName
    andExp [DExp
one] = DExp
one
    andExp (DExp
h:[DExp]
t) = Name -> DExp
DVarE Name
andName DExp -> DExp -> DExp
`DAppE` DExp
h DExp -> DExp -> DExp
`DAppE` [DExp] -> DExp
andExp [DExp]
t

    (Name
lname, Int
lNumArgs) = DCon -> (Name, Int)
extractNameArgs DCon
c1
    (Name
rname, Int
rNumArgs) = DCon -> (Name, Int)
extractNameArgs DCon
c2