{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Core.Data.Class.SubstituteSym
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.Class.SubstituteSym
  ( -- * Substituting symbolic constants
    SubstituteSym (..),
    SubstituteSym' (..),
  )
where

import Control.Monad.Except (ExceptT (ExceptT))
import Control.Monad.Identity
  ( Identity (Identity),
    IdentityT (IdentityT),
  )
import Control.Monad.Trans.Maybe (MaybeT (MaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import qualified Data.ByteString as B
import Data.Functor.Sum (Sum)
import Data.Int (Int16, Int32, Int64, Int8)
import qualified Data.Text as T
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeNats (KnownNat, type (<=))
import Generics.Deriving
  ( Default (Default, unDefault),
    Generic (Rep, from, to),
    K1 (K1),
    M1 (M1),
    U1,
    type (:*:) ((:*:)),
    type (:+:) (L1, R1),
  )
import Generics.Deriving.Instances ()
import Grisette.Core.Data.BV (IntN, SomeIntN, SomeWordN, WordN)
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term
  ( LinkedRep (underlyingTerm),
    SupportedPrim,
    TypedSymbol,
  )
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.TermSubstitution (substTerm)
import Grisette.IR.SymPrim.Data.SymPrim
  ( SomeSymIntN (SomeSymIntN),
    SomeSymWordN (SomeSymWordN),
    SymBool (SymBool),
    SymIntN (SymIntN),
    SymInteger (SymInteger),
    SymWordN (SymWordN),
    type (-~>) (SymGeneralFun),
    type (=~>) (SymTabularFun),
  )

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim

-- | Substitution of symbolic constants.
--
-- >>> a = "a" :: TypedSymbol Bool
-- >>> v = "x" .&& "y" :: SymBool
-- >>> substituteSym a v (["a" .&& "b", "a"] :: [SymBool])
-- [(&& (&& x y) b),(&& x y)]
--
-- __Note 1:__ This type class can be derived for algebraic data types.
-- You may need the @DerivingVia@ and @DerivingStrategies@ extensions.
--
-- > data X = ... deriving Generic deriving SubstituteSym via (Default X)
class SubstituteSym a where
  -- Substitute a symbolic constant to some symbolic value
  --
  -- >>> substituteSym "a" ("c" .&& "d" :: Sym Bool) ["a" .&& "b" :: Sym Bool, "a"]
  -- [(&& (&& c d) b),(&& c d)]
  substituteSym :: (LinkedRep cb sb) => TypedSymbol cb -> sb -> a -> a

#define CONCRETE_SUBSTITUTESYM(type) \
instance SubstituteSym type where \
  substituteSym _ _ = id

#define CONCRETE_SUBSTITUTESYM_BV(type) \
instance (KnownNat n, 1 <= n) => SubstituteSym (type n) where \
  substituteSym _ _ = id

#if 1
CONCRETE_SUBSTITUTESYM(Bool)
CONCRETE_SUBSTITUTESYM(Integer)
CONCRETE_SUBSTITUTESYM(Char)
CONCRETE_SUBSTITUTESYM(Int)
CONCRETE_SUBSTITUTESYM(Int8)
CONCRETE_SUBSTITUTESYM(Int16)
CONCRETE_SUBSTITUTESYM(Int32)
CONCRETE_SUBSTITUTESYM(Int64)
CONCRETE_SUBSTITUTESYM(Word)
CONCRETE_SUBSTITUTESYM(Word8)
CONCRETE_SUBSTITUTESYM(Word16)
CONCRETE_SUBSTITUTESYM(Word32)
CONCRETE_SUBSTITUTESYM(Word64)
CONCRETE_SUBSTITUTESYM(SomeWordN)
CONCRETE_SUBSTITUTESYM(SomeIntN)
CONCRETE_SUBSTITUTESYM(B.ByteString)
CONCRETE_SUBSTITUTESYM(T.Text)
CONCRETE_SUBSTITUTESYM_BV(WordN)
CONCRETE_SUBSTITUTESYM_BV(IntN)
#endif

instance SubstituteSym () where
  substituteSym :: forall cb sb. LinkedRep cb sb => TypedSymbol cb -> sb -> () -> ()
substituteSym TypedSymbol cb
_ sb
_ = () -> ()
forall a. a -> a
id

-- Either
deriving via
  (Default (Either a b))
  instance
    ( SubstituteSym a,
      SubstituteSym b
    ) =>
    SubstituteSym (Either a b)

-- Maybe
deriving via (Default (Maybe a)) instance (SubstituteSym a) => SubstituteSym (Maybe a)

-- List
deriving via (Default [a]) instance (SubstituteSym a) => SubstituteSym [a]

-- (,)
deriving via
  (Default (a, b))
  instance
    (SubstituteSym a, SubstituteSym b) =>
    SubstituteSym (a, b)

-- (,,)
deriving via
  (Default (a, b, c))
  instance
    ( SubstituteSym a,
      SubstituteSym b,
      SubstituteSym c
    ) =>
    SubstituteSym (a, b, c)

-- (,,,)
deriving via
  (Default (a, b, c, d))
  instance
    ( SubstituteSym a,
      SubstituteSym b,
      SubstituteSym c,
      SubstituteSym d
    ) =>
    SubstituteSym (a, b, c, d)

-- (,,,,)
deriving via
  (Default (a, b, c, d, e))
  instance
    ( SubstituteSym a,
      SubstituteSym b,
      SubstituteSym c,
      SubstituteSym d,
      SubstituteSym e
    ) =>
    SubstituteSym (a, b, c, d, e)

-- (,,,,,)
deriving via
  (Default (a, b, c, d, e, f))
  instance
    ( SubstituteSym a,
      SubstituteSym b,
      SubstituteSym c,
      SubstituteSym d,
      SubstituteSym e,
      SubstituteSym f
    ) =>
    SubstituteSym (a, b, c, d, e, f)

-- (,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g))
  instance
    ( SubstituteSym a,
      SubstituteSym b,
      SubstituteSym c,
      SubstituteSym d,
      SubstituteSym e,
      SubstituteSym f,
      SubstituteSym g
    ) =>
    SubstituteSym (a, b, c, d, e, f, g)

-- (,,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g, h))
  instance
    ( SubstituteSym a,
      SubstituteSym b,
      SubstituteSym c,
      SubstituteSym d,
      SubstituteSym e,
      SubstituteSym f,
      SubstituteSym g,
      SubstituteSym h
    ) =>
    SubstituteSym ((,,,,,,,) a b c d e f g h)

-- MaybeT
instance
  (SubstituteSym (m (Maybe a))) =>
  SubstituteSym (MaybeT m a)
  where
  substituteSym :: forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> MaybeT m a -> MaybeT m a
substituteSym TypedSymbol cb
sym sb
val (MaybeT m (Maybe a)
v) = m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe a) -> MaybeT m a) -> m (Maybe a) -> MaybeT m a
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> m (Maybe a) -> m (Maybe a)
forall a cb sb.
(SubstituteSym a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a -> a
forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> m (Maybe a) -> m (Maybe a)
substituteSym TypedSymbol cb
sym sb
val m (Maybe a)
v

-- ExceptT
instance
  (SubstituteSym (m (Either e a))) =>
  SubstituteSym (ExceptT e m a)
  where
  substituteSym :: forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> ExceptT e m a -> ExceptT e m a
substituteSym TypedSymbol cb
sym sb
val (ExceptT m (Either e a)
v) = m (Either e a) -> ExceptT e m a
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT (m (Either e a) -> ExceptT e m a)
-> m (Either e a) -> ExceptT e m a
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> m (Either e a) -> m (Either e a)
forall a cb sb.
(SubstituteSym a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a -> a
forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> m (Either e a) -> m (Either e a)
substituteSym TypedSymbol cb
sym sb
val m (Either e a)
v

-- Sum
deriving via
  (Default (Sum f g a))
  instance
    (SubstituteSym (f a), SubstituteSym (g a)) =>
    SubstituteSym (Sum f g a)

-- WriterT
instance
  (SubstituteSym (m (a, s))) =>
  SubstituteSym (WriterLazy.WriterT s m a)
  where
  substituteSym :: forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> WriterT s m a -> WriterT s m a
substituteSym TypedSymbol cb
sym sb
val (WriterLazy.WriterT m (a, s)
v) = m (a, s) -> WriterT s m a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterLazy.WriterT (m (a, s) -> WriterT s m a) -> m (a, s) -> WriterT s m a
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> m (a, s) -> m (a, s)
forall a cb sb.
(SubstituteSym a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a -> a
forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> m (a, s) -> m (a, s)
substituteSym TypedSymbol cb
sym sb
val m (a, s)
v

instance
  (SubstituteSym (m (a, s))) =>
  SubstituteSym (WriterStrict.WriterT s m a)
  where
  substituteSym :: forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> WriterT s m a -> WriterT s m a
substituteSym TypedSymbol cb
sym sb
val (WriterStrict.WriterT m (a, s)
v) = m (a, s) -> WriterT s m a
forall w (m :: * -> *) a. m (a, w) -> WriterT w m a
WriterStrict.WriterT (m (a, s) -> WriterT s m a) -> m (a, s) -> WriterT s m a
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> m (a, s) -> m (a, s)
forall a cb sb.
(SubstituteSym a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a -> a
forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> m (a, s) -> m (a, s)
substituteSym TypedSymbol cb
sym sb
val m (a, s)
v

-- Identity
instance (SubstituteSym a) => SubstituteSym (Identity a) where
  substituteSym :: forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> Identity a -> Identity a
substituteSym TypedSymbol cb
sym sb
val (Identity a
a) = a -> Identity a
forall a. a -> Identity a
Identity (a -> Identity a) -> a -> Identity a
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> a -> a
forall a cb sb.
(SubstituteSym a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a -> a
forall cb sb. LinkedRep cb sb => TypedSymbol cb -> sb -> a -> a
substituteSym TypedSymbol cb
sym sb
val a
a

-- IdentityT
instance (SubstituteSym (m a)) => SubstituteSym (IdentityT m a) where
  substituteSym :: forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> IdentityT m a -> IdentityT m a
substituteSym TypedSymbol cb
sym sb
val (IdentityT m a
a) = m a -> IdentityT m a
forall {k} (f :: k -> *) (a :: k). f a -> IdentityT f a
IdentityT (m a -> IdentityT m a) -> m a -> IdentityT m a
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> m a -> m a
forall a cb sb.
(SubstituteSym a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a -> a
forall cb sb. LinkedRep cb sb => TypedSymbol cb -> sb -> m a -> m a
substituteSym TypedSymbol cb
sym sb
val m a
a

#define SUBSTITUTE_SYM_SIMPLE(symtype) \
instance SubstituteSym symtype where \
  substituteSym sym v (symtype t) = symtype $ substTerm sym (underlyingTerm v) t

#define SUBSTITUTE_SYM_BV(symtype) \
instance (KnownNat n, 1 <= n) => SubstituteSym (symtype n) where \
  substituteSym sym v (symtype t) = symtype $ substTerm sym (underlyingTerm v) t

#define SUBSTITUTE_SYM_FUN(op, cons) \
instance (SupportedPrim ca, SupportedPrim cb, LinkedRep ca sa, LinkedRep cb sb) => SubstituteSym (sa op sb) where \
  substituteSym sym v (cons t) = cons $ substTerm sym (underlyingTerm v) t

#define SUBSTITUTE_SYM_BV_SOME(somety, origty) \
instance SubstituteSym somety where \
  substituteSym sym v (somety (origty t)) = somety $ origty $ substTerm sym (underlyingTerm v) t

#if 1
SUBSTITUTE_SYM_SIMPLE(SymBool)
SUBSTITUTE_SYM_SIMPLE(SymInteger)
SUBSTITUTE_SYM_BV(SymIntN)
SUBSTITUTE_SYM_BV(SymWordN)
SUBSTITUTE_SYM_FUN(=~>, SymTabularFun)
SUBSTITUTE_SYM_FUN(-~>, SymGeneralFun)
SUBSTITUTE_SYM_BV_SOME(SomeSymIntN, SymIntN)
SUBSTITUTE_SYM_BV_SOME(SomeSymWordN, SymWordN)
#endif

-- | Auxiliary class for 'SubstituteSym' instance derivation
class SubstituteSym' a where
  -- | Auxiliary function for 'substituteSym' derivation
  substituteSym' :: (LinkedRep cb sb) => TypedSymbol cb -> sb -> a c -> a c

instance
  ( Generic a,
    SubstituteSym' (Rep a)
  ) =>
  SubstituteSym (Default a)
  where
  substituteSym :: forall cb sb.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> Default a -> Default a
substituteSym TypedSymbol cb
sym sb
val = a -> Default a
forall a. a -> Default a
Default (a -> Default a) -> (Default a -> a) -> Default a -> Default a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rep a Any -> a
forall a x. Generic a => Rep a x -> a
forall x. Rep a x -> a
to (Rep a Any -> a) -> (Default a -> Rep a Any) -> Default a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypedSymbol cb -> sb -> Rep a Any -> Rep a Any
forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> Rep a c -> Rep a c
forall (a :: * -> *) cb sb c.
(SubstituteSym' a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a c -> a c
substituteSym' TypedSymbol cb
sym sb
val (Rep a Any -> Rep a Any)
-> (Default a -> Rep a Any) -> Default a -> Rep a Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from (a -> Rep a Any) -> (Default a -> a) -> Default a -> Rep a Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Default a -> a
forall a. Default a -> a
unDefault

instance SubstituteSym' U1 where
  substituteSym' :: forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> U1 c -> U1 c
substituteSym' TypedSymbol cb
_ sb
_ = U1 c -> U1 c
forall a. a -> a
id

instance (SubstituteSym c) => SubstituteSym' (K1 i c) where
  substituteSym' :: forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> K1 i c c -> K1 i c c
substituteSym' TypedSymbol cb
sym sb
val (K1 c
v) = c -> K1 i c c
forall k i c (p :: k). c -> K1 i c p
K1 (c -> K1 i c c) -> c -> K1 i c c
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> c -> c
forall a cb sb.
(SubstituteSym a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a -> a
forall cb sb. LinkedRep cb sb => TypedSymbol cb -> sb -> c -> c
substituteSym TypedSymbol cb
sym sb
val c
v

instance (SubstituteSym' a) => SubstituteSym' (M1 i c a) where
  substituteSym' :: forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> M1 i c a c -> M1 i c a c
substituteSym' TypedSymbol cb
sym sb
val (M1 a c
v) = a c -> M1 i c a c
forall k i (c :: Meta) (f :: k -> *) (p :: k). f p -> M1 i c f p
M1 (a c -> M1 i c a c) -> a c -> M1 i c a c
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> a c -> a c
forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> a c -> a c
forall (a :: * -> *) cb sb c.
(SubstituteSym' a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a c -> a c
substituteSym' TypedSymbol cb
sym sb
val a c
v

instance (SubstituteSym' a, SubstituteSym' b) => SubstituteSym' (a :+: b) where
  substituteSym' :: forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> (:+:) a b c -> (:+:) a b c
substituteSym' TypedSymbol cb
sym sb
val (L1 a c
l) = a c -> (:+:) a b c
forall k (f :: k -> *) (g :: k -> *) (p :: k). f p -> (:+:) f g p
L1 (a c -> (:+:) a b c) -> a c -> (:+:) a b c
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> a c -> a c
forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> a c -> a c
forall (a :: * -> *) cb sb c.
(SubstituteSym' a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a c -> a c
substituteSym' TypedSymbol cb
sym sb
val a c
l
  substituteSym' TypedSymbol cb
sym sb
val (R1 b c
r) = b c -> (:+:) a b c
forall k (f :: k -> *) (g :: k -> *) (p :: k). g p -> (:+:) f g p
R1 (b c -> (:+:) a b c) -> b c -> (:+:) a b c
forall a b. (a -> b) -> a -> b
$ TypedSymbol cb -> sb -> b c -> b c
forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> b c -> b c
forall (a :: * -> *) cb sb c.
(SubstituteSym' a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a c -> a c
substituteSym' TypedSymbol cb
sym sb
val b c
r

instance (SubstituteSym' a, SubstituteSym' b) => SubstituteSym' (a :*: b) where
  substituteSym' :: forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> (:*:) a b c -> (:*:) a b c
substituteSym' TypedSymbol cb
sym sb
val (a c
a :*: b c
b) = TypedSymbol cb -> sb -> a c -> a c
forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> a c -> a c
forall (a :: * -> *) cb sb c.
(SubstituteSym' a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a c -> a c
substituteSym' TypedSymbol cb
sym sb
val a c
a a c -> b c -> (:*:) a b c
forall k (f :: k -> *) (g :: k -> *) (p :: k).
f p -> g p -> (:*:) f g p
:*: TypedSymbol cb -> sb -> b c -> b c
forall cb sb c.
LinkedRep cb sb =>
TypedSymbol cb -> sb -> b c -> b c
forall (a :: * -> *) cb sb c.
(SubstituteSym' a, LinkedRep cb sb) =>
TypedSymbol cb -> sb -> a c -> a c
substituteSym' TypedSymbol cb
sym sb
val b c
b