{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.ITEOp
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.ITEOp
  ( ITEOp (..),
  )
where

import Control.Monad.Identity (Identity (Identity))
import qualified Data.HashSet as HS
import GHC.TypeNats (KnownNat, type (<=))
import Grisette.Internal.SymPrim.FP (ValidFP)
import Grisette.Internal.SymPrim.GeneralFun (freshArgSymbol, substTerm, type (-->) (GeneralFun))
import Grisette.Internal.SymPrim.Prim.SomeTerm (SomeTerm (SomeTerm))
import Grisette.Internal.SymPrim.Prim.Term
  ( SupportedPrim (pevalITETerm),
    TypedConstantSymbol,
    symTerm,
  )
import Grisette.Internal.SymPrim.SymAlgReal (SymAlgReal (SymAlgReal))
import Grisette.Internal.SymPrim.SymBV
  ( SymIntN (SymIntN),
    SymWordN (SymWordN),
  )
import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool))
import Grisette.Internal.SymPrim.SymFP
  ( SymFP (SymFP),
    SymFPRoundingMode (SymFPRoundingMode),
  )
import Grisette.Internal.SymPrim.SymGeneralFun (type (-~>) (SymGeneralFun))
import Grisette.Internal.SymPrim.SymInteger (SymInteger (SymInteger))
import Grisette.Internal.SymPrim.SymTabularFun (type (=~>) (SymTabularFun))

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim

-- | ITE operator for solvable (see "Grisette.Core#g:solvable")s, including
-- symbolic boolean, integer, etc.
--
-- >>> let a = "a" :: SymBool
-- >>> let b = "b" :: SymBool
-- >>> let c = "c" :: SymBool
-- >>> symIte a b c
-- (ite a b c)
class ITEOp v where
  -- | Symbolic if-then-else.
  symIte :: SymBool -> v -> v -> v

-- ITEOp instances
#define ITEOP_SIMPLE(type) \
instance ITEOp type where \
  symIte (SymBool c) (type t) (type f) = type $ pevalITETerm c t f; \
  {-# INLINE symIte #-}

#define ITEOP_BV(type) \
instance (KnownNat n, 1 <= n) => ITEOp (type n) where \
  symIte (SymBool c) (type t) (type f) = type $ pevalITETerm c t f; \
  {-# INLINE symIte #-}

#define ITEOP_FUN(cop, op, cons) \
instance ITEOp (op sa sb) where \
  symIte (SymBool c) (cons t) (cons f) = cons $ pevalITETerm c t f; \
  {-# INLINE symIte #-}

#if 1
ITEOP_SIMPLE(SymBool)
ITEOP_SIMPLE(SymInteger)
ITEOP_SIMPLE(SymFPRoundingMode)
ITEOP_SIMPLE(SymAlgReal)
ITEOP_BV(SymIntN)
ITEOP_BV(SymWordN)
ITEOP_FUN((=->), (=~>), SymTabularFun)
ITEOP_FUN((-->), (-~>), SymGeneralFun)
#endif

instance ITEOp (a --> b) where
  symIte :: SymBool -> (a --> b) -> (a --> b) -> a --> b
symIte
    (SymBool Term Bool
c)
    (GeneralFun (TypedConstantSymbol a
ta :: TypedConstantSymbol a) Term b
a)
    (GeneralFun TypedConstantSymbol a
tb Term b
b) =
      TypedConstantSymbol a -> Term b -> a --> b
forall a b.
(SupportedNonFuncPrim a, SupportedPrim b) =>
TypedConstantSymbol a -> Term b -> a --> b
GeneralFun TypedConstantSymbol a
argSymbol (Term b -> a --> b) -> Term b -> a --> b
forall a b. (a -> b) -> a -> b
$
        Term Bool -> Term b -> Term b -> Term b
forall t.
SupportedPrim t =>
Term Bool -> Term t -> Term t -> Term t
pevalITETerm
          Term Bool
c
          (TypedConstantSymbol a
-> Term a -> HashSet SomeTypedConstantSymbol -> Term b -> Term b
forall (knd :: SymbolKind) a b.
(SupportedPrim a, SupportedPrim b, IsSymbolKind knd) =>
TypedSymbol knd a
-> Term a -> HashSet SomeTypedConstantSymbol -> Term b -> Term b
substTerm TypedConstantSymbol a
ta (TypedConstantSymbol a -> Term a
forall (knd :: SymbolKind) t. TypedSymbol knd t -> Term t
symTerm TypedConstantSymbol a
argSymbol) HashSet SomeTypedConstantSymbol
forall a. HashSet a
HS.empty Term b
a)
          (TypedConstantSymbol a
-> Term a -> HashSet SomeTypedConstantSymbol -> Term b -> Term b
forall (knd :: SymbolKind) a b.
(SupportedPrim a, SupportedPrim b, IsSymbolKind knd) =>
TypedSymbol knd a
-> Term a -> HashSet SomeTypedConstantSymbol -> Term b -> Term b
substTerm TypedConstantSymbol a
tb (TypedConstantSymbol a -> Term a
forall (knd :: SymbolKind) t. TypedSymbol knd t -> Term t
symTerm TypedConstantSymbol a
argSymbol) HashSet SomeTypedConstantSymbol
forall a. HashSet a
HS.empty Term b
b)
      where
        argSymbol :: TypedConstantSymbol a
        argSymbol :: TypedConstantSymbol a
argSymbol = [SomeTerm] -> TypedConstantSymbol a
forall a.
SupportedNonFuncPrim a =>
[SomeTerm] -> TypedConstantSymbol a
freshArgSymbol [Term b -> SomeTerm
forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm Term b
a, Term b -> SomeTerm
forall a. SupportedPrim a => Term a -> SomeTerm
SomeTerm Term b
b]
  {-# INLINE symIte #-}

instance (ValidFP eb sb) => ITEOp (SymFP eb sb) where
  symIte :: SymBool -> SymFP eb sb -> SymFP eb sb -> SymFP eb sb
symIte (SymBool Term Bool
c) (SymFP Term (FP eb sb)
t) (SymFP Term (FP eb sb)
f) = Term (FP eb sb) -> SymFP eb sb
forall (eb :: Nat) (sb :: Nat). Term (FP eb sb) -> SymFP eb sb
SymFP (Term (FP eb sb) -> SymFP eb sb) -> Term (FP eb sb) -> SymFP eb sb
forall a b. (a -> b) -> a -> b
$ Term Bool -> Term (FP eb sb) -> Term (FP eb sb) -> Term (FP eb sb)
forall t.
SupportedPrim t =>
Term Bool -> Term t -> Term t -> Term t
pevalITETerm Term Bool
c Term (FP eb sb)
t Term (FP eb sb)
f
  {-# INLINE symIte #-}

instance (ITEOp v) => ITEOp (Identity v) where
  symIte :: SymBool -> Identity v -> Identity v -> Identity v
symIte SymBool
c (Identity v
t) (Identity v
f) = v -> Identity v
forall a. a -> Identity a
Identity (v -> Identity v) -> v -> Identity v
forall a b. (a -> b) -> a -> b
$ SymBool -> v -> v -> v
forall v. ITEOp v => SymBool -> v -> v -> v
symIte SymBool
c v
t v
f
  {-# INLINE symIte #-}