{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# HLINT ignore "Unused LANGUAGE pragma" #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Symbol
-- Copyright   :   (c) Sirui Lu 2024
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Symbol
  ( Identifier (..),
    identifier,
    withMetadata,
    withLocation,
    mapMetadata,
    uniqueIdentifier,
    Symbol (..),
    simple,
    indexed,
    symbolIdentifier,
    mapIdentifier,
  )
where

import Control.DeepSeq (NFData)
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import Data.Hashable (Hashable (hashWithSalt))
import Data.IORef (IORef, atomicModifyIORef', newIORef)
import qualified Data.Serialize as Cereal
import Data.String (IsString (fromString))
import qualified Data.Text as T
import GHC.Generics (Generic)
import GHC.IO (unsafePerformIO)
import Grisette.Internal.Core.Data.SExpr
  ( SExpr (Atom, List, NumberAtom),
    fileLocation,
    showsSExprWithParens,
  )
import Language.Haskell.TH.Syntax (Lift)
import Language.Haskell.TH.Syntax.Compat (SpliceQ)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim

-- | Identifier type used for 'Grisette.Core.GenSym'
--
-- The constructor is hidden intentionally.
-- You can construct an identifier by:
--
--   * a raw identifier
--
--     The following two expressions will refer to the same identifier (the
--     solver won't distinguish them and would assign the same value to them).
--     The user may need to use unique names to avoid unintentional identifier
--     collision.
--
--     >>> identifier "a"
--     a
--
--     >>> "a" :: Identifier -- available when OverloadedStrings is enabled
--     a
--
--   * bundle the identifier with some user provided metadata
--
--     Identifiers created with different name or different additional
--     information will not be the same.
--
--     >>> withMetadata "a" (NumberAtom 1)
--     a:1
--
--   * bundle the calling file location with the identifier to ensure global
--     uniqueness
--
--     Identifiers created at different locations will not be the
--     same. The identifiers created at the same location will be the same.
--
--     >>> $$(withLocation "a") -- a sample result could be "a:[grisette-file-location <interactive> 18 (4 18)]"
--     a:[grisette-file-location <interactive>...]
data Identifier = Identifier {Identifier -> Text
baseIdent :: T.Text, Identifier -> SExpr
metadata :: SExpr}
  deriving (Identifier -> Identifier -> Bool
(Identifier -> Identifier -> Bool)
-> (Identifier -> Identifier -> Bool) -> Eq Identifier
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Identifier -> Identifier -> Bool
== :: Identifier -> Identifier -> Bool
$c/= :: Identifier -> Identifier -> Bool
/= :: Identifier -> Identifier -> Bool
Eq, Eq Identifier
Eq Identifier =>
(Identifier -> Identifier -> Ordering)
-> (Identifier -> Identifier -> Bool)
-> (Identifier -> Identifier -> Bool)
-> (Identifier -> Identifier -> Bool)
-> (Identifier -> Identifier -> Bool)
-> (Identifier -> Identifier -> Identifier)
-> (Identifier -> Identifier -> Identifier)
-> Ord Identifier
Identifier -> Identifier -> Bool
Identifier -> Identifier -> Ordering
Identifier -> Identifier -> Identifier
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Identifier -> Identifier -> Ordering
compare :: Identifier -> Identifier -> Ordering
$c< :: Identifier -> Identifier -> Bool
< :: Identifier -> Identifier -> Bool
$c<= :: Identifier -> Identifier -> Bool
<= :: Identifier -> Identifier -> Bool
$c> :: Identifier -> Identifier -> Bool
> :: Identifier -> Identifier -> Bool
$c>= :: Identifier -> Identifier -> Bool
>= :: Identifier -> Identifier -> Bool
$cmax :: Identifier -> Identifier -> Identifier
max :: Identifier -> Identifier -> Identifier
$cmin :: Identifier -> Identifier -> Identifier
min :: Identifier -> Identifier -> Identifier
Ord, (forall x. Identifier -> Rep Identifier x)
-> (forall x. Rep Identifier x -> Identifier) -> Generic Identifier
forall x. Rep Identifier x -> Identifier
forall x. Identifier -> Rep Identifier x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Identifier -> Rep Identifier x
from :: forall x. Identifier -> Rep Identifier x
$cto :: forall x. Rep Identifier x -> Identifier
to :: forall x. Rep Identifier x -> Identifier
Generic, (forall (m :: * -> *). Quote m => Identifier -> m Exp)
-> (forall (m :: * -> *).
    Quote m =>
    Identifier -> Code m Identifier)
-> Lift Identifier
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => Identifier -> m Exp
forall (m :: * -> *). Quote m => Identifier -> Code m Identifier
$clift :: forall (m :: * -> *). Quote m => Identifier -> m Exp
lift :: forall (m :: * -> *). Quote m => Identifier -> m Exp
$cliftTyped :: forall (m :: * -> *). Quote m => Identifier -> Code m Identifier
liftTyped :: forall (m :: * -> *). Quote m => Identifier -> Code m Identifier
Lift)
  deriving anyclass (Eq Identifier
Eq Identifier =>
(Int -> Identifier -> Int)
-> (Identifier -> Int) -> Hashable Identifier
Int -> Identifier -> Int
Identifier -> Int
forall a. Eq a => (Int -> a -> Int) -> (a -> Int) -> Hashable a
$chashWithSalt :: Int -> Identifier -> Int
hashWithSalt :: Int -> Identifier -> Int
$chash :: Identifier -> Int
hash :: Identifier -> Int
Hashable, Identifier -> ()
(Identifier -> ()) -> NFData Identifier
forall a. (a -> ()) -> NFData a
$crnf :: Identifier -> ()
rnf :: Identifier -> ()
NFData, (forall (m :: * -> *). MonadPut m => Identifier -> m ())
-> (forall (m :: * -> *). MonadGet m => m Identifier)
-> Serial Identifier
forall a.
(forall (m :: * -> *). MonadPut m => a -> m ())
-> (forall (m :: * -> *). MonadGet m => m a) -> Serial a
forall (m :: * -> *). MonadGet m => m Identifier
forall (m :: * -> *). MonadPut m => Identifier -> m ()
$cserialize :: forall (m :: * -> *). MonadPut m => Identifier -> m ()
serialize :: forall (m :: * -> *). MonadPut m => Identifier -> m ()
$cdeserialize :: forall (m :: * -> *). MonadGet m => m Identifier
deserialize :: forall (m :: * -> *). MonadGet m => m Identifier
Serial)

instance Cereal.Serialize Identifier where
  put :: Putter Identifier
put = Putter Identifier
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Identifier -> m ()
serialize
  get :: Get Identifier
get = Get Identifier
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m Identifier
deserialize

instance Binary.Binary Identifier where
  put :: Identifier -> Put
put = Identifier -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Identifier -> m ()
serialize
  get :: Get Identifier
get = Get Identifier
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m Identifier
deserialize

instance Show Identifier where
  showsPrec :: Int -> Identifier -> ShowS
showsPrec Int
_ (Identifier Text
i (List [])) = String -> ShowS
showString (Text -> String
T.unpack Text
i)
  showsPrec Int
_ (Identifier Text
i SExpr
metadata) =
    String -> ShowS
showString (Text -> String
T.unpack Text
i)
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
":"
      ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Char -> SExpr -> ShowS
showsSExprWithParens Char
'[' Char
']' SExpr
metadata

instance IsString Identifier where
  fromString :: String -> Identifier
fromString String
i = Text -> SExpr -> Identifier
Identifier (String -> Text
T.pack String
i) (SExpr -> Identifier) -> SExpr -> Identifier
forall a b. (a -> b) -> a -> b
$ [SExpr] -> SExpr
List []

-- | Simple identifier.
-- The same identifier refers to the same symbolic variable in the whole
-- program.
--
-- The user may need to use unique identifiers to avoid unintentional identifier
-- collision.
identifier :: T.Text -> Identifier
identifier :: Text -> Identifier
identifier = (Text -> SExpr -> Identifier) -> SExpr -> Text -> Identifier
forall a b c. (a -> b -> c) -> b -> a -> c
flip Text -> SExpr -> Identifier
Identifier (SExpr -> Text -> Identifier) -> SExpr -> Text -> Identifier
forall a b. (a -> b) -> a -> b
$ [SExpr] -> SExpr
List []

-- | Identifier with extra metadata.
--
-- The same identifier with the same metadata refers to the same symbolic
-- variable in the whole program.
--
-- The user may need to use unique identifiers or additional metadata to
-- avoid unintentional identifier collision.
withMetadata :: T.Text -> SExpr -> Identifier
withMetadata :: Text -> SExpr -> Identifier
withMetadata = Text -> SExpr -> Identifier
Identifier

-- | Identifier with the file location.
withLocation :: T.Text -> SpliceQ Identifier
withLocation :: Text -> SpliceQ Identifier
withLocation Text
nm = [||Text -> SExpr -> Identifier
withMetadata Text
nm $$SpliceQ SExpr
fileLocation||]

-- | Modify the metadata of an identifier.
mapMetadata :: (SExpr -> SExpr) -> Identifier -> Identifier
mapMetadata :: (SExpr -> SExpr) -> Identifier -> Identifier
mapMetadata SExpr -> SExpr
f (Identifier Text
i SExpr
m) = Text -> SExpr -> Identifier
Identifier Text
i (SExpr -> SExpr
f SExpr
m)

identifierCount :: IORef Int
identifierCount :: IORef Int
identifierCount = IO (IORef Int) -> IORef Int
forall a. IO a -> a
unsafePerformIO (IO (IORef Int) -> IORef Int) -> IO (IORef Int) -> IORef Int
forall a b. (a -> b) -> a -> b
$ Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0
{-# NOINLINE identifierCount #-}

-- | Get a globally unique identifier within the 'IO' monad.
uniqueIdentifier :: T.Text -> IO Identifier
uniqueIdentifier :: Text -> IO Identifier
uniqueIdentifier Text
ident = do
  Int
i <- IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
identifierCount (\Int
x -> (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
x))
  Identifier -> IO Identifier
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Identifier -> IO Identifier) -> Identifier -> IO Identifier
forall a b. (a -> b) -> a -> b
$
    Text -> SExpr -> Identifier
withMetadata
      Text
ident
      ([SExpr] -> SExpr
List [Text -> SExpr
Atom Text
"grisette-unique", Integer -> SExpr
NumberAtom (Integer -> SExpr) -> Integer -> SExpr
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
i])

-- | Symbol types for a symbolic variable.
--
-- The symbols can be indexed with an integer.
data Symbol where
  SimpleSymbol :: Identifier -> Symbol
  IndexedSymbol :: Identifier -> Int -> Symbol
  deriving (Symbol -> Symbol -> Bool
(Symbol -> Symbol -> Bool)
-> (Symbol -> Symbol -> Bool) -> Eq Symbol
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Symbol -> Symbol -> Bool
== :: Symbol -> Symbol -> Bool
$c/= :: Symbol -> Symbol -> Bool
/= :: Symbol -> Symbol -> Bool
Eq, Eq Symbol
Eq Symbol =>
(Symbol -> Symbol -> Ordering)
-> (Symbol -> Symbol -> Bool)
-> (Symbol -> Symbol -> Bool)
-> (Symbol -> Symbol -> Bool)
-> (Symbol -> Symbol -> Bool)
-> (Symbol -> Symbol -> Symbol)
-> (Symbol -> Symbol -> Symbol)
-> Ord Symbol
Symbol -> Symbol -> Bool
Symbol -> Symbol -> Ordering
Symbol -> Symbol -> Symbol
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Symbol -> Symbol -> Ordering
compare :: Symbol -> Symbol -> Ordering
$c< :: Symbol -> Symbol -> Bool
< :: Symbol -> Symbol -> Bool
$c<= :: Symbol -> Symbol -> Bool
<= :: Symbol -> Symbol -> Bool
$c> :: Symbol -> Symbol -> Bool
> :: Symbol -> Symbol -> Bool
$c>= :: Symbol -> Symbol -> Bool
>= :: Symbol -> Symbol -> Bool
$cmax :: Symbol -> Symbol -> Symbol
max :: Symbol -> Symbol -> Symbol
$cmin :: Symbol -> Symbol -> Symbol
min :: Symbol -> Symbol -> Symbol
Ord, (forall x. Symbol -> Rep Symbol x)
-> (forall x. Rep Symbol x -> Symbol) -> Generic Symbol
forall x. Rep Symbol x -> Symbol
forall x. Symbol -> Rep Symbol x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. Symbol -> Rep Symbol x
from :: forall x. Symbol -> Rep Symbol x
$cto :: forall x. Rep Symbol x -> Symbol
to :: forall x. Rep Symbol x -> Symbol
Generic, (forall (m :: * -> *). Quote m => Symbol -> m Exp)
-> (forall (m :: * -> *). Quote m => Symbol -> Code m Symbol)
-> Lift Symbol
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => Symbol -> m Exp
forall (m :: * -> *). Quote m => Symbol -> Code m Symbol
$clift :: forall (m :: * -> *). Quote m => Symbol -> m Exp
lift :: forall (m :: * -> *). Quote m => Symbol -> m Exp
$cliftTyped :: forall (m :: * -> *). Quote m => Symbol -> Code m Symbol
liftTyped :: forall (m :: * -> *). Quote m => Symbol -> Code m Symbol
Lift, Symbol -> ()
(Symbol -> ()) -> NFData Symbol
forall a. (a -> ()) -> NFData a
$crnf :: Symbol -> ()
rnf :: Symbol -> ()
NFData, (forall (m :: * -> *). MonadPut m => Symbol -> m ())
-> (forall (m :: * -> *). MonadGet m => m Symbol) -> Serial Symbol
forall a.
(forall (m :: * -> *). MonadPut m => a -> m ())
-> (forall (m :: * -> *). MonadGet m => m a) -> Serial a
forall (m :: * -> *). MonadGet m => m Symbol
forall (m :: * -> *). MonadPut m => Symbol -> m ()
$cserialize :: forall (m :: * -> *). MonadPut m => Symbol -> m ()
serialize :: forall (m :: * -> *). MonadPut m => Symbol -> m ()
$cdeserialize :: forall (m :: * -> *). MonadGet m => m Symbol
deserialize :: forall (m :: * -> *). MonadGet m => m Symbol
Serial)

instance Cereal.Serialize Symbol where
  put :: Putter Symbol
put = Putter Symbol
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Symbol -> m ()
serialize
  get :: Get Symbol
get = Get Symbol
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m Symbol
deserialize

instance Binary.Binary Symbol where
  put :: Symbol -> Put
put = Symbol -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => Symbol -> m ()
serialize
  get :: Get Symbol
get = Get Symbol
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m Symbol
deserialize

instance Hashable Symbol where
  hashWithSalt :: Int -> Symbol -> Int
hashWithSalt Int
s (SimpleSymbol Identifier
i) = Int -> Identifier -> Int
forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
s Identifier
i
  hashWithSalt Int
s (IndexedSymbol Identifier
i Int
idx) = Int
s Int -> Identifier -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Identifier
i Int -> Int -> Int
forall a. Hashable a => Int -> a -> Int
`hashWithSalt` Int
idx
  {-# INLINE hashWithSalt #-}

-- | Get the identifier of a symbol.
symbolIdentifier :: Symbol -> Identifier
symbolIdentifier :: Symbol -> Identifier
symbolIdentifier (SimpleSymbol Identifier
i) = Identifier
i
symbolIdentifier (IndexedSymbol Identifier
i Int
_) = Identifier
i

-- | Modify the identifier of a symbol.
mapIdentifier :: (Identifier -> Identifier) -> Symbol -> Symbol
mapIdentifier :: (Identifier -> Identifier) -> Symbol -> Symbol
mapIdentifier Identifier -> Identifier
f (SimpleSymbol Identifier
i) = Identifier -> Symbol
SimpleSymbol (Identifier -> Identifier
f Identifier
i)
mapIdentifier Identifier -> Identifier
f (IndexedSymbol Identifier
i Int
idx) = Identifier -> Int -> Symbol
IndexedSymbol (Identifier -> Identifier
f Identifier
i) Int
idx

instance Show Symbol where
  show :: Symbol -> String
show (SimpleSymbol Identifier
i) = Identifier -> String
forall a. Show a => a -> String
show Identifier
i
  show (IndexedSymbol Identifier
i Int
idx) = Identifier -> String
forall a. Show a => a -> String
show Identifier
i String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"@" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
idx

instance IsString Symbol where
  fromString :: String -> Symbol
fromString = Identifier -> Symbol
SimpleSymbol (Identifier -> Symbol)
-> (String -> Identifier) -> String -> Symbol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Identifier
forall a. IsString a => String -> a
fromString

-- | Create a simple symbol.
simple :: Identifier -> Symbol
simple :: Identifier -> Symbol
simple = Identifier -> Symbol
SimpleSymbol

-- | Create an indexed symbol.
indexed :: Identifier -> Int -> Symbol
indexed :: Identifier -> Int -> Symbol
indexed = Identifier -> Int -> Symbol
IndexedSymbol