{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Internal.Core.Data.Class.Solver
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Internal.Core.Data.Class.Solver
  ( -- * Note for the examples

    --

    -- | The examples assumes that the [z3](https://github.com/Z3Prover/z3)
    -- solver is available in @PATH@.

    -- * Solver interfaces
    SolvingFailure (..),
    MonadicSolver (..),
    monadicSolverSolve,
    SolverCommand (..),
    ConfigurableSolver (..),
    Solver (..),
    solverSolve,
    withSolver,
    solve,
    solverSolveMulti,
    solveMulti,

    -- * Union with exceptions
    UnionWithExcept (..),
    solverSolveExcept,
    solveExcept,
    solverSolveMultiExcept,
    solveMultiExcept,
  )
where

import Control.DeepSeq (NFData)
import Control.Exception (mask, onException)
import Control.Monad.Except (ExceptT, runExceptT)
import qualified Data.Binary as Binary
import Data.Bytes.Serial (Serial (deserialize, serialize))
import qualified Data.HashSet as S
import Data.Hashable (Hashable)
import Data.Maybe (fromJust)
import qualified Data.Serialize as Cereal
import qualified Data.Text as T
import GHC.Generics (Generic)
import Generics.Deriving (Default (Default))
import Grisette.Internal.Core.Data.Class.ExtractSym
  ( ExtractSym (extractSym),
  )
import Grisette.Internal.Core.Data.Class.LogicalOp (LogicalOp (symNot, (.||)))
import Grisette.Internal.Core.Data.Class.PPrint (PPrint)
import Grisette.Internal.Core.Data.Class.PlainUnion
  ( PlainUnion,
    simpleMerge,
  )
import Grisette.Internal.Core.Data.Class.Solvable (Solvable (con))
import Grisette.Internal.SymPrim.Prim.Model
  ( AnySymbolSet,
    Model,
    SymbolSet (unSymbolSet),
    equation,
  )
import Grisette.Internal.SymPrim.Prim.Term
  ( SomeTypedSymbol (SomeTypedSymbol),
  )
import Grisette.Internal.SymPrim.SymBool (SymBool (SymBool))
import Language.Haskell.TH.Syntax (Lift)

data SolveInternal = SolveInternal
  deriving (SolveInternal -> SolveInternal -> Bool
(SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool) -> Eq SolveInternal
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SolveInternal -> SolveInternal -> Bool
== :: SolveInternal -> SolveInternal -> Bool
$c/= :: SolveInternal -> SolveInternal -> Bool
/= :: SolveInternal -> SolveInternal -> Bool
Eq, Int -> SolveInternal -> ShowS
[SolveInternal] -> ShowS
SolveInternal -> String
(Int -> SolveInternal -> ShowS)
-> (SolveInternal -> String)
-> ([SolveInternal] -> ShowS)
-> Show SolveInternal
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SolveInternal -> ShowS
showsPrec :: Int -> SolveInternal -> ShowS
$cshow :: SolveInternal -> String
show :: SolveInternal -> String
$cshowList :: [SolveInternal] -> ShowS
showList :: [SolveInternal] -> ShowS
Show, Eq SolveInternal
Eq SolveInternal =>
(SolveInternal -> SolveInternal -> Ordering)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> SolveInternal)
-> (SolveInternal -> SolveInternal -> SolveInternal)
-> Ord SolveInternal
SolveInternal -> SolveInternal -> Bool
SolveInternal -> SolveInternal -> Ordering
SolveInternal -> SolveInternal -> SolveInternal
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: SolveInternal -> SolveInternal -> Ordering
compare :: SolveInternal -> SolveInternal -> Ordering
$c< :: SolveInternal -> SolveInternal -> Bool
< :: SolveInternal -> SolveInternal -> Bool
$c<= :: SolveInternal -> SolveInternal -> Bool
<= :: SolveInternal -> SolveInternal -> Bool
$c> :: SolveInternal -> SolveInternal -> Bool
> :: SolveInternal -> SolveInternal -> Bool
$c>= :: SolveInternal -> SolveInternal -> Bool
>= :: SolveInternal -> SolveInternal -> Bool
$cmax :: SolveInternal -> SolveInternal -> SolveInternal
max :: SolveInternal -> SolveInternal -> SolveInternal
$cmin :: SolveInternal -> SolveInternal -> SolveInternal
min :: SolveInternal -> SolveInternal -> SolveInternal
Ord, (forall x. SolveInternal -> Rep SolveInternal x)
-> (forall x. Rep SolveInternal x -> SolveInternal)
-> Generic SolveInternal
forall x. Rep SolveInternal x -> SolveInternal
forall x. SolveInternal -> Rep SolveInternal x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SolveInternal -> Rep SolveInternal x
from :: forall x. SolveInternal -> Rep SolveInternal x
$cto :: forall x. Rep SolveInternal x -> SolveInternal
to :: forall x. Rep SolveInternal x -> SolveInternal
Generic, Eq SolveInternal
Eq SolveInternal =>
(Int -> SolveInternal -> Int)
-> (SolveInternal -> Int) -> Hashable SolveInternal
Int -> SolveInternal -> Int
SolveInternal -> Int
forall a. Eq a => (Int -> a -> Int) -> (a -> Int) -> Hashable a
$chashWithSalt :: Int -> SolveInternal -> Int
hashWithSalt :: Int -> SolveInternal -> Int
$chash :: SolveInternal -> Int
hash :: SolveInternal -> Int
Hashable, (forall (m :: * -> *). Quote m => SolveInternal -> m Exp)
-> (forall (m :: * -> *).
    Quote m =>
    SolveInternal -> Code m SolveInternal)
-> Lift SolveInternal
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => SolveInternal -> m Exp
forall (m :: * -> *).
Quote m =>
SolveInternal -> Code m SolveInternal
$clift :: forall (m :: * -> *). Quote m => SolveInternal -> m Exp
lift :: forall (m :: * -> *). Quote m => SolveInternal -> m Exp
$cliftTyped :: forall (m :: * -> *).
Quote m =>
SolveInternal -> Code m SolveInternal
liftTyped :: forall (m :: * -> *).
Quote m =>
SolveInternal -> Code m SolveInternal
Lift, SolveInternal -> ()
(SolveInternal -> ()) -> NFData SolveInternal
forall a. (a -> ()) -> NFData a
$crnf :: SolveInternal -> ()
rnf :: SolveInternal -> ()
NFData)

-- $setup
-- >>> import Grisette
-- >>> import Grisette.Core
-- >>> import Grisette.SymPrim
-- >>> import Grisette.Backend

-- | The current failures that can be returned by the solver.
data SolvingFailure
  = -- | Unsatisfiable: No model is available.
    Unsat
  | -- | Unknown: The solver cannot determine whether the formula is
    -- satisfiable.
    Unk
  | -- | The solver has reached the maximum number of models to return.
    ResultNumLimitReached
  | -- | The solver has encountered an error.
    SolvingError T.Text
  | -- | The solver has been terminated.
    Terminated
  deriving (Int -> SolvingFailure -> ShowS
[SolvingFailure] -> ShowS
SolvingFailure -> String
(Int -> SolvingFailure -> ShowS)
-> (SolvingFailure -> String)
-> ([SolvingFailure] -> ShowS)
-> Show SolvingFailure
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> SolvingFailure -> ShowS
showsPrec :: Int -> SolvingFailure -> ShowS
$cshow :: SolvingFailure -> String
show :: SolvingFailure -> String
$cshowList :: [SolvingFailure] -> ShowS
showList :: [SolvingFailure] -> ShowS
Show, SolvingFailure -> SolvingFailure -> Bool
(SolvingFailure -> SolvingFailure -> Bool)
-> (SolvingFailure -> SolvingFailure -> Bool) -> Eq SolvingFailure
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: SolvingFailure -> SolvingFailure -> Bool
== :: SolvingFailure -> SolvingFailure -> Bool
$c/= :: SolvingFailure -> SolvingFailure -> Bool
/= :: SolvingFailure -> SolvingFailure -> Bool
Eq, (forall x. SolvingFailure -> Rep SolvingFailure x)
-> (forall x. Rep SolvingFailure x -> SolvingFailure)
-> Generic SolvingFailure
forall x. Rep SolvingFailure x -> SolvingFailure
forall x. SolvingFailure -> Rep SolvingFailure x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. SolvingFailure -> Rep SolvingFailure x
from :: forall x. SolvingFailure -> Rep SolvingFailure x
$cto :: forall x. Rep SolvingFailure x -> SolvingFailure
to :: forall x. Rep SolvingFailure x -> SolvingFailure
Generic, (forall (m :: * -> *). Quote m => SolvingFailure -> m Exp)
-> (forall (m :: * -> *).
    Quote m =>
    SolvingFailure -> Code m SolvingFailure)
-> Lift SolvingFailure
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => SolvingFailure -> m Exp
forall (m :: * -> *).
Quote m =>
SolvingFailure -> Code m SolvingFailure
$clift :: forall (m :: * -> *). Quote m => SolvingFailure -> m Exp
lift :: forall (m :: * -> *). Quote m => SolvingFailure -> m Exp
$cliftTyped :: forall (m :: * -> *).
Quote m =>
SolvingFailure -> Code m SolvingFailure
liftTyped :: forall (m :: * -> *).
Quote m =>
SolvingFailure -> Code m SolvingFailure
Lift)
  deriving anyclass (SolvingFailure -> ()
(SolvingFailure -> ()) -> NFData SolvingFailure
forall a. (a -> ()) -> NFData a
$crnf :: SolvingFailure -> ()
rnf :: SolvingFailure -> ()
NFData, Eq SolvingFailure
Eq SolvingFailure =>
(Int -> SolvingFailure -> Int)
-> (SolvingFailure -> Int) -> Hashable SolvingFailure
Int -> SolvingFailure -> Int
SolvingFailure -> Int
forall a. Eq a => (Int -> a -> Int) -> (a -> Int) -> Hashable a
$chashWithSalt :: Int -> SolvingFailure -> Int
hashWithSalt :: Int -> SolvingFailure -> Int
$chash :: SolvingFailure -> Int
hash :: SolvingFailure -> Int
Hashable, (forall (m :: * -> *). MonadPut m => SolvingFailure -> m ())
-> (forall (m :: * -> *). MonadGet m => m SolvingFailure)
-> Serial SolvingFailure
forall a.
(forall (m :: * -> *). MonadPut m => a -> m ())
-> (forall (m :: * -> *). MonadGet m => m a) -> Serial a
forall (m :: * -> *). MonadGet m => m SolvingFailure
forall (m :: * -> *). MonadPut m => SolvingFailure -> m ()
$cserialize :: forall (m :: * -> *). MonadPut m => SolvingFailure -> m ()
serialize :: forall (m :: * -> *). MonadPut m => SolvingFailure -> m ()
$cdeserialize :: forall (m :: * -> *). MonadGet m => m SolvingFailure
deserialize :: forall (m :: * -> *). MonadGet m => m SolvingFailure
Serial)
  deriving ((forall ann. SolvingFailure -> Doc ann)
-> (forall ann. Int -> SolvingFailure -> Doc ann)
-> (forall ann. [SolvingFailure] -> Doc ann)
-> PPrint SolvingFailure
forall ann. Int -> SolvingFailure -> Doc ann
forall ann. [SolvingFailure] -> Doc ann
forall ann. SolvingFailure -> Doc ann
forall a.
(forall ann. a -> Doc ann)
-> (forall ann. Int -> a -> Doc ann)
-> (forall ann. [a] -> Doc ann)
-> PPrint a
$cpformat :: forall ann. SolvingFailure -> Doc ann
pformat :: forall ann. SolvingFailure -> Doc ann
$cpformatPrec :: forall ann. Int -> SolvingFailure -> Doc ann
pformatPrec :: forall ann. Int -> SolvingFailure -> Doc ann
$cpformatList :: forall ann. [SolvingFailure] -> Doc ann
pformatList :: forall ann. [SolvingFailure] -> Doc ann
PPrint) via (Default SolvingFailure)

instance Cereal.Serialize SolvingFailure where
  put :: Putter SolvingFailure
put = Putter SolvingFailure
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SolvingFailure -> m ()
serialize
  get :: Get SolvingFailure
get = Get SolvingFailure
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SolvingFailure
deserialize

instance Binary.Binary SolvingFailure where
  put :: SolvingFailure -> Put
put = SolvingFailure -> Put
forall a (m :: * -> *). (Serial a, MonadPut m) => a -> m ()
forall (m :: * -> *). MonadPut m => SolvingFailure -> m ()
serialize
  get :: Get SolvingFailure
get = Get SolvingFailure
forall a (m :: * -> *). (Serial a, MonadGet m) => m a
forall (m :: * -> *). MonadGet m => m SolvingFailure
deserialize

-- | A monadic solver interface.
--
-- This interface abstract the monadic interface of a solver. All the operations
-- performed in the monad are using a single solver instance. The solver
-- instance is management by the monad's @run@ function.
class (Monad m) => MonadicSolver m where
  monadicSolverPush :: Int -> m ()
  monadicSolverPop :: Int -> m ()
  monadicSolverResetAssertions :: m ()
  monadicSolverAssert :: SymBool -> m ()
  monadicSolverCheckSat :: m (Either SolvingFailure Model)

-- | Solve a single formula with a monadic solver. Find an assignment to it to
-- make it true.
monadicSolverSolve ::
  (MonadicSolver m) => SymBool -> m (Either SolvingFailure Model)
monadicSolverSolve :: forall (m :: * -> *).
MonadicSolver m =>
SymBool -> m (Either SolvingFailure Model)
monadicSolverSolve SymBool
formula = do
  SymBool -> m ()
forall (m :: * -> *). MonadicSolver m => SymBool -> m ()
monadicSolverAssert SymBool
formula
  m (Either SolvingFailure Model)
forall (m :: * -> *).
MonadicSolver m =>
m (Either SolvingFailure Model)
monadicSolverCheckSat

-- | The commands that can be sent to a solver.
data SolverCommand
  = SolverAssert !SymBool
  | SolverCheckSat
  | SolverPush Int
  | SolverPop Int
  | SolverResetAssertions
  | SolverTerminate

-- | A class that abstracts the solver interface.
class Solver handle where
  -- | Run a solver command.
  solverRunCommand ::
    (handle -> IO (Either SolvingFailure a)) ->
    handle ->
    SolverCommand ->
    IO (Either SolvingFailure a)

  -- | Assert a formula.
  solverAssert :: handle -> SymBool -> IO (Either SolvingFailure ())
  solverAssert handle
handle SymBool
formula =
    (handle -> IO (Either SolvingFailure ()))
-> handle -> SolverCommand -> IO (Either SolvingFailure ())
forall handle a.
Solver handle =>
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
forall a.
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
solverRunCommand (IO (Either SolvingFailure ())
-> handle -> IO (Either SolvingFailure ())
forall a b. a -> b -> a
const (IO (Either SolvingFailure ())
 -> handle -> IO (Either SolvingFailure ()))
-> IO (Either SolvingFailure ())
-> handle
-> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SolvingFailure () -> IO (Either SolvingFailure ()))
-> Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ () -> Either SolvingFailure ()
forall a b. b -> Either a b
Right ()) handle
handle (SolverCommand -> IO (Either SolvingFailure ()))
-> SolverCommand -> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ SymBool -> SolverCommand
SolverAssert SymBool
formula

  -- | Solve a formula.
  solverCheckSat :: handle -> IO (Either SolvingFailure Model)

  -- | Push @n@ levels.
  solverPush :: handle -> Int -> IO (Either SolvingFailure ())
  solverPush handle
handle Int
n =
    (handle -> IO (Either SolvingFailure ()))
-> handle -> SolverCommand -> IO (Either SolvingFailure ())
forall handle a.
Solver handle =>
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
forall a.
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
solverRunCommand (IO (Either SolvingFailure ())
-> handle -> IO (Either SolvingFailure ())
forall a b. a -> b -> a
const (IO (Either SolvingFailure ())
 -> handle -> IO (Either SolvingFailure ()))
-> IO (Either SolvingFailure ())
-> handle
-> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SolvingFailure () -> IO (Either SolvingFailure ()))
-> Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ () -> Either SolvingFailure ()
forall a b. b -> Either a b
Right ()) handle
handle (SolverCommand -> IO (Either SolvingFailure ()))
-> SolverCommand -> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ Int -> SolverCommand
SolverPush Int
n

  -- | Pop @n@ levels.
  solverPop :: handle -> Int -> IO (Either SolvingFailure ())
  solverPop handle
handle Int
n =
    (handle -> IO (Either SolvingFailure ()))
-> handle -> SolverCommand -> IO (Either SolvingFailure ())
forall handle a.
Solver handle =>
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
forall a.
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
solverRunCommand (IO (Either SolvingFailure ())
-> handle -> IO (Either SolvingFailure ())
forall a b. a -> b -> a
const (IO (Either SolvingFailure ())
 -> handle -> IO (Either SolvingFailure ()))
-> IO (Either SolvingFailure ())
-> handle
-> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SolvingFailure () -> IO (Either SolvingFailure ()))
-> Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ () -> Either SolvingFailure ()
forall a b. b -> Either a b
Right ()) handle
handle (SolverCommand -> IO (Either SolvingFailure ()))
-> SolverCommand -> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ Int -> SolverCommand
SolverPop Int
n

  -- | Reset all assertions in the solver.
  --
  -- The solver keeps all the assertions used in the previous commands:
  --
  -- >>> solver <- newSolver z3
  -- >>> solverSolve solver "a"
  -- Right (Model {a -> true :: Bool})
  -- >>> solverSolve solver $ symNot "a"
  -- Left Unsat
  --
  -- You can clear the assertions using @solverResetAssertions@:
  --
  -- >>> solverResetAssertions solver
  -- Right ()
  -- >>> solverSolve solver $ symNot "a"
  -- Right (Model {a -> false :: Bool})
  solverResetAssertions :: handle -> IO (Either SolvingFailure ())
  solverResetAssertions handle
handle =
    (handle -> IO (Either SolvingFailure ()))
-> handle -> SolverCommand -> IO (Either SolvingFailure ())
forall handle a.
Solver handle =>
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
forall a.
(handle -> IO (Either SolvingFailure a))
-> handle -> SolverCommand -> IO (Either SolvingFailure a)
solverRunCommand (IO (Either SolvingFailure ())
-> handle -> IO (Either SolvingFailure ())
forall a b. a -> b -> a
const (IO (Either SolvingFailure ())
 -> handle -> IO (Either SolvingFailure ()))
-> IO (Either SolvingFailure ())
-> handle
-> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SolvingFailure () -> IO (Either SolvingFailure ()))
-> Either SolvingFailure () -> IO (Either SolvingFailure ())
forall a b. (a -> b) -> a -> b
$ () -> Either SolvingFailure ()
forall a b. b -> Either a b
Right ()) handle
handle SolverCommand
SolverResetAssertions

  -- | Terminate the solver, wait until the last command is finished.
  solverTerminate :: handle -> IO ()

  -- | Force terminate the solver, do not wait for the last command to finish.
  solverForceTerminate :: handle -> IO ()

-- | Solve a single formula. Find an assignment to it to make it true.
solverSolve ::
  (Solver handle) => handle -> SymBool -> IO (Either SolvingFailure Model)
solverSolve :: forall handle.
Solver handle =>
handle -> SymBool -> IO (Either SolvingFailure Model)
solverSolve handle
solver SymBool
formula = do
  Either SolvingFailure ()
res <- handle -> SymBool -> IO (Either SolvingFailure ())
forall handle.
Solver handle =>
handle -> SymBool -> IO (Either SolvingFailure ())
solverAssert handle
solver SymBool
formula
  case Either SolvingFailure ()
res of
    Left SolvingFailure
err -> Either SolvingFailure Model -> IO (Either SolvingFailure Model)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either SolvingFailure Model -> IO (Either SolvingFailure Model))
-> Either SolvingFailure Model -> IO (Either SolvingFailure Model)
forall a b. (a -> b) -> a -> b
$ SolvingFailure -> Either SolvingFailure Model
forall a b. a -> Either a b
Left SolvingFailure
err
    Right ()
_ -> handle -> IO (Either SolvingFailure Model)
forall handle.
Solver handle =>
handle -> IO (Either SolvingFailure Model)
solverCheckSat handle
solver

-- | Solve a single formula while returning multiple models to make it true.
-- The maximum number of desired models are given.
solverSolveMulti ::
  (Solver handle) =>
  -- | solver handle
  handle ->
  -- | maximum number of models to return
  Int ->
  -- | formula to solve, the solver will try to make it true
  SymBool ->
  IO ([Model], SolvingFailure)
solverSolveMulti :: forall handle.
Solver handle =>
handle -> Int -> SymBool -> IO ([Model], SolvingFailure)
solverSolveMulti handle
solver Int
numOfModelRequested SymBool
formula = do
  Either SolvingFailure Model
firstModel <- handle -> SymBool -> IO (Either SolvingFailure Model)
forall handle.
Solver handle =>
handle -> SymBool -> IO (Either SolvingFailure Model)
solverSolve handle
solver SymBool
formula
  case Either SolvingFailure Model
firstModel of
    Left SolvingFailure
err -> ([Model], SolvingFailure) -> IO ([Model], SolvingFailure)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], SolvingFailure
err)
    Right Model
model -> do
      ([Model]
models, SolvingFailure
err) <- handle -> Model -> Int -> IO ([Model], SolvingFailure)
go handle
solver Model
model Int
numOfModelRequested
      ([Model], SolvingFailure) -> IO ([Model], SolvingFailure)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Model
model Model -> [Model] -> [Model]
forall a. a -> [a] -> [a]
: [Model]
models, SolvingFailure
err)
  where
    allSymbols :: AnySymbolSet
allSymbols = SymBool -> AnySymbolSet
forall a. ExtractSym a => a -> AnySymbolSet
extractSym SymBool
formula :: AnySymbolSet
    go :: handle -> Model -> Int -> IO ([Model], SolvingFailure)
go handle
solver Model
prevModel Int
n
      | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = ([Model], SolvingFailure) -> IO ([Model], SolvingFailure)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], SolvingFailure
ResultNumLimitReached)
      | Bool
otherwise = do
          let newFormula :: SymBool
newFormula =
                (SymBool -> SomeTypedSymbol 'AnyKind -> SymBool)
-> SymBool -> HashSet (SomeTypedSymbol 'AnyKind) -> SymBool
forall a b. (a -> b -> a) -> a -> HashSet b -> a
S.foldl'
                  ( \SymBool
acc (SomeTypedSymbol TypedSymbol 'AnyKind t
v) ->
                      SymBool
acc
                        SymBool -> SymBool -> SymBool
forall b. LogicalOp b => b -> b -> b
.|| (SymBool -> SymBool
forall b. LogicalOp b => b -> b
symNot (Term Bool -> SymBool
SymBool (Term Bool -> SymBool) -> Term Bool -> SymBool
forall a b. (a -> b) -> a -> b
$ Maybe (Term Bool) -> Term Bool
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Term Bool) -> Term Bool) -> Maybe (Term Bool) -> Term Bool
forall a b. (a -> b) -> a -> b
$ TypedSymbol 'AnyKind t -> Model -> Maybe (Term Bool)
forall a. TypedAnySymbol a -> Model -> Maybe (Term Bool)
equation TypedSymbol 'AnyKind t
v Model
prevModel))
                  )
                  (Bool -> SymBool
forall c t. Solvable c t => c -> t
con Bool
False)
                  (AnySymbolSet -> HashSet (SomeTypedSymbol 'AnyKind)
forall (knd :: SymbolKind).
SymbolSet knd -> HashSet (SomeTypedSymbol knd)
unSymbolSet AnySymbolSet
allSymbols)
          Either SolvingFailure Model
res <- handle -> SymBool -> IO (Either SolvingFailure Model)
forall handle.
Solver handle =>
handle -> SymBool -> IO (Either SolvingFailure Model)
solverSolve handle
solver SymBool
newFormula
          case Either SolvingFailure Model
res of
            Left SolvingFailure
err -> ([Model], SolvingFailure) -> IO ([Model], SolvingFailure)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([], SolvingFailure
err)
            Right Model
model -> do
              ([Model]
models, SolvingFailure
err) <- handle -> Model -> Int -> IO ([Model], SolvingFailure)
go handle
solver Model
model (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
              ([Model], SolvingFailure) -> IO ([Model], SolvingFailure)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Model
model Model -> [Model] -> [Model]
forall a. a -> [a] -> [a]
: [Model]
models, SolvingFailure
err)

-- |
-- Solver procedure for programs with error handling.
solverSolveExcept ::
  ( UnionWithExcept t u e v,
    PlainUnion u,
    Functor u,
    Solver handle
  ) =>
  -- | solver handle
  handle ->
  -- | mapping the results to symbolic boolean formulas, the solver would try to
  -- find a model to make the formula true
  (Either e v -> SymBool) ->
  -- | the program to be solved, should be a union of exception and values
  t ->
  IO (Either SolvingFailure Model)
solverSolveExcept :: forall t (u :: * -> *) e v handle.
(UnionWithExcept t u e v, PlainUnion u, Functor u,
 Solver handle) =>
handle
-> (Either e v -> SymBool) -> t -> IO (Either SolvingFailure Model)
solverSolveExcept handle
solver Either e v -> SymBool
f t
v =
  handle -> SymBool -> IO (Either SolvingFailure Model)
forall handle.
Solver handle =>
handle -> SymBool -> IO (Either SolvingFailure Model)
solverSolve handle
solver (u SymBool -> SymBool
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (u SymBool -> SymBool) -> u SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ Either e v -> SymBool
f (Either e v -> SymBool) -> u (Either e v) -> u SymBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> u (Either e v)
forall t (u :: * -> *) e v.
UnionWithExcept t u e v =>
t -> u (Either e v)
extractUnionExcept t
v)

-- |
-- Solver procedure for programs with error handling. Would return multiple
-- models if possible.
solverSolveMultiExcept ::
  ( UnionWithExcept t u e v,
    PlainUnion u,
    Functor u,
    Solver handle
  ) =>
  -- | solver configuration
  handle ->
  -- | maximum number of models to return
  Int ->
  -- | mapping the results to symbolic boolean formulas, the solver would try to
  -- find a model to make the formula true
  (Either e v -> SymBool) ->
  -- | the program to be solved, should be a union of exception and values
  t ->
  IO ([Model], SolvingFailure)
solverSolveMultiExcept :: forall t (u :: * -> *) e v handle.
(UnionWithExcept t u e v, PlainUnion u, Functor u,
 Solver handle) =>
handle
-> Int
-> (Either e v -> SymBool)
-> t
-> IO ([Model], SolvingFailure)
solverSolveMultiExcept handle
handle Int
n Either e v -> SymBool
f t
v =
  handle -> Int -> SymBool -> IO ([Model], SolvingFailure)
forall handle.
Solver handle =>
handle -> Int -> SymBool -> IO ([Model], SolvingFailure)
solverSolveMulti handle
handle Int
n (u SymBool -> SymBool
forall (u :: * -> *) a.
(SimpleMergeable a, PlainUnion u) =>
u a -> a
simpleMerge (u SymBool -> SymBool) -> u SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ Either e v -> SymBool
f (Either e v -> SymBool) -> u (Either e v) -> u SymBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> u (Either e v)
forall t (u :: * -> *) e v.
UnionWithExcept t u e v =>
t -> u (Either e v)
extractUnionExcept t
v)

-- | A class that abstracts the creation of a solver instance based on a
-- configuration.
--
-- The solver instance will need to be terminated by the user, with the solver
-- interface.
class
  (Solver handle) =>
  ConfigurableSolver config handle
    | config -> handle
  where
  newSolver :: config -> IO handle

-- | Start a solver, run a computation with the solver, and terminate the
-- solver after the computation finishes.
--
-- When an exception happens, this will forcibly terminate the solver.
--
-- Note: if Grisette is compiled with sbv < 10.10, the solver likely won't be
-- really terminated until it has finished the last action, and this will
-- result in long-running or zombie solver instances.
--
-- This was due to a bug in sbv, which is fixed in
-- https://github.com/LeventErkok/sbv/pull/695.
withSolver ::
  (ConfigurableSolver config handle) =>
  config ->
  (handle -> IO a) ->
  IO a
withSolver :: forall config handle a.
ConfigurableSolver config handle =>
config -> (handle -> IO a) -> IO a
withSolver config
config handle -> IO a
action =
  ((forall a. IO a -> IO a) -> IO a) -> IO a
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO a) -> IO a)
-> ((forall a. IO a -> IO a) -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
    handle
handle <- config -> IO handle
forall config handle.
ConfigurableSolver config handle =>
config -> IO handle
newSolver config
config
    a
r <- IO a -> IO a
forall a. IO a -> IO a
restore (handle -> IO a
action handle
handle) IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
`onException` handle -> IO ()
forall handle. Solver handle => handle -> IO ()
solverForceTerminate handle
handle
    handle -> IO ()
forall handle. Solver handle => handle -> IO ()
solverTerminate handle
handle
    a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r

-- | Solve a single formula. Find an assignment to it to make it true.
--
-- >>> solve z3 ("a" .&& ("b" :: SymInteger) .== 1)
-- Right (Model {a -> true :: Bool, b -> 1 :: Integer})
-- >>> solve z3 ("a" .&& symNot "a")
-- Left Unsat
solve ::
  (ConfigurableSolver config handle) =>
  -- | solver configuration
  config ->
  -- | formula to solve, the solver will try to make it true
  SymBool ->
  IO (Either SolvingFailure Model)
solve :: forall config handle.
ConfigurableSolver config handle =>
config -> SymBool -> IO (Either SolvingFailure Model)
solve config
config SymBool
formula = config
-> (handle -> IO (Either SolvingFailure Model))
-> IO (Either SolvingFailure Model)
forall config handle a.
ConfigurableSolver config handle =>
config -> (handle -> IO a) -> IO a
withSolver config
config (handle -> SymBool -> IO (Either SolvingFailure Model)
forall handle.
Solver handle =>
handle -> SymBool -> IO (Either SolvingFailure Model)
`solverSolve` SymBool
formula)

-- | Solve a single formula while returning multiple models to make it true.
-- The maximum number of desired models are given.
--
-- > >>> solveMulti z3 4 ("a" .|| "b")
-- > [Model {a -> True :: Bool, b -> False :: Bool},Model {a -> False :: Bool, b -> True :: Bool},Model {a -> True :: Bool, b -> True :: Bool}]
solveMulti ::
  (ConfigurableSolver config handle) =>
  -- | solver configuration
  config ->
  -- | maximum number of models to return
  Int ->
  -- | formula to solve, the solver will try to make it true
  SymBool ->
  IO ([Model], SolvingFailure)
solveMulti :: forall config handle.
ConfigurableSolver config handle =>
config -> Int -> SymBool -> IO ([Model], SolvingFailure)
solveMulti config
config Int
numOfModelRequested SymBool
formula =
  config
-> (handle -> IO ([Model], SolvingFailure))
-> IO ([Model], SolvingFailure)
forall config handle a.
ConfigurableSolver config handle =>
config -> (handle -> IO a) -> IO a
withSolver config
config ((handle -> IO ([Model], SolvingFailure))
 -> IO ([Model], SolvingFailure))
-> (handle -> IO ([Model], SolvingFailure))
-> IO ([Model], SolvingFailure)
forall a b. (a -> b) -> a -> b
$
    \handle
solver -> handle -> Int -> SymBool -> IO ([Model], SolvingFailure)
forall handle.
Solver handle =>
handle -> Int -> SymBool -> IO ([Model], SolvingFailure)
solverSolveMulti handle
solver Int
numOfModelRequested SymBool
formula

-- | A class that abstracts the union-like structures that contains exceptions.
class UnionWithExcept t u e v | t -> u e v where
  -- | Extract a union of exceptions and values from the structure.
  extractUnionExcept :: t -> u (Either e v)

instance UnionWithExcept (ExceptT e u v) u e v where
  extractUnionExcept :: ExceptT e u v -> u (Either e v)
extractUnionExcept = ExceptT e u v -> u (Either e v)
forall e (u :: * -> *) v. ExceptT e u v -> u (Either e v)
runExceptT

-- |
-- Solver procedure for programs with error handling.
--
-- >>> import Control.Monad.Except
-- >>> let x = "x" :: SymInteger
-- >>> :{
--   res :: ExceptT AssertionError Union ()
--   res = do
--     symAssert $ x .> 0       -- constrain that x is positive
--     symAssert $ x .< 2       -- constrain that x is less than 2
-- :}
--
-- >>> :{
--   translate (Left _) = con False -- errors are not desirable
--   translate _ = con True         -- non-errors are desirable
-- :}
--
-- >>> solveExcept z3 translate res
-- Right (Model {x -> 1 :: Integer})
solveExcept ::
  ( UnionWithExcept t u e v,
    PlainUnion u,
    Functor u,
    ConfigurableSolver config handle
  ) =>
  -- | solver configuration
  config ->
  -- | mapping the results to symbolic boolean formulas, the solver would try to
  -- find a model to make the formula true
  (Either e v -> SymBool) ->
  -- | the program to be solved, should be a union of exception and values
  t ->
  IO (Either SolvingFailure Model)
solveExcept :: forall t (u :: * -> *) e v config handle.
(UnionWithExcept t u e v, PlainUnion u, Functor u,
 ConfigurableSolver config handle) =>
config
-> (Either e v -> SymBool) -> t -> IO (Either SolvingFailure Model)
solveExcept config
config Either e v -> SymBool
f t
v =
  config
-> (handle -> IO (Either SolvingFailure Model))
-> IO (Either SolvingFailure Model)
forall config handle a.
ConfigurableSolver config handle =>
config -> (handle -> IO a) -> IO a
withSolver config
config ((handle -> IO (Either SolvingFailure Model))
 -> IO (Either SolvingFailure Model))
-> (handle -> IO (Either SolvingFailure Model))
-> IO (Either SolvingFailure Model)
forall a b. (a -> b) -> a -> b
$
    \handle
solver -> handle
-> (Either e v -> SymBool) -> t -> IO (Either SolvingFailure Model)
forall t (u :: * -> *) e v handle.
(UnionWithExcept t u e v, PlainUnion u, Functor u,
 Solver handle) =>
handle
-> (Either e v -> SymBool) -> t -> IO (Either SolvingFailure Model)
solverSolveExcept handle
solver Either e v -> SymBool
f t
v

-- |
-- Solver procedure for programs with error handling. Would return multiple
-- models if possible.
solveMultiExcept ::
  ( UnionWithExcept t u e v,
    PlainUnion u,
    Functor u,
    ConfigurableSolver config handle
  ) =>
  -- | solver configuration
  config ->
  -- | maximum number of models to return
  Int ->
  -- | mapping the results to symbolic boolean formulas, the solver would try to
  -- find a model to make the formula true
  (Either e v -> SymBool) ->
  -- | the program to be solved, should be a union of exception and values
  t ->
  IO ([Model], SolvingFailure)
solveMultiExcept :: forall t (u :: * -> *) e v config handle.
(UnionWithExcept t u e v, PlainUnion u, Functor u,
 ConfigurableSolver config handle) =>
config
-> Int
-> (Either e v -> SymBool)
-> t
-> IO ([Model], SolvingFailure)
solveMultiExcept config
config Int
n Either e v -> SymBool
f t
v =
  config
-> (handle -> IO ([Model], SolvingFailure))
-> IO ([Model], SolvingFailure)
forall config handle a.
ConfigurableSolver config handle =>
config -> (handle -> IO a) -> IO a
withSolver config
config ((handle -> IO ([Model], SolvingFailure))
 -> IO ([Model], SolvingFailure))
-> (handle -> IO ([Model], SolvingFailure))
-> IO ([Model], SolvingFailure)
forall a b. (a -> b) -> a -> b
$
    \handle
solver -> handle
-> Int
-> (Either e v -> SymBool)
-> t
-> IO ([Model], SolvingFailure)
forall t (u :: * -> *) e v handle.
(UnionWithExcept t u e v, PlainUnion u, Functor u,
 Solver handle) =>
handle
-> Int
-> (Either e v -> SymBool)
-> t
-> IO ([Model], SolvingFailure)
solverSolveMultiExcept handle
solver Int
n Either e v -> SymBool
f t
v