{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.Bool
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.SEq
  ( -- * Symbolic equality
    SEq (..),
    SEq' (..),
  )
where

import Control.Monad.Except (ExceptT (ExceptT))
import Control.Monad.Identity
  ( Identity (Identity),
    IdentityT (IdentityT),
  )
import Control.Monad.Trans.Maybe (MaybeT (MaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import qualified Data.ByteString as B
import Data.Functor.Sum (Sum)
import Data.Int (Int16, Int32, Int64, Int8)
import qualified Data.Text as T
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeNats (KnownNat, type (<=))
import Generics.Deriving
  ( Default (Default),
    Generic (Rep, from),
    K1 (K1),
    M1 (M1),
    U1,
    V1,
    type (:*:) ((:*:)),
    type (:+:) (L1, R1),
  )
import Grisette.Internal.Core.Control.Exception (AssertionError, VerificationConditions)
import Grisette.Internal.Core.Data.Class.LogicalOp (LogicalOp (symNot, (.&&)))
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.SymPrim.BV (IntN, WordN)
import Grisette.Internal.SymPrim.Prim.Term (pevalEqTerm)
import Grisette.Internal.SymPrim.SymBV
  ( SymIntN (SymIntN),
    SymWordN (SymWordN),
  )
import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool))
import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger))

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> :set -XDataKinds
-- >>> :set -XBinaryLiterals
-- >>> :set -XFlexibleContexts
-- >>> :set -XFlexibleInstances
-- >>> :set -XFunctionalDependencies

-- | Symbolic equality. Note that we can't use Haskell's 'Eq' class since
-- symbolic comparison won't necessarily return a concrete 'Bool' value.
--
-- >>> let a = 1 :: SymInteger
-- >>> let b = 2 :: SymInteger
-- >>> a .== b
-- false
-- >>> a ./= b
-- true
--
-- >>> let a = "a" :: SymInteger
-- >>> let b = "b" :: SymInteger
-- >>> a ./= b
-- (! (= a b))
-- >>> a ./= b
-- (! (= a b))
--
-- __Note:__ This type class can be derived for algebraic data types.
-- You may need the @DerivingVia@ and @DerivingStrategies@ extensions.
--
-- > data X = ... deriving Generic deriving SEq via (Default X)
class SEq a where
  (.==) :: a -> a -> SymBool
  a
a .== a
b = SymBool -> SymBool
forall b. LogicalOp b => b -> b
symNot (SymBool -> SymBool) -> SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ a
a a -> a -> SymBool
forall a. SEq a => a -> a -> SymBool
./= a
b
  {-# INLINE (.==) #-}
  infix 4 .==

  (./=) :: a -> a -> SymBool
  a
a ./= a
b = SymBool -> SymBool
forall b. LogicalOp b => b -> b
symNot (SymBool -> SymBool) -> SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ a
a a -> a -> SymBool
forall a. SEq a => a -> a -> SymBool
.== a
b
  {-# INLINE (./=) #-}
  infix 4 ./=
  {-# MINIMAL (.==) | (./=) #-}

-- SEq instances
#define CONCRETE_SEQ(type) \
instance SEq type where \
  l .== r = con $ l == r; \
  {-# INLINE (.==) #-}

#define CONCRETE_SEQ_BV(type) \
instance (KnownNat n, 1 <= n) => SEq (type n) where \
  l .== r = con $ l == r; \
  {-# INLINE (.==) #-}

#if 1
CONCRETE_SEQ(Bool)
CONCRETE_SEQ(Integer)
CONCRETE_SEQ(Char)
CONCRETE_SEQ(Int)
CONCRETE_SEQ(Int8)
CONCRETE_SEQ(Int16)
CONCRETE_SEQ(Int32)
CONCRETE_SEQ(Int64)
CONCRETE_SEQ(Word)
CONCRETE_SEQ(Word8)
CONCRETE_SEQ(Word16)
CONCRETE_SEQ(Word32)
CONCRETE_SEQ(Word64)
CONCRETE_SEQ(B.ByteString)
CONCRETE_SEQ(T.Text)
CONCRETE_SEQ_BV(WordN)
CONCRETE_SEQ_BV(IntN)
#endif

-- List
deriving via (Default [a]) instance (SEq a) => SEq [a]

-- Maybe
deriving via (Default (Maybe a)) instance (SEq a) => SEq (Maybe a)

-- Either
deriving via (Default (Either e a)) instance (SEq e, SEq a) => SEq (Either e a)

-- ExceptT
instance (SEq (m (Either e a))) => SEq (ExceptT e m a) where
  (ExceptT m (Either e a)
a) .== :: ExceptT e m a -> ExceptT e m a -> SymBool
.== (ExceptT m (Either e a)
b) = m (Either e a)
a m (Either e a) -> m (Either e a) -> SymBool
forall a. SEq a => a -> a -> SymBool
.== m (Either e a)
b
  {-# INLINE (.==) #-}

-- MaybeT
instance (SEq (m (Maybe a))) => SEq (MaybeT m a) where
  (MaybeT m (Maybe a)
a) .== :: MaybeT m a -> MaybeT m a -> SymBool
.== (MaybeT m (Maybe a)
b) = m (Maybe a)
a m (Maybe a) -> m (Maybe a) -> SymBool
forall a. SEq a => a -> a -> SymBool
.== m (Maybe a)
b
  {-# INLINE (.==) #-}

-- ()
instance SEq () where
  ()
_ .== :: () -> () -> SymBool
.== ()
_ = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE (.==) #-}

-- (,)
deriving via (Default (a, b)) instance (SEq a, SEq b) => SEq (a, b)

-- (,,)
deriving via (Default (a, b, c)) instance (SEq a, SEq b, SEq c) => SEq (a, b, c)

-- (,,,)
deriving via
  (Default (a, b, c, d))
  instance
    (SEq a, SEq b, SEq c, SEq d) =>
    SEq (a, b, c, d)

-- (,,,,)
deriving via
  (Default (a, b, c, d, e))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e) =>
    SEq (a, b, c, d, e)

-- (,,,,,)
deriving via
  (Default (a, b, c, d, e, f))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e, SEq f) =>
    SEq (a, b, c, d, e, f)

-- (,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e, SEq f, SEq g) =>
    SEq (a, b, c, d, e, f, g)

-- (,,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g, h))
  instance
    (SEq a, SEq b, SEq c, SEq d, SEq e, SEq f, SEq g, SEq h) =>
    SEq (a, b, c, d, e, f, g, h)

-- Sum
deriving via
  (Default (Sum f g a))
  instance
    (SEq (f a), SEq (g a)) => SEq (Sum f g a)

-- Writer
instance (SEq (m (a, s))) => SEq (WriterLazy.WriterT s m a) where
  (WriterLazy.WriterT m (a, s)
l) .== :: WriterT s m a -> WriterT s m a -> SymBool
.== (WriterLazy.WriterT m (a, s)
r) = m (a, s)
l m (a, s) -> m (a, s) -> SymBool
forall a. SEq a => a -> a -> SymBool
.== m (a, s)
r
  {-# INLINE (.==) #-}

instance (SEq (m (a, s))) => SEq (WriterStrict.WriterT s m a) where
  (WriterStrict.WriterT m (a, s)
l) .== :: WriterT s m a -> WriterT s m a -> SymBool
.== (WriterStrict.WriterT m (a, s)
r) = m (a, s)
l m (a, s) -> m (a, s) -> SymBool
forall a. SEq a => a -> a -> SymBool
.== m (a, s)
r
  {-# INLINE (.==) #-}

-- Identity
instance (SEq a) => SEq (Identity a) where
  (Identity a
l) .== :: Identity a -> Identity a -> SymBool
.== (Identity a
r) = a
l a -> a -> SymBool
forall a. SEq a => a -> a -> SymBool
.== a
r
  {-# INLINE (.==) #-}

-- IdentityT
instance (SEq (m a)) => SEq (IdentityT m a) where
  (IdentityT m a
l) .== :: IdentityT m a -> IdentityT m a -> SymBool
.== (IdentityT m a
r) = m a
l m a -> m a -> SymBool
forall a. SEq a => a -> a -> SymBool
.== m a
r
  {-# INLINE (.==) #-}

-- Symbolic types
#define SEQ_SIMPLE(symtype) \
instance SEq symtype where \
  (symtype l) .== (symtype r) = SymBool $ pevalEqTerm l r

#define SEQ_BV(symtype) \
instance (KnownNat n, 1 <= n) => SEq (symtype n) where \
  (symtype l) .== (symtype r) = SymBool $ pevalEqTerm l r

#if 1
SEQ_SIMPLE(SymBool)
SEQ_SIMPLE(SymInteger)
SEQ_BV(SymIntN)
SEQ_BV(SymWordN)
#endif

-- Exceptions
deriving via (Default AssertionError) instance SEq AssertionError

deriving via (Default VerificationConditions) instance SEq VerificationConditions

-- | Auxiliary class for 'SEq' instance derivation
class SEq' f where
  -- | Auxiliary function for '(..==) derivation
  (..==) :: f a -> f a -> SymBool

  infix 4 ..==

instance SEq' U1 where
  U1 a
_ ..== :: forall a. U1 a -> U1 a -> SymBool
..== U1 a
_ = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE (..==) #-}

instance SEq' V1 where
  V1 a
_ ..== :: forall a. V1 a -> V1 a -> SymBool
..== V1 a
_ = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
True
  {-# INLINE (..==) #-}

instance (SEq c) => SEq' (K1 i c) where
  (K1 c
a) ..== :: forall a. K1 i c a -> K1 i c a -> SymBool
..== (K1 c
b) = c
a c -> c -> SymBool
forall a. SEq a => a -> a -> SymBool
.== c
b
  {-# INLINE (..==) #-}

instance (SEq' a) => SEq' (M1 i c a) where
  (M1 a a
a) ..== :: forall a. M1 i c a a -> M1 i c a a -> SymBool
..== (M1 a a
b) = a a
a a a -> a a -> SymBool
forall a. a a -> a a -> SymBool
forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
..== a a
b
  {-# INLINE (..==) #-}

instance (SEq' a, SEq' b) => SEq' (a :+: b) where
  (L1 a a
a) ..== :: forall a. (:+:) a b a -> (:+:) a b a -> SymBool
..== (L1 a a
b) = a a
a a a -> a a -> SymBool
forall a. a a -> a a -> SymBool
forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
..== a a
b
  (R1 b a
a) ..== (R1 b a
b) = b a
a b a -> b a -> SymBool
forall a. b a -> b a -> SymBool
forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
..== b a
b
  (:+:) a b a
_ ..== (:+:) a b a
_ = Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
False
  {-# INLINE (..==) #-}

instance (SEq' a, SEq' b) => SEq' (a :*: b) where
  (a a
a1 :*: b a
b1) ..== :: forall a. (:*:) a b a -> (:*:) a b a -> SymBool
..== (a a
a2 :*: b a
b2) = (a a
a1 a a -> a a -> SymBool
forall a. a a -> a a -> SymBool
forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
..== a a
a2) SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.&& (b a
b1 b a -> b a -> SymBool
forall a. b a -> b a -> SymBool
forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
..== b a
b2)
  {-# INLINE (..==) #-}

instance (Generic a, SEq' (Rep a)) => SEq (Default a) where
  Default a
l .== :: Default a -> Default a -> SymBool
.== Default a
r = a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
l Rep a Any -> Rep a Any -> SymBool
forall a. Rep a a -> Rep a a -> SymBool
forall (f :: * -> *) a. SEq' f => f a -> f a -> SymBool
..== a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from a
r
  {-# INLINE (.==) #-}