{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs             #-}
{-# LANGUAGE TypeOperators     #-}
--------------------------------------------------------------------------------
-- |
-- Module      :  Data.Comp.Multi.Equality
-- Copyright   :  (c) Patrick Bahr, 2011
-- License     :  BSD3
-- Maintainer  :  Patrick Bahr <paba@diku.dk>
-- Stability   :  experimental
-- Portability :  non-portable (GHC Extensions)
--
-- This module defines equality for (higher-order) signatures, which lifts to
-- equality for (higher-order) terms and contexts. All definitions are
-- generalised versions of those in "Data.Comp.Equality".
--
--------------------------------------------------------------------------------
module Data.Comp.Multi.Equality
    (
     EqHF(..),
     KEq(..),
     heqMod
    ) where

import Data.Comp.Multi.HFoldable
import Data.Comp.Multi.HFunctor
import Data.Comp.Multi.Ops
import Data.Comp.Multi.Term

class KEq f where
    keq :: f i -> f j -> Bool

{-| Signature equality. An instance @EqHF f@ gives rise to an instance
  @KEq (HTerm f)@. -}
class EqHF f where
    eqHF :: KEq g => f g i -> f g j -> Bool

instance Eq a => KEq (K a) where
    keq :: forall i j. K a i -> K a j -> Bool
keq (K a
x) (K a
y) = a
x forall a. Eq a => a -> a -> Bool
== a
y

instance KEq a => Eq (E a) where
     E a i
x == :: E a -> E a -> Bool
== E a i
y = a i
x forall (f :: * -> *) i j. KEq f => f i -> f j -> Bool
`keq`  a i
y

{-|
  'EqF' is propagated through sums.
-}
instance (EqHF f, EqHF g) => EqHF (f :+: g) where
    eqHF :: forall (g :: * -> *) i j.
KEq g =>
(:+:) f g g i -> (:+:) f g g j -> Bool
eqHF (Inl f g i
x) (Inl f g j
y) = forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
eqHF f g i
x f g j
y
    eqHF (Inr g g i
x) (Inr g g j
y) = forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
eqHF g g i
x g g j
y
    eqHF (:+:) f g g i
_ (:+:) f g g j
_ = Bool
False

instance EqHF f => EqHF (Cxt h f) where
    eqHF :: forall (g :: * -> *) i j.
KEq g =>
Cxt h f g i -> Cxt h f g j -> Bool
eqHF (Term f (Cxt h f g) i
e1) (Term f (Cxt h f g) j
e2) = f (Cxt h f g) i
e1 forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
`eqHF` f (Cxt h f g) j
e2
    eqHF (Hole g i
h1) (Hole g j
h2) = g i
h1 forall (f :: * -> *) i j. KEq f => f i -> f j -> Bool
`keq` g j
h2
    eqHF Cxt h f g i
_ Cxt h f g j
_ = Bool
False

instance (EqHF f, KEq a) => KEq (Cxt h f a) where
    keq :: forall i j. Cxt h f a i -> Cxt h f a j -> Bool
keq = forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
eqHF

{-|
  From an 'EqF' functor an 'Eq' instance of the corresponding
  term type can be derived.
-}
instance (EqHF f, KEq a) => Eq (Cxt h f a i) where
    == :: Cxt h f a i -> Cxt h f a i -> Bool
(==) = forall (f :: * -> *) i j. KEq f => f i -> f j -> Bool
keq

{-| This function implements equality of values of type @f a@ modulo
the equality of @a@ itself. If two functorial values are equal in this
sense, 'eqMod' returns a 'Just' value containing a list of pairs
consisting of corresponding components of the two functorial
values. -}

heqMod :: (EqHF f, HFunctor f, HFoldable f) => f a i -> f b i -> Maybe [(E a, E b)]
heqMod :: forall (f :: (* -> *) -> * -> *) (a :: * -> *) i (b :: * -> *).
(EqHF f, HFunctor f, HFoldable f) =>
f a i -> f b i -> Maybe [(E a, E b)]
heqMod f a i
s f b i
t
    | forall {f :: * -> *} {i}. f f i -> f (K ()) i
unit f a i
s forall (f :: (* -> *) -> * -> *) (g :: * -> *) i j.
(EqHF f, KEq g) =>
f g i -> f g j -> Bool
`eqHF` forall {f :: * -> *} {i}. f f i -> f (K ()) i
unit' f b i
t = forall a. a -> Maybe a
Just [(E a, E b)]
args
    | Bool
otherwise = forall a. Maybe a
Nothing
    where unit :: f f i -> f (K ()) i
unit = forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a i. a -> K a i
K ())
          unit' :: f f i -> f (K ()) i
unit' = forall (h :: (* -> *) -> * -> *) (f :: * -> *) (g :: * -> *).
HFunctor h =>
(f :-> g) -> h f :-> h g
hfmap (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a i. a -> K a i
K ())
          args :: [(E a, E b)]
args = forall (f :: (* -> *) -> * -> *) (a :: * -> *).
HFoldable f =>
f a :=> [E a]
htoList f a i
s forall a b. [a] -> [b] -> [(a, b)]
`zip` forall (f :: (* -> *) -> * -> *) (a :: * -> *).
HFoldable f =>
f a :=> [E a]
htoList f b i
t