{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveLift #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :   Grisette.Core.Data.Class.Solver
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.Class.Solver
  ( -- * Note for the examples

    --

    -- | The examples assumes a [z3](https://github.com/Z3Prover/z3) solver available in @PATH@.

    -- * Union with exceptions
    UnionWithExcept (..),

    -- * Solver interfaces
    Solver (..),
    solveExcept,
    solveMultiExcept,
  )
where

import Control.DeepSeq
import Control.Monad.Except
import Data.Hashable
import Generics.Deriving
import Grisette.Core.Control.Exception
import Grisette.Core.Data.Class.Bool
import Grisette.Core.Data.Class.Evaluate
import Grisette.Core.Data.Class.ExtractSymbolics
import Grisette.Core.Data.Class.SimpleMergeable
import Grisette.Core.Data.Class.Solvable
import Grisette.IR.SymPrim.Data.Prim.Model
import {-# SOURCE #-} Grisette.IR.SymPrim.Data.SymPrim
import Language.Haskell.TH.Syntax

data SolveInternal = SolveInternal deriving (SolveInternal -> SolveInternal -> Bool
(SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool) -> Eq SolveInternal
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SolveInternal -> SolveInternal -> Bool
$c/= :: SolveInternal -> SolveInternal -> Bool
== :: SolveInternal -> SolveInternal -> Bool
$c== :: SolveInternal -> SolveInternal -> Bool
Eq, Int -> SolveInternal -> ShowS
[SolveInternal] -> ShowS
SolveInternal -> String
(Int -> SolveInternal -> ShowS)
-> (SolveInternal -> String)
-> ([SolveInternal] -> ShowS)
-> Show SolveInternal
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SolveInternal] -> ShowS
$cshowList :: [SolveInternal] -> ShowS
show :: SolveInternal -> String
$cshow :: SolveInternal -> String
showsPrec :: Int -> SolveInternal -> ShowS
$cshowsPrec :: Int -> SolveInternal -> ShowS
Show, Eq SolveInternal
Eq SolveInternal
-> (SolveInternal -> SolveInternal -> Ordering)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> Bool)
-> (SolveInternal -> SolveInternal -> SolveInternal)
-> (SolveInternal -> SolveInternal -> SolveInternal)
-> Ord SolveInternal
SolveInternal -> SolveInternal -> Bool
SolveInternal -> SolveInternal -> Ordering
SolveInternal -> SolveInternal -> SolveInternal
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SolveInternal -> SolveInternal -> SolveInternal
$cmin :: SolveInternal -> SolveInternal -> SolveInternal
max :: SolveInternal -> SolveInternal -> SolveInternal
$cmax :: SolveInternal -> SolveInternal -> SolveInternal
>= :: SolveInternal -> SolveInternal -> Bool
$c>= :: SolveInternal -> SolveInternal -> Bool
> :: SolveInternal -> SolveInternal -> Bool
$c> :: SolveInternal -> SolveInternal -> Bool
<= :: SolveInternal -> SolveInternal -> Bool
$c<= :: SolveInternal -> SolveInternal -> Bool
< :: SolveInternal -> SolveInternal -> Bool
$c< :: SolveInternal -> SolveInternal -> Bool
compare :: SolveInternal -> SolveInternal -> Ordering
$ccompare :: SolveInternal -> SolveInternal -> Ordering
Ord, (forall x. SolveInternal -> Rep SolveInternal x)
-> (forall x. Rep SolveInternal x -> SolveInternal)
-> Generic SolveInternal
forall x. Rep SolveInternal x -> SolveInternal
forall x. SolveInternal -> Rep SolveInternal x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep SolveInternal x -> SolveInternal
$cfrom :: forall x. SolveInternal -> Rep SolveInternal x
Generic, Int -> SolveInternal -> Int
SolveInternal -> Int
(Int -> SolveInternal -> Int)
-> (SolveInternal -> Int) -> Hashable SolveInternal
forall a. (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: SolveInternal -> Int
$chash :: SolveInternal -> Int
hashWithSalt :: Int -> SolveInternal -> Int
$chashWithSalt :: Int -> SolveInternal -> Int
Hashable, (forall (m :: * -> *). Quote m => SolveInternal -> m Exp)
-> (forall (m :: * -> *).
    Quote m =>
    SolveInternal -> Code m SolveInternal)
-> Lift SolveInternal
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => SolveInternal -> m Exp
forall (m :: * -> *).
Quote m =>
SolveInternal -> Code m SolveInternal
liftTyped :: forall (m :: * -> *).
Quote m =>
SolveInternal -> Code m SolveInternal
$cliftTyped :: forall (m :: * -> *).
Quote m =>
SolveInternal -> Code m SolveInternal
lift :: forall (m :: * -> *). Quote m => SolveInternal -> m Exp
$clift :: forall (m :: * -> *). Quote m => SolveInternal -> m Exp
Lift, SolveInternal -> ()
(SolveInternal -> ()) -> NFData SolveInternal
forall a. (a -> ()) -> NFData a
rnf :: SolveInternal -> ()
$crnf :: SolveInternal -> ()
NFData)

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> import Grisette.Backend.SBV
-- >>> :set -XOverloadedStrings

-- | A solver interface.
class
  Solver config failure
    | config -> failure
  where
  -- | Solve a single formula. Find an assignment to it to make it true.
  --
  -- >>> solve (UnboundedReasoning z3) ("a" &&~ ("b" :: SymInteger) ==~ 1)
  -- Right (Model {a -> True :: Bool, b -> 1 :: Integer})
  -- >>> solve (UnboundedReasoning z3) ("a" &&~ nots "a")
  -- Left Unsat
  solve ::
    -- | solver configuration
    config ->
    -- | formula to solve, the solver will try to make it true
    SymBool ->
    IO (Either failure Model)

  -- | Solve a single formula while returning multiple models to make it true.
  -- The maximum number of desired models are given.
  --
  -- > >>> solveMulti (UnboundedReasoning z3) 4 ("a" ||~ "b")
  -- > [Model {a -> True :: Bool, b -> False :: Bool},Model {a -> False :: Bool, b -> True :: Bool},Model {a -> True :: Bool, b -> True :: Bool}]
  solveMulti ::
    -- | solver configuration
    config ->
    -- | maximum number of models to return
    Int ->
    -- | formula to solve, the solver will try to make it true
    SymBool ->
    IO [Model]

  -- | Solve a single formula while returning multiple models to make it true.
  -- All models are returned.
  --
  -- > >>> solveAll (UnboundedReasoning z3) ("a" ||~ "b")
  -- > [Model {a -> True :: Bool, b -> False :: Bool},Model {a -> False :: Bool, b -> True :: Bool},Model {a -> True :: Bool, b -> True :: Bool}]
  solveAll ::
    -- | solver configuration
    config ->
    -- | formula to solve, the solver will try to make it true
    SymBool ->
    IO [Model]

-- | A class that abstracts the union-like structures that contains exceptions.
class UnionWithExcept t u e v | t -> u e v where
  -- | Extract a union of exceptions and values from the structure.
  extractUnionExcept :: t -> u (Either e v)

instance UnionWithExcept (ExceptT e u v) u e v where
  extractUnionExcept :: ExceptT e u v -> u (Either e v)
extractUnionExcept = ExceptT e u v -> u (Either e v)
forall e (u :: * -> *) v. ExceptT e u v -> u (Either e v)
runExceptT

-- |
-- Solver procedure for programs with error handling.
--
-- >>> :set -XLambdaCase
-- >>> import Control.Monad.Except
-- >>> let x = "x" :: SymInteger
-- >>> :{
--   res :: ExceptT AssertionError UnionM ()
--   res = do
--     symAssert $ x >~ 0       -- constrain that x is positive
--     symAssert $ x <~ 2       -- constrain that x is less than 2
-- :}
--
-- >>> :{
--   translate (Left _) = con False -- errors are not desirable
--   translate _ = con True         -- non-errors are desirable
-- :}
--
-- >>> solveExcept (UnboundedReasoning z3) translate res
-- Right (Model {x -> 1 :: Integer})
solveExcept ::
  ( UnionWithExcept t u e v,
    UnionPrjOp u,
    Functor u,
    Solver config failure
  ) =>
  -- | solver configuration
  config ->
  -- | mapping the results to symbolic boolean formulas, the solver would try to find a model to make the formula true
  (Either e v -> SymBool) ->
  -- | the program to be solved, should be a union of exception and values
  t ->
  IO (Either failure Model)
solveExcept :: forall t (u :: * -> *) e v config failure.
(UnionWithExcept t u e v, UnionPrjOp u, Functor u,
 Solver config failure) =>
config -> (Either e v -> SymBool) -> t -> IO (Either failure Model)
solveExcept config
config Either e v -> SymBool
f t
v = config -> SymBool -> IO (Either failure Model)
forall config failure.
Solver config failure =>
config -> SymBool -> IO (Either failure Model)
solve config
config (u SymBool -> SymBool
forall bool (u :: * -> *) a.
(SimpleMergeable a, UnionLike u, UnionPrjOp u) =>
u a -> a
simpleMerge (u SymBool -> SymBool) -> u SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ Either e v -> SymBool
f (Either e v -> SymBool) -> u (Either e v) -> u SymBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> u (Either e v)
forall t (u :: * -> *) e v.
UnionWithExcept t u e v =>
t -> u (Either e v)
extractUnionExcept t
v)

-- |
-- Solver procedure for programs with error handling. Would return multiple
-- models if possible.
solveMultiExcept ::
  ( UnionWithExcept t u e v,
    UnionPrjOp u,
    Functor u,
    Solver config failure
  ) =>
  -- | solver configuration
  config ->
  -- | maximum number of models to return
  Int ->
  -- | mapping the results to symbolic boolean formulas, the solver would try to find a model to make the formula true
  (Either e v -> SymBool) ->
  -- | the program to be solved, should be a union of exception and values
  t ->
  IO [Model]
solveMultiExcept :: forall t (u :: * -> *) e v config failure.
(UnionWithExcept t u e v, UnionPrjOp u, Functor u,
 Solver config failure) =>
config -> Int -> (Either e v -> SymBool) -> t -> IO [Model]
solveMultiExcept config
config Int
n Either e v -> SymBool
f t
v = config -> Int -> SymBool -> IO [Model]
forall config failure.
Solver config failure =>
config -> Int -> SymBool -> IO [Model]
solveMulti config
config Int
n (u SymBool -> SymBool
forall bool (u :: * -> *) a.
(SimpleMergeable a, UnionLike u, UnionPrjOp u) =>
u a -> a
simpleMerge (u SymBool -> SymBool) -> u SymBool -> SymBool
forall a b. (a -> b) -> a -> b
$ Either e v -> SymBool
f (Either e v -> SymBool) -> u (Either e v) -> u SymBool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> t -> u (Either e v)
forall t (u :: * -> *) e v.
UnionWithExcept t u e v =>
t -> u (Either e v)
extractUnionExcept t
v)