{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}

{-|
Module      : Data.JoinSemilattice.Class.Eq
Description : Equality relationships.
Copyright   : (c) Tom Harding, 2020
License     : MIT
-}
module Data.JoinSemilattice.Class.Eq where

import Control.Applicative (liftA2)
import Data.JoinSemilattice.Class.Boolean (BooleanR (..))
import Data.JoinSemilattice.Class.Merge (Merge)
import Data.JoinSemilattice.Defined (Defined (..))
import Data.JoinSemilattice.Intersect (Intersect (..), Intersectable)
import qualified Data.JoinSemilattice.Intersect as Intersect
import Data.Kind (Constraint, Type)

class EqC f x => EqC' f x
instance EqC f x => EqC' f x

-- | Equality between two variables as a relationship between them and their
-- result. The hope here is that, if we learn the output before the inputs, we
-- can often "work backwards" to learn something about them. If we know the
-- result is exactly /true/, for example, we can effectively then
-- 'Control.Monad.Cell.Class.unify' the two input cells, as we know that their
-- values will always be the same.
--
-- The class constraints are a bit ugly here, and it's something I'm hoping I
-- can tidy up down the line. The idea is that, previously, our class was
-- defined as:
--
-- @
--   class EqR (x :: Type) (b :: Type) | x -> b where
--     eqR :: (x -> x -> b) -> (x -> x -> b)
-- @
--
-- The problem here was that, if we said @x .== x :: Prop m (Defined Bool)@, we
-- couldn't even infer that the type of @x@ was @Defined@-wrapped, which made
-- the overloaded literals, for example, largely pointless.
--
-- To fix it, the class was rewritten to parameterise the wrapper type, which
-- means we can always make this inference. However, the constraints got a bit
-- grizzly when I hacked it together.
class (forall x. EqC' f x => Merge (f x), BooleanR f)
    => EqR (f :: Type -> Type) where
  type EqC f :: Type -> Constraint

  eqR :: EqC' f x => ( f x, f x, f Bool ) -> ( f x, f x, f Bool )

-- | A relationship between two variables and the result of a not-equals
-- comparison between them.
neR :: (EqR f, EqC' f x) => ( f x, f x, f Bool ) -> ( f x, f x, f Bool )
neR :: (f x, f x, f Bool) -> (f x, f x, f Bool)
neR ( f x
x, f x
y, f Bool
z )
  = let ( f Bool
notZ', f Bool
_ ) = (f Bool, f Bool) -> (f Bool, f Bool)
forall (f :: * -> *).
BooleanR f =>
(f Bool, f Bool) -> (f Bool, f Bool)
notR ( f Bool
forall a. Monoid a => a
mempty, f Bool
z )
        ( f x
x', f x
y', f Bool
notZR ) = (f x, f x, f Bool) -> (f x, f x, f Bool)
forall (f :: * -> *) x.
(EqR f, EqC' f x) =>
(f x, f x, f Bool) -> (f x, f x, f Bool)
eqR ( f x
x, f x
y, f Bool
notZ' )
        ( f Bool
_, f Bool
z' ) = (f Bool, f Bool) -> (f Bool, f Bool)
forall (f :: * -> *).
BooleanR f =>
(f Bool, f Bool) -> (f Bool, f Bool)
notR ( f Bool
notZR, f Bool
forall a. Monoid a => a
mempty )

    in ( f x
x', f x
y', f Bool
z' )

instance EqR Defined where
  type EqC Defined = Eq

  eqR :: (Defined x, Defined x, Defined Bool)
-> (Defined x, Defined x, Defined Bool)
eqR ( Defined x
x, Defined x
y, Defined Bool
z )
    = ( if Defined Bool
z Defined Bool -> Defined Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Defined Bool
forall (f :: * -> *). BooleanR f => f Bool
trueR then Defined x
y else Defined x
forall a. Monoid a => a
mempty
      , if Defined Bool
z Defined Bool -> Defined Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Defined Bool
forall (f :: * -> *). BooleanR f => f Bool
trueR then Defined x
x else Defined x
forall a. Monoid a => a
mempty
      , (x -> x -> Bool) -> Defined x -> Defined x -> Defined Bool
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 x -> x -> Bool
forall a. Eq a => a -> a -> Bool
(==) Defined x
x Defined x
y
      )

instance EqR Intersect where
  type EqC Intersect = Intersectable

  eqR :: (Intersect x, Intersect x, Intersect Bool)
-> (Intersect x, Intersect x, Intersect Bool)
eqR ( Intersect x
x, Intersect x
y, Intersect Bool
z )
    = ( if | Intersect Bool
z Intersect Bool -> Intersect Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Intersect Bool
forall (f :: * -> *). BooleanR f => f Bool
trueR                           -> Intersect x
y
           | Intersect Bool
z Intersect Bool -> Intersect Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Intersect Bool
forall (f :: * -> *). BooleanR f => f Bool
falseR Bool -> Bool -> Bool
&& Intersect x -> Int
forall x. Intersectable x => Intersect x -> Int
Intersect.size Intersect x
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 -> Intersect x -> Intersect x
forall x. Intersectable x => Intersect x -> Intersect x
Intersect.except Intersect x
y
           | Bool
otherwise                            -> Intersect x
forall a. Monoid a => a
mempty

      , if | Intersect Bool
z Intersect Bool -> Intersect Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Intersect Bool
forall (f :: * -> *). BooleanR f => f Bool
trueR                           -> Intersect x
x
           | Intersect Bool
z Intersect Bool -> Intersect Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Intersect Bool
forall (f :: * -> *). BooleanR f => f Bool
falseR Bool -> Bool -> Bool
&& Intersect x -> Int
forall x. Intersectable x => Intersect x -> Int
Intersect.size Intersect x
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 -> Intersect x -> Intersect x
forall x. Intersectable x => Intersect x -> Intersect x
Intersect.except Intersect x
x
           | Bool
otherwise                            -> Intersect x
forall a. Monoid a => a
mempty

      , (x -> x -> Bool) -> Intersect x -> Intersect x -> Intersect Bool
forall this that result.
(Intersectable this, Intersectable that, Intersectable result) =>
(this -> that -> result)
-> Intersect this -> Intersect that -> Intersect result
Intersect.lift2 x -> x -> Bool
forall a. Eq a => a -> a -> Bool
(==) Intersect x
x Intersect x
y
      )