{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Core.Data.Class.ExtractSymbolics
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.Class.ExtractSymbolics
  ( -- * Extracting symbolic constant set from a value
    ExtractSymbolics (..),
  )
where

import Control.Monad.Except (ExceptT (ExceptT))
import Control.Monad.Identity
  ( Identity (Identity),
    IdentityT (IdentityT),
  )
import Control.Monad.Trans.Maybe (MaybeT (MaybeT))
import qualified Control.Monad.Writer.Lazy as WriterLazy
import qualified Control.Monad.Writer.Strict as WriterStrict
import qualified Data.ByteString as B
import Data.Functor.Sum (Sum)
import Data.Int (Int16, Int32, Int64, Int8)
import qualified Data.Text as T
import Data.Word (Word16, Word32, Word64, Word8)
import GHC.TypeNats (KnownNat, type (<=))
import Generics.Deriving
  ( Default (Default, unDefault),
    Generic (Rep, from),
    K1 (unK1),
    M1 (unM1),
    U1,
    type (:*:) ((:*:)),
    type (:+:) (L1, R1),
  )
import Grisette.Core.Control.Exception (AssertionError, VerificationConditions)
import Grisette.Core.Data.BV (IntN, SomeIntN, SomeWordN, WordN)
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.Term
  ( LinkedRep,
    SupportedPrim,
  )
import Grisette.IR.SymPrim.Data.Prim.InternedTerm.TermUtils (extractSymbolicsTerm)
import Grisette.IR.SymPrim.Data.Prim.Model
  ( SymbolSet (SymbolSet),
  )
import Grisette.IR.SymPrim.Data.SymPrim
  ( SomeSymIntN (SomeSymIntN),
    SomeSymWordN (SomeSymWordN),
    SymBool (SymBool),
    SymIntN (SymIntN),
    SymInteger (SymInteger),
    SymWordN (SymWordN),
    type (-~>) (SymGeneralFun),
    type (=~>) (SymTabularFun),
  )

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> import Grisette.Lib.Base
-- >>> import Data.HashSet as HashSet
-- >>> import Data.List (sort)

-- | Extracts all the symbolic variables that are transitively contained in the given value.
--
-- >>> extractSymbolics ("a" :: SymBool) :: SymbolSet
-- SymbolSet {a :: Bool}
--
-- >>> extractSymbolics (mrgIf "a" (mrgReturn ["b"]) (mrgReturn ["c", "d"]) :: UnionM [SymBool]) :: SymbolSet
-- SymbolSet {a :: Bool, b :: Bool, c :: Bool, d :: Bool}
--
-- __Note 1:__ This type class can be derived for algebraic data types.
-- You may need the @DerivingVia@ and @DerivingStrategies@ extensions.
--
-- > data X = ... deriving Generic deriving ExtractSymbolics via (Default X)
class ExtractSymbolics a where
  extractSymbolics :: a -> SymbolSet

-- instances
#define CONCRETE_EXTRACT_SYMBOLICS(type) \
instance ExtractSymbolics type where \
  extractSymbolics _ = mempty

#define CONCRETE_EXTRACT_SYMBOLICS_BV(type) \
instance (KnownNat n, 1 <= n) => ExtractSymbolics (type n) where \
  extractSymbolics _ = mempty

#if 1
CONCRETE_EXTRACT_SYMBOLICS(Bool)
CONCRETE_EXTRACT_SYMBOLICS(Integer)
CONCRETE_EXTRACT_SYMBOLICS(Char)
CONCRETE_EXTRACT_SYMBOLICS(Int)
CONCRETE_EXTRACT_SYMBOLICS(Int8)
CONCRETE_EXTRACT_SYMBOLICS(Int16)
CONCRETE_EXTRACT_SYMBOLICS(Int32)
CONCRETE_EXTRACT_SYMBOLICS(Int64)
CONCRETE_EXTRACT_SYMBOLICS(Word)
CONCRETE_EXTRACT_SYMBOLICS(Word8)
CONCRETE_EXTRACT_SYMBOLICS(Word16)
CONCRETE_EXTRACT_SYMBOLICS(Word32)
CONCRETE_EXTRACT_SYMBOLICS(Word64)
CONCRETE_EXTRACT_SYMBOLICS(SomeWordN)
CONCRETE_EXTRACT_SYMBOLICS(SomeIntN)
CONCRETE_EXTRACT_SYMBOLICS(B.ByteString)
CONCRETE_EXTRACT_SYMBOLICS(T.Text)
CONCRETE_EXTRACT_SYMBOLICS_BV(WordN)
CONCRETE_EXTRACT_SYMBOLICS_BV(IntN)
#endif

-- ()
instance ExtractSymbolics () where
  extractSymbolics :: () -> SymbolSet
extractSymbolics ()
_ = SymbolSet
forall a. Monoid a => a
mempty

-- Either
deriving via
  (Default (Either a b))
  instance
    (ExtractSymbolics a, ExtractSymbolics b) =>
    ExtractSymbolics (Either a b)

-- Maybe
deriving via
  (Default (Maybe a))
  instance
    (ExtractSymbolics a) => ExtractSymbolics (Maybe a)

-- List
deriving via
  (Default [a])
  instance
    (ExtractSymbolics a) => ExtractSymbolics [a]

-- (,)
deriving via
  (Default (a, b))
  instance
    (ExtractSymbolics a, ExtractSymbolics b) =>
    ExtractSymbolics (a, b)

-- (,,)
deriving via
  (Default (a, b, c))
  instance
    (ExtractSymbolics a, ExtractSymbolics b, ExtractSymbolics c) =>
    ExtractSymbolics (a, b, c)

-- (,,,)
deriving via
  (Default (a, b, c, d))
  instance
    ( ExtractSymbolics a,
      ExtractSymbolics b,
      ExtractSymbolics c,
      ExtractSymbolics d
    ) =>
    ExtractSymbolics (a, b, c, d)

-- (,,,,)
deriving via
  (Default (a, b, c, d, e))
  instance
    ( ExtractSymbolics a,
      ExtractSymbolics b,
      ExtractSymbolics c,
      ExtractSymbolics d,
      ExtractSymbolics e
    ) =>
    ExtractSymbolics (a, b, c, d, e)

-- (,,,,,)
deriving via
  (Default (a, b, c, d, e, f))
  instance
    ( ExtractSymbolics a,
      ExtractSymbolics b,
      ExtractSymbolics c,
      ExtractSymbolics d,
      ExtractSymbolics e,
      ExtractSymbolics f
    ) =>
    ExtractSymbolics (a, b, c, d, e, f)

-- (,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g))
  instance
    ( ExtractSymbolics a,
      ExtractSymbolics b,
      ExtractSymbolics c,
      ExtractSymbolics d,
      ExtractSymbolics e,
      ExtractSymbolics f,
      ExtractSymbolics g
    ) =>
    ExtractSymbolics (a, b, c, d, e, f, g)

-- (,,,,,,,)
deriving via
  (Default (a, b, c, d, e, f, g, h))
  instance
    ( ExtractSymbolics a,
      ExtractSymbolics b,
      ExtractSymbolics c,
      ExtractSymbolics d,
      ExtractSymbolics e,
      ExtractSymbolics f,
      ExtractSymbolics g,
      ExtractSymbolics h
    ) =>
    ExtractSymbolics (a, b, c, d, e, f, g, h)

-- MaybeT
instance (ExtractSymbolics (m (Maybe a))) => ExtractSymbolics (MaybeT m a) where
  extractSymbolics :: MaybeT m a -> SymbolSet
extractSymbolics (MaybeT m (Maybe a)
v) = m (Maybe a) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (Maybe a)
v

-- ExceptT
instance
  (ExtractSymbolics (m (Either e a))) =>
  ExtractSymbolics (ExceptT e m a)
  where
  extractSymbolics :: ExceptT e m a -> SymbolSet
extractSymbolics (ExceptT m (Either e a)
v) = m (Either e a) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (Either e a)
v

-- Sum
deriving via
  (Default (Sum f g a))
  instance
    (ExtractSymbolics (f a), ExtractSymbolics (g a)) =>
    ExtractSymbolics (Sum f g a)

-- WriterT
instance
  (ExtractSymbolics (m (a, s))) =>
  ExtractSymbolics (WriterLazy.WriterT s m a)
  where
  extractSymbolics :: WriterT s m a -> SymbolSet
extractSymbolics (WriterLazy.WriterT m (a, s)
f) = m (a, s) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (a, s)
f

instance
  (ExtractSymbolics (m (a, s))) =>
  ExtractSymbolics (WriterStrict.WriterT s m a)
  where
  extractSymbolics :: WriterT s m a -> SymbolSet
extractSymbolics (WriterStrict.WriterT m (a, s)
f) = m (a, s) -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m (a, s)
f

-- Identity
instance (ExtractSymbolics a) => ExtractSymbolics (Identity a) where
  extractSymbolics :: Identity a -> SymbolSet
extractSymbolics (Identity a
a) = a -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics a
a

-- IdentityT
instance (ExtractSymbolics (m a)) => ExtractSymbolics (IdentityT m a) where
  extractSymbolics :: IdentityT m a -> SymbolSet
extractSymbolics (IdentityT m a
a) = m a -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics m a
a

#define EXTRACT_SYMBOLICS_SIMPLE(symtype) \
instance ExtractSymbolics symtype where \
  extractSymbolics (symtype t) = SymbolSet $ extractSymbolicsTerm t

#define EXTRACT_SYMBOLICS_BV(symtype) \
instance (KnownNat n, 1 <= n) => ExtractSymbolics (symtype n) where \
  extractSymbolics (symtype t) = SymbolSet $ extractSymbolicsTerm t

#define EXTRACT_SYMBOLICS_FUN(op, cons) \
instance (SupportedPrim ca, SupportedPrim cb, LinkedRep ca sa, LinkedRep cb sb) => ExtractSymbolics (sa op sb) where \
  extractSymbolics (cons t) = SymbolSet $ extractSymbolicsTerm t

#define EXTRACT_SYMBOLICS_BV_SOME(somety, origty) \
instance ExtractSymbolics somety where \
  extractSymbolics (somety (origty t)) = SymbolSet $ extractSymbolicsTerm t

#if 1
EXTRACT_SYMBOLICS_SIMPLE(SymBool)
EXTRACT_SYMBOLICS_SIMPLE(SymInteger)
EXTRACT_SYMBOLICS_BV(SymIntN)
EXTRACT_SYMBOLICS_BV(SymWordN)
EXTRACT_SYMBOLICS_FUN(=~>, SymTabularFun)
EXTRACT_SYMBOLICS_FUN(-~>, SymGeneralFun)
EXTRACT_SYMBOLICS_BV_SOME(SomeSymIntN, SymIntN)
EXTRACT_SYMBOLICS_BV_SOME(SomeSymWordN, SymWordN)
#endif

-- Exception
deriving via (Default AssertionError) instance ExtractSymbolics AssertionError

deriving via (Default VerificationConditions) instance ExtractSymbolics VerificationConditions

instance (Generic a, ExtractSymbolics' (Rep a)) => ExtractSymbolics (Default a) where
  extractSymbolics :: Default a -> SymbolSet
extractSymbolics = Rep a Any -> SymbolSet
forall c. Rep a c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' (Rep a Any -> SymbolSet)
-> (Default a -> Rep a Any) -> Default a -> SymbolSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Rep a Any
forall x. a -> Rep a x
forall a x. Generic a => a -> Rep a x
from (a -> Rep a Any) -> (Default a -> a) -> Default a -> Rep a Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Default a -> a
forall a. Default a -> a
unDefault

class ExtractSymbolics' a where
  extractSymbolics' :: a c -> SymbolSet

instance ExtractSymbolics' U1 where
  extractSymbolics' :: forall c. U1 c -> SymbolSet
extractSymbolics' U1 c
_ = SymbolSet
forall a. Monoid a => a
mempty

instance (ExtractSymbolics c) => ExtractSymbolics' (K1 i c) where
  extractSymbolics' :: forall c. K1 i c c -> SymbolSet
extractSymbolics' = c -> SymbolSet
forall a. ExtractSymbolics a => a -> SymbolSet
extractSymbolics (c -> SymbolSet) -> (K1 i c c -> c) -> K1 i c c -> SymbolSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. K1 i c c -> c
forall k i c (p :: k). K1 i c p -> c
unK1

instance (ExtractSymbolics' a) => ExtractSymbolics' (M1 i c a) where
  extractSymbolics' :: forall c. M1 i c a c -> SymbolSet
extractSymbolics' = a c -> SymbolSet
forall c. a c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' (a c -> SymbolSet)
-> (M1 i c a c -> a c) -> M1 i c a c -> SymbolSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. M1 i c a c -> a c
forall k i (c :: Meta) (f :: k -> *) (p :: k). M1 i c f p -> f p
unM1

instance
  (ExtractSymbolics' a, ExtractSymbolics' b) =>
  ExtractSymbolics' (a :+: b)
  where
  extractSymbolics' :: forall c. (:+:) a b c -> SymbolSet
extractSymbolics' (L1 a c
l) = a c -> SymbolSet
forall c. a c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' a c
l
  extractSymbolics' (R1 b c
r) = b c -> SymbolSet
forall c. b c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' b c
r

instance
  (ExtractSymbolics' a, ExtractSymbolics' b) =>
  ExtractSymbolics' (a :*: b)
  where
  extractSymbolics' :: forall c. (:*:) a b c -> SymbolSet
extractSymbolics' (a c
l :*: b c
r) = a c -> SymbolSet
forall c. a c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' a c
l SymbolSet -> SymbolSet -> SymbolSet
forall a. Semigroup a => a -> a -> a
<> b c -> SymbolSet
forall c. b c -> SymbolSet
forall (a :: * -> *) c. ExtractSymbolics' a => a c -> SymbolSet
extractSymbolics' b c
r