{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
module Data.GADT.Compare.TH
    ( DeriveGEQ(..)
    , DeriveGCompare(..)
    , GComparing, runGComparing, geq', compare'
    ) where

import Control.Applicative
import Control.Monad
import Data.Dependent.Sum
import Data.Dependent.Sum.TH.Internal
import Data.Functor.Identity
import Data.GADT.Compare
import Data.Traversable (for)
import Data.Type.Equality ((:~:) (..))
import Language.Haskell.TH
import Language.Haskell.TH.Extras

-- A type class purely for overloading purposes
class DeriveGEQ t where
    deriveGEq :: t -> Q [Dec]

instance DeriveGEQ Name where
    deriveGEq :: Name -> Q [Dec]
deriveGEq typeName :: Name
typeName = do
        Info
typeInfo <- Name -> Q Info
reify Name
typeName
        case Info
typeInfo of
            TyConI dec :: Dec
dec -> Dec -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq Dec
dec
            _ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "deriveGEq: the name of a type constructor is required"

instance DeriveGEQ Dec where
    deriveGEq :: Dec -> Q [Dec]
deriveGEq = Name
-> (Q Type -> Q Type)
-> ([TyVarBndr] -> [Con] -> Q Dec)
-> Dec
-> Q [Dec]
deriveForDec ''GEq (\t :: Q Type
t -> [t| GEq $t |]) [TyVarBndr] -> [Con] -> Q Dec
geqFunction

instance DeriveGEQ t => DeriveGEQ [t] where
    deriveGEq :: [t] -> Q [Dec]
deriveGEq [it :: t
it] = t -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq t
it
    deriveGEq _ = String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "deriveGEq: [] instance only applies to single-element lists"

instance DeriveGEQ t => DeriveGEQ (Q t) where
    deriveGEq :: Q t -> Q [Dec]
deriveGEq = (Q t -> (t -> Q [Dec]) -> Q [Dec]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> Q [Dec]
forall t. DeriveGEQ t => t -> Q [Dec]
deriveGEq)

geqFunction :: [TyVarBndr] -> [Con] -> Q Dec
geqFunction bndrs :: [TyVarBndr]
bndrs cons :: [Con]
cons = Name -> [ClauseQ] -> Q Dec
funD 'geq
    (  (Con -> ClauseQ) -> [Con] -> [ClauseQ]
forall a b. (a -> b) -> [a] -> [b]
map ([TyVarBndr] -> Con -> ClauseQ
geqClause [TyVarBndr]
bndrs) [Con]
cons
    [ClauseQ] -> [ClauseQ] -> [ClauseQ]
forall a. [a] -> [a] -> [a]
++  [ [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [PatQ
wildP, PatQ
wildP] (ExpQ -> BodyQ
normalB [| Nothing |]) []
        | [Con] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Con]
cons Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= 1
        ]
    )

geqClause :: [TyVarBndr] -> Con -> ClauseQ
geqClause bndrs :: [TyVarBndr]
bndrs con :: Con
con = do
    let argTypes :: [Type]
argTypes = Con -> [Type]
argTypesOfCon Con
con
        needsGEq :: Type -> Bool
needsGEq argType :: Type
argType = (TyVarBndr -> Bool) -> [TyVarBndr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Name -> Type -> Bool
`occursInType` Type
argType) (Name -> Bool) -> (TyVarBndr -> Name) -> TyVarBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr -> Name
nameOfBinder) ([TyVarBndr]
bndrs [TyVarBndr] -> [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a] -> [a]
++ Con -> [TyVarBndr]
varsBoundInCon Con
con)

        nArgs :: Int
nArgs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes
    [Name]
lArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName "x")
    [Name]
rArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName "y")

    [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [ Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
lArgNames)
           , Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
rArgNames)
           ]
        ( ExpQ -> BodyQ
normalB (ExpQ -> BodyQ) -> ExpQ -> BodyQ
forall a b. (a -> b) -> a -> b
$ [StmtQ] -> ExpQ
doE
            (  [ if Type -> Bool
needsGEq Type
argType
                    then PatQ -> ExpQ -> StmtQ
bindS (Name -> [PatQ] -> PatQ
conP 'Refl []) [| geq $(varE lArg) $(varE rArg) |]
                    else ExpQ -> StmtQ
noBindS [| guard ($(varE lArg) == $(varE rArg)) |]
               | (lArg :: Name
lArg, rArg :: Name
rArg, argType :: Type
argType) <- [Name] -> [Name] -> [Type] -> [(Name, Name, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Name]
lArgNames [Name]
rArgNames [Type]
argTypes
               ]
            [StmtQ] -> [StmtQ] -> [StmtQ]
forall a. [a] -> [a] -> [a]
++ [ ExpQ -> StmtQ
noBindS [| return Refl |] ]
            )
        ) []
    where conName :: Name
conName = Con -> Name
nameOfCon Con
con

-- A monad allowing gcompare to be defined in the same style as geq
newtype GComparing a b t = GComparing (Either (GOrdering a b) t)

instance Functor (GComparing a b) where fmap :: (a -> b) -> GComparing a b a -> GComparing a b b
fmap f :: a -> b
f (GComparing x :: Either (GOrdering a b) a
x) = Either (GOrdering a b) b -> GComparing a b b
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing ((GOrdering a b -> Either (GOrdering a b) b)
-> (a -> Either (GOrdering a b) b)
-> Either (GOrdering a b) a
-> Either (GOrdering a b) b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either GOrdering a b -> Either (GOrdering a b) b
forall a b. a -> Either a b
Left (b -> Either (GOrdering a b) b
forall a b. b -> Either a b
Right (b -> Either (GOrdering a b) b)
-> (a -> b) -> a -> Either (GOrdering a b) b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f) Either (GOrdering a b) a
x)
instance Monad (GComparing a b) where
    return :: a -> GComparing a b a
return = Either (GOrdering a b) a -> GComparing a b a
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (Either (GOrdering a b) a -> GComparing a b a)
-> (a -> Either (GOrdering a b) a) -> a -> GComparing a b a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Either (GOrdering a b) a
forall a b. b -> Either a b
Right
    GComparing (Left  x :: GOrdering a b
x) >>= :: GComparing a b a -> (a -> GComparing a b b) -> GComparing a b b
>>= f :: a -> GComparing a b b
f = Either (GOrdering a b) b -> GComparing a b b
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (GOrdering a b -> Either (GOrdering a b) b
forall a b. a -> Either a b
Left GOrdering a b
x)
    GComparing (Right x :: a
x) >>= f :: a -> GComparing a b b
f = a -> GComparing a b b
f a
x
instance Applicative (GComparing a b) where
    pure :: a -> GComparing a b a
pure = a -> GComparing a b a
forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: GComparing a b (a -> b) -> GComparing a b a -> GComparing a b b
(<*>) = GComparing a b (a -> b) -> GComparing a b a -> GComparing a b b
forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

geq' :: GCompare t => t a -> t b -> GComparing x y (a :~: b)
geq' :: t a -> t b -> GComparing x y (a :~: b)
geq' x :: t a
x y :: t b
y = Either (GOrdering x y) (a :~: b) -> GComparing x y (a :~: b)
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (case t a -> t b -> GOrdering a b
forall k (f :: k -> *) (a :: k) (b :: k).
GCompare f =>
f a -> f b -> GOrdering a b
gcompare t a
x t b
y of
    GLT -> GOrdering x y -> Either (GOrdering x y) (a :~: b)
forall a b. a -> Either a b
Left GOrdering x y
forall k (a :: k) (b :: k). GOrdering a b
GLT
    GEQ -> (a :~: a) -> Either (GOrdering x y) (a :~: a)
forall a b. b -> Either a b
Right a :~: a
forall k (a :: k). a :~: a
Refl
    GGT -> GOrdering x y -> Either (GOrdering x y) (a :~: b)
forall a b. a -> Either a b
Left GOrdering x y
forall k (a :: k) (b :: k). GOrdering a b
GGT)

compare' :: a -> a -> GComparing a b ()
compare' x :: a
x y :: a
y = Either (GOrdering a b) () -> GComparing a b ()
forall k (a :: k) (b :: k) t.
Either (GOrdering a b) t -> GComparing a b t
GComparing (Either (GOrdering a b) () -> GComparing a b ())
-> Either (GOrdering a b) () -> GComparing a b ()
forall a b. (a -> b) -> a -> b
$ case a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
compare a
x a
y of
    LT -> GOrdering a b -> Either (GOrdering a b) ()
forall a b. a -> Either a b
Left GOrdering a b
forall k (a :: k) (b :: k). GOrdering a b
GLT
    EQ -> () -> Either (GOrdering a b) ()
forall a b. b -> Either a b
Right ()
    GT -> GOrdering a b -> Either (GOrdering a b) ()
forall a b. a -> Either a b
Left GOrdering a b
forall k (a :: k) (b :: k). GOrdering a b
GGT

runGComparing :: GComparing a b (GOrdering a b) -> GOrdering a b
runGComparing (GComparing x :: Either (GOrdering a b) (GOrdering a b)
x) = (GOrdering a b -> GOrdering a b)
-> (GOrdering a b -> GOrdering a b)
-> Either (GOrdering a b) (GOrdering a b)
-> GOrdering a b
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either GOrdering a b -> GOrdering a b
forall a. a -> a
id GOrdering a b -> GOrdering a b
forall a. a -> a
id Either (GOrdering a b) (GOrdering a b)
x

class DeriveGCompare t where
    deriveGCompare :: t -> Q [Dec]

instance DeriveGCompare Name where
    deriveGCompare :: Name -> Q [Dec]
deriveGCompare typeName :: Name
typeName = do
        Info
typeInfo <- Name -> Q Info
reify Name
typeName
        case Info
typeInfo of
            TyConI dec :: Dec
dec -> Dec -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare Dec
dec
            _ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "deriveGCompare: the name of a type constructor is required"

instance DeriveGCompare Dec where
    deriveGCompare :: Dec -> Q [Dec]
deriveGCompare = Name
-> (Q Type -> Q Type)
-> ([TyVarBndr] -> [Con] -> Q Dec)
-> Dec
-> Q [Dec]
deriveForDec ''GCompare (\t :: Q Type
t -> [t| GCompare $t |]) [TyVarBndr] -> [Con] -> Q Dec
forall (t :: * -> *). Foldable t => [TyVarBndr] -> t Con -> Q Dec
gcompareFunction

instance DeriveGCompare t => DeriveGCompare [t] where
    deriveGCompare :: [t] -> Q [Dec]
deriveGCompare [it :: t
it] = t -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare t
it
    deriveGCompare _ = String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail "deriveGCompare: [] instance only applies to single-element lists"

instance DeriveGCompare t => DeriveGCompare (Q t) where
    deriveGCompare :: Q t -> Q [Dec]
deriveGCompare = (Q t -> (t -> Q [Dec]) -> Q [Dec]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= t -> Q [Dec]
forall t. DeriveGCompare t => t -> Q [Dec]
deriveGCompare)

gcompareFunction :: [TyVarBndr] -> t Con -> Q Dec
gcompareFunction boundVars :: [TyVarBndr]
boundVars cons :: t Con
cons
    | t Con -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null t Con
cons = Name -> [ClauseQ] -> Q Dec
funD 'gcompare [[PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [] (ExpQ -> BodyQ
normalB [| \x y -> seq x (seq y undefined) |]) []]
    | Bool
otherwise = Name -> [ClauseQ] -> Q Dec
funD 'gcompare ((Con -> [ClauseQ]) -> t Con -> [ClauseQ]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Con -> [ClauseQ]
gcompareClauses t Con
cons)
    where
        -- for every constructor, first check for equality (recursively comparing
        -- arguments) then add catch-all cases; all not-yet-matched patterns are
        -- "greater than" the constructor under consideration.
        gcompareClauses :: Con -> [ClauseQ]
gcompareClauses con :: Con
con =
            [ Con -> ClauseQ
mainClause Con
con
            , [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [Name -> [FieldPatQ] -> PatQ
recP Name
conName [], PatQ
wildP] (ExpQ -> BodyQ
normalB [| GLT |]) []
            , [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [PatQ
wildP, Name -> [FieldPatQ] -> PatQ
recP Name
conName []] (ExpQ -> BodyQ
normalB [| GGT |]) []
            ] where conName :: Name
conName = Con -> Name
nameOfCon Con
con

        needsGCompare :: Type -> Con -> Bool
needsGCompare argType :: Type
argType con :: Con
con = (TyVarBndr -> Bool) -> [TyVarBndr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any ((Name -> Type -> Bool
`occursInType` Type
argType) (Name -> Bool) -> (TyVarBndr -> Name) -> TyVarBndr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TyVarBndr -> Name
nameOfBinder) ([TyVarBndr]
boundVars [TyVarBndr] -> [TyVarBndr] -> [TyVarBndr]
forall a. [a] -> [a] -> [a]
++ Con -> [TyVarBndr]
varsBoundInCon Con
con)

        -- main clause; using the 'GComparing' monad, compare all arguments to the
        -- constructor recursively, attempting to unify type variables by recursive
        -- calls to gcompare whenever needed (that is, whenever a constructor argument's
        -- type contains a variable bound in the data declaration or in the constructor's
        -- type signature)
        mainClause :: Con -> ClauseQ
mainClause con :: Con
con = do
            let conName :: Name
conName = Con -> Name
nameOfCon Con
con
                argTypes :: [Type]
argTypes = Con -> [Type]
argTypesOfCon Con
con
                nArgs :: Int
nArgs = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
argTypes

            [Name]
lArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName "x")
            [Name]
rArgNames <- Int -> Q Name -> Q [Name]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
nArgs (String -> Q Name
newName "y")

            [PatQ] -> BodyQ -> [Q Dec] -> ClauseQ
clause [ Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
lArgNames)
                   , Name -> [PatQ] -> PatQ
conP Name
conName ((Name -> PatQ) -> [Name] -> [PatQ]
forall a b. (a -> b) -> [a] -> [b]
map Name -> PatQ
varP [Name]
rArgNames)
                   ]
                ( ExpQ -> BodyQ
normalB
                    [| runGComparing $
                        $(doE
                            (  [ if needsGCompare argType con
                                    then bindS (conP 'Refl []) [| geq' $(varE lArg) $(varE rArg) |]
                                    else noBindS [| compare' $(varE lArg) $(varE rArg) |]
                               | (lArg, rArg, argType) <- zip3 lArgNames rArgNames argTypes
                               ]
                            ++ [ noBindS [| return GEQ |] ]
                            )
                        )
                    |]
                ) []