{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module      :   Grisette.Internal.Backend.SymBiMap
-- Copyright   :   (c) Sirui Lu 2021-2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Backend.SymBiMap
  ( SymBiMap (..),
    emptySymBiMap,
    sizeBiMap,
    addBiMap,
    addBiMapIntermediate,
    findStringToSymbol,
    lookupTerm,
    QuantifiedSymbolInfo (..),
    attachNextQuantifiedSymbolInfo,
  )
where

import Control.DeepSeq (NFData)
import Data.Dynamic (Dynamic)
import qualified Data.HashMap.Strict as M
import Data.Hashable (Hashable)
import GHC.Generics (Generic)
import GHC.Stack (HasCallStack)
import Grisette.Internal.Backend.QuantifiedStack (QuantifiedStack)
import Grisette.Internal.Core.Data.Symbol
  ( Symbol (IndexedSymbol, SimpleSymbol),
    withInfo,
  )
import Grisette.Internal.SymPrim.Prim.SomeTerm
  ( SomeTerm,
  )
import Grisette.Internal.SymPrim.Prim.Term
  ( IsSymbolKind,
    SomeTypedAnySymbol,
    SomeTypedSymbol,
    TypedConstantSymbol,
    TypedSymbol (TypedSymbol),
    castSomeTypedSymbol,
  )
import Language.Haskell.TH.Syntax (Lift)

-- | A bidirectional map between symbolic Grisette terms and sbv terms.
data SymBiMap = SymBiMap
  { SymBiMap -> HashMap SomeTerm (QuantifiedStack -> Dynamic)
biMapToSBV :: M.HashMap SomeTerm (QuantifiedStack -> Dynamic),
    SymBiMap -> HashMap String SomeTypedAnySymbol
biMapFromSBV :: M.HashMap String SomeTypedAnySymbol,
    SymBiMap -> Int
quantifiedSymbolNum :: Int
  }

-- | Information about a quantified symbol.
newtype QuantifiedSymbolInfo = QuantifiedSymbolInfo Int
  deriving ((forall x. QuantifiedSymbolInfo -> Rep QuantifiedSymbolInfo x)
-> (forall x. Rep QuantifiedSymbolInfo x -> QuantifiedSymbolInfo)
-> Generic QuantifiedSymbolInfo
forall x. Rep QuantifiedSymbolInfo x -> QuantifiedSymbolInfo
forall x. QuantifiedSymbolInfo -> Rep QuantifiedSymbolInfo x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. QuantifiedSymbolInfo -> Rep QuantifiedSymbolInfo x
from :: forall x. QuantifiedSymbolInfo -> Rep QuantifiedSymbolInfo x
$cto :: forall x. Rep QuantifiedSymbolInfo x -> QuantifiedSymbolInfo
to :: forall x. Rep QuantifiedSymbolInfo x -> QuantifiedSymbolInfo
Generic, Eq QuantifiedSymbolInfo
Eq QuantifiedSymbolInfo =>
(QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Ordering)
-> (QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool)
-> (QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool)
-> (QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool)
-> (QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool)
-> (QuantifiedSymbolInfo
    -> QuantifiedSymbolInfo -> QuantifiedSymbolInfo)
-> (QuantifiedSymbolInfo
    -> QuantifiedSymbolInfo -> QuantifiedSymbolInfo)
-> Ord QuantifiedSymbolInfo
QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Ordering
QuantifiedSymbolInfo
-> QuantifiedSymbolInfo -> QuantifiedSymbolInfo
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Ordering
compare :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Ordering
$c< :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
< :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
$c<= :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
<= :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
$c> :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
> :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
$c>= :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
>= :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
$cmax :: QuantifiedSymbolInfo
-> QuantifiedSymbolInfo -> QuantifiedSymbolInfo
max :: QuantifiedSymbolInfo
-> QuantifiedSymbolInfo -> QuantifiedSymbolInfo
$cmin :: QuantifiedSymbolInfo
-> QuantifiedSymbolInfo -> QuantifiedSymbolInfo
min :: QuantifiedSymbolInfo
-> QuantifiedSymbolInfo -> QuantifiedSymbolInfo
Ord, QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
(QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool)
-> (QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool)
-> Eq QuantifiedSymbolInfo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
== :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
$c/= :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
/= :: QuantifiedSymbolInfo -> QuantifiedSymbolInfo -> Bool
Eq, Int -> QuantifiedSymbolInfo -> ShowS
[QuantifiedSymbolInfo] -> ShowS
QuantifiedSymbolInfo -> String
(Int -> QuantifiedSymbolInfo -> ShowS)
-> (QuantifiedSymbolInfo -> String)
-> ([QuantifiedSymbolInfo] -> ShowS)
-> Show QuantifiedSymbolInfo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> QuantifiedSymbolInfo -> ShowS
showsPrec :: Int -> QuantifiedSymbolInfo -> ShowS
$cshow :: QuantifiedSymbolInfo -> String
show :: QuantifiedSymbolInfo -> String
$cshowList :: [QuantifiedSymbolInfo] -> ShowS
showList :: [QuantifiedSymbolInfo] -> ShowS
Show, Eq QuantifiedSymbolInfo
Eq QuantifiedSymbolInfo =>
(Int -> QuantifiedSymbolInfo -> Int)
-> (QuantifiedSymbolInfo -> Int) -> Hashable QuantifiedSymbolInfo
Int -> QuantifiedSymbolInfo -> Int
QuantifiedSymbolInfo -> Int
forall a. Eq a => (Int -> a -> Int) -> (a -> Int) -> Hashable a
$chashWithSalt :: Int -> QuantifiedSymbolInfo -> Int
hashWithSalt :: Int -> QuantifiedSymbolInfo -> Int
$chash :: QuantifiedSymbolInfo -> Int
hash :: QuantifiedSymbolInfo -> Int
Hashable, (forall (m :: * -> *). Quote m => QuantifiedSymbolInfo -> m Exp)
-> (forall (m :: * -> *).
    Quote m =>
    QuantifiedSymbolInfo -> Code m QuantifiedSymbolInfo)
-> Lift QuantifiedSymbolInfo
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => QuantifiedSymbolInfo -> m Exp
forall (m :: * -> *).
Quote m =>
QuantifiedSymbolInfo -> Code m QuantifiedSymbolInfo
$clift :: forall (m :: * -> *). Quote m => QuantifiedSymbolInfo -> m Exp
lift :: forall (m :: * -> *). Quote m => QuantifiedSymbolInfo -> m Exp
$cliftTyped :: forall (m :: * -> *).
Quote m =>
QuantifiedSymbolInfo -> Code m QuantifiedSymbolInfo
liftTyped :: forall (m :: * -> *).
Quote m =>
QuantifiedSymbolInfo -> Code m QuantifiedSymbolInfo
Lift, QuantifiedSymbolInfo -> ()
(QuantifiedSymbolInfo -> ()) -> NFData QuantifiedSymbolInfo
forall a. (a -> ()) -> NFData a
$crnf :: QuantifiedSymbolInfo -> ()
rnf :: QuantifiedSymbolInfo -> ()
NFData)

nextQuantifiedSymbolInfo :: SymBiMap -> (SymBiMap, QuantifiedSymbolInfo)
nextQuantifiedSymbolInfo :: SymBiMap -> (SymBiMap, QuantifiedSymbolInfo)
nextQuantifiedSymbolInfo (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t HashMap String SomeTypedAnySymbol
f Int
num) =
  (HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t HashMap String SomeTypedAnySymbol
f (Int
num Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1), Int -> QuantifiedSymbolInfo
QuantifiedSymbolInfo Int
num)

attachQuantifiedSymbolInfo ::
  QuantifiedSymbolInfo -> TypedConstantSymbol a -> TypedConstantSymbol a
attachQuantifiedSymbolInfo :: forall a.
QuantifiedSymbolInfo
-> TypedConstantSymbol a -> TypedConstantSymbol a
attachQuantifiedSymbolInfo
  QuantifiedSymbolInfo
info
  (TypedSymbol (SimpleSymbol Identifier
ident)) =
    Symbol -> TypedSymbol 'ConstantKind a
forall t (knd :: SymbolKind).
(SupportedPrim t, SymbolKindConstraint knd t, IsSymbolKind knd) =>
Symbol -> TypedSymbol knd t
TypedSymbol (Symbol -> TypedSymbol 'ConstantKind a)
-> Symbol -> TypedSymbol 'ConstantKind a
forall a b. (a -> b) -> a -> b
$ Identifier -> Symbol
SimpleSymbol (Identifier -> Symbol) -> Identifier -> Symbol
forall a b. (a -> b) -> a -> b
$ Identifier -> QuantifiedSymbolInfo -> Identifier
forall a.
(Typeable a, Ord a, Lift a, NFData a, Show a, Hashable a) =>
Identifier -> a -> Identifier
withInfo Identifier
ident QuantifiedSymbolInfo
info
attachQuantifiedSymbolInfo
  QuantifiedSymbolInfo
info
  (TypedSymbol (IndexedSymbol Identifier
ident Int
idx)) =
    Symbol -> TypedSymbol 'ConstantKind a
forall t (knd :: SymbolKind).
(SupportedPrim t, SymbolKindConstraint knd t, IsSymbolKind knd) =>
Symbol -> TypedSymbol knd t
TypedSymbol (Symbol -> TypedSymbol 'ConstantKind a)
-> Symbol -> TypedSymbol 'ConstantKind a
forall a b. (a -> b) -> a -> b
$ Identifier -> Int -> Symbol
IndexedSymbol (Identifier -> QuantifiedSymbolInfo -> Identifier
forall a.
(Typeable a, Ord a, Lift a, NFData a, Show a, Hashable a) =>
Identifier -> a -> Identifier
withInfo Identifier
ident QuantifiedSymbolInfo
info) Int
idx

-- | Attach the next quantified symbol info to a symbol.
attachNextQuantifiedSymbolInfo ::
  SymBiMap -> TypedConstantSymbol a -> (SymBiMap, TypedConstantSymbol a)
attachNextQuantifiedSymbolInfo :: forall a.
SymBiMap
-> TypedConstantSymbol a -> (SymBiMap, TypedConstantSymbol a)
attachNextQuantifiedSymbolInfo SymBiMap
m TypedConstantSymbol a
s =
  let (SymBiMap
m', QuantifiedSymbolInfo
info) = SymBiMap -> (SymBiMap, QuantifiedSymbolInfo)
nextQuantifiedSymbolInfo SymBiMap
m
   in (SymBiMap
m', QuantifiedSymbolInfo
-> TypedConstantSymbol a -> TypedConstantSymbol a
forall a.
QuantifiedSymbolInfo
-> TypedConstantSymbol a -> TypedConstantSymbol a
attachQuantifiedSymbolInfo QuantifiedSymbolInfo
info TypedConstantSymbol a
s)

-- | An empty bidirectional map.
emptySymBiMap :: SymBiMap
emptySymBiMap :: SymBiMap
emptySymBiMap = HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
forall k v. HashMap k v
M.empty HashMap String SomeTypedAnySymbol
forall k v. HashMap k v
M.empty Int
0

-- | The size of the bidirectional map.
sizeBiMap :: SymBiMap -> Int
sizeBiMap :: SymBiMap -> Int
sizeBiMap = HashMap SomeTerm (QuantifiedStack -> Dynamic) -> Int
forall k v. HashMap k v -> Int
M.size (HashMap SomeTerm (QuantifiedStack -> Dynamic) -> Int)
-> (SymBiMap -> HashMap SomeTerm (QuantifiedStack -> Dynamic))
-> SymBiMap
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SymBiMap -> HashMap SomeTerm (QuantifiedStack -> Dynamic)
biMapToSBV

-- | Add a new entry to the bidirectional map.
addBiMap ::
  (HasCallStack) =>
  SomeTerm ->
  Dynamic ->
  String ->
  SomeTypedSymbol knd ->
  SymBiMap ->
  SymBiMap
addBiMap :: forall (knd :: SymbolKind).
HasCallStack =>
SomeTerm
-> Dynamic -> String -> SomeTypedSymbol knd -> SymBiMap -> SymBiMap
addBiMap SomeTerm
s Dynamic
d String
n SomeTypedSymbol knd
sb (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t HashMap String SomeTypedAnySymbol
f Int
num) =
  case SomeTypedSymbol knd -> Maybe SomeTypedAnySymbol
forall (knd' :: SymbolKind) (knd :: SymbolKind).
IsSymbolKind knd' =>
SomeTypedSymbol knd -> Maybe (SomeTypedSymbol knd')
castSomeTypedSymbol SomeTypedSymbol knd
sb of
    Just SomeTypedAnySymbol
sb' -> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap (SomeTerm
-> (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert SomeTerm
s (Dynamic -> QuantifiedStack -> Dynamic
forall a b. a -> b -> a
const Dynamic
d) HashMap SomeTerm (QuantifiedStack -> Dynamic)
t) (String
-> SomeTypedAnySymbol
-> HashMap String SomeTypedAnySymbol
-> HashMap String SomeTypedAnySymbol
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert String
n SomeTypedAnySymbol
sb' HashMap String SomeTypedAnySymbol
f) Int
num
    Maybe SomeTypedAnySymbol
_ -> String -> SymBiMap
forall a. HasCallStack => String -> a
error String
"Casting to AnySymbol, should not fail"

-- | Add a new entry to the bidirectional map for intermediate values.
addBiMapIntermediate ::
  (HasCallStack) => SomeTerm -> (QuantifiedStack -> Dynamic) -> SymBiMap -> SymBiMap
addBiMapIntermediate :: HasCallStack =>
SomeTerm -> (QuantifiedStack -> Dynamic) -> SymBiMap -> SymBiMap
addBiMapIntermediate SomeTerm
s QuantifiedStack -> Dynamic
d (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
t HashMap String SomeTypedAnySymbol
f Int
num) = HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap String SomeTypedAnySymbol -> Int -> SymBiMap
SymBiMap (SomeTerm
-> (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
M.insert SomeTerm
s QuantifiedStack -> Dynamic
d HashMap SomeTerm (QuantifiedStack -> Dynamic)
t) HashMap String SomeTypedAnySymbol
f Int
num

-- | Find a symbolic Grisette term from a string.
findStringToSymbol :: (IsSymbolKind knd) => String -> SymBiMap -> Maybe (SomeTypedSymbol knd)
findStringToSymbol :: forall (knd :: SymbolKind).
IsSymbolKind knd =>
String -> SymBiMap -> Maybe (SomeTypedSymbol knd)
findStringToSymbol String
s (SymBiMap HashMap SomeTerm (QuantifiedStack -> Dynamic)
_ HashMap String SomeTypedAnySymbol
f Int
_) = do
  SomeTypedAnySymbol
r <- String
-> HashMap String SomeTypedAnySymbol -> Maybe SomeTypedAnySymbol
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup String
s HashMap String SomeTypedAnySymbol
f
  SomeTypedAnySymbol -> Maybe (SomeTypedSymbol knd)
forall (knd' :: SymbolKind) (knd :: SymbolKind).
IsSymbolKind knd' =>
SomeTypedSymbol knd -> Maybe (SomeTypedSymbol knd')
castSomeTypedSymbol SomeTypedAnySymbol
r

-- | Look up an sbv value with a symbolic Grisette term in the bidirectional
-- map.
lookupTerm :: (HasCallStack) => SomeTerm -> SymBiMap -> Maybe (QuantifiedStack -> Dynamic)
lookupTerm :: HasCallStack =>
SomeTerm -> SymBiMap -> Maybe (QuantifiedStack -> Dynamic)
lookupTerm SomeTerm
t SymBiMap
m = SomeTerm
-> HashMap SomeTerm (QuantifiedStack -> Dynamic)
-> Maybe (QuantifiedStack -> Dynamic)
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
M.lookup SomeTerm
t (SymBiMap -> HashMap SomeTerm (QuantifiedStack -> Dynamic)
biMapToSBV SymBiMap
m)