module Language.Hasmtlib.Type.Solver
  ( WithSolver(..)
  , solveWith
  , interactiveWith, debugInteractiveWith
  , solveMinimized, solveMinimizedDebug
  , solveMaximized, solveMaximizedDebug
  )
where

import Language.Hasmtlib.Type.MonadSMT
import Language.Hasmtlib.Internal.Expr
import Language.Hasmtlib.Type.SMTSort
import Language.Hasmtlib.Type.Solution
import Language.Hasmtlib.Type.Pipe
import Language.Hasmtlib.Orderable
import Language.Hasmtlib.Codec
import qualified SMTLIB.Backends as Backend
import qualified SMTLIB.Backends.Process as Process
import Data.Default
import Control.Monad.State

-- | Data that can have a 'Backend.Solver' which may be debugged.
class WithSolver a where
  -- | Create a datum with a 'Backend.Solver' and a 'Bool for whether to debug the 'Backend.Solver'.
  withSolver :: Backend.Solver -> Bool -> a

instance WithSolver Pipe where
  withSolver :: Solver -> Bool -> Pipe
withSolver = Int -> Maybe String -> Solver -> Bool -> Pipe
Pipe Int
0 Maybe String
forall a. Maybe a
Nothing

-- | @'solveWith' solver prob@ solves a SMT problem @prob@ with the given
-- @solver@. It returns a pair consisting of:
--
-- 1. A 'Result' that indicates if @prob@ is satisfiable ('Sat'),
--    unsatisfiable ('Unsat'), or if the solver could not determine any
--    results ('Unknown').
--
-- 2. A 'Decoded' answer that was decoded using the solution to @prob@. Note
--    that this answer is only meaningful if the 'Result' is 'Sat' or 'Unknown' and
--    the answer value is in a 'Just'.
--
-- Here is a small example of how to use 'solveWith':
--
-- @
-- import Language.Hasmtlib
--
-- main :: IO ()
-- main = do
--   res <- solveWith @SMT (solver cvc5) $ do
--     setLogic \"QF_LIA\"
--
--     x <- var @IntSort
--
--     assert $ x >? 0
--
--     return x
--
--   print res
-- @
solveWith :: (Default s, Monad m, Codec a) => Solver s m -> StateT s m a -> m (Result, Maybe (Decoded a))
solveWith :: forall s (m :: * -> *) a.
(Default s, Monad m, Codec a) =>
Solver s m -> StateT s m a -> m (Result, Maybe (Decoded a))
solveWith Solver s m
solver StateT s m a
m = do
  (a
a, s
problem) <- StateT s m a -> s -> m (a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m a
m s
forall a. Default a => a
def
  (Result
result, Solution
solution) <- Solver s m
solver s
problem

  (Result, Maybe (Decoded a)) -> m (Result, Maybe (Decoded a))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
result, Solution -> a -> Maybe (Decoded a)
forall a. Codec a => Solution -> a -> Maybe (Decoded a)
decode Solution
solution a
a)

-- | Pipes an SMT-problem interactively to the solver.
--   Enables incremental solving by default.
--   Here is a small example of how to use it for solving a problem utilizing the solvers incremental stack:
--
-- @
-- import Language.Hasmtlib
-- import Control.Monad.IO.Class
--
-- main :: IO ()
-- main = do
--   cvc5Living <- interactiveSolver cvc5
--   interactiveWith @Pipe cvc5Living $ do
--     setOption $ Incremental True
--     setOption $ ProduceModels True
--     setLogic \"QF_LIA\"
--
--     x <- var @IntSort
--
--     assert $ x >? 0
--
--     (res, sol) <- solve
--     liftIO $ print res
--     liftIO $ print $ decode sol x
--
--     push
--     y <- var @IntSort
--
--     assert $ y <? 0
--     assert $ x === y
--
--     res' <- checkSat
--     liftIO $ print res'
--     pop
--
--     res'' <- checkSat
--     liftIO $ print res''
--
--   return ()
-- @
interactiveWith :: (WithSolver s, MonadIO m) => (Backend.Solver, Process.Handle) -> StateT s m () -> m ()
interactiveWith :: forall s (m :: * -> *).
(WithSolver s, MonadIO m) =>
(Solver, Handle) -> StateT s m () -> m ()
interactiveWith (Solver
solver, Handle
handle) StateT s m ()
m = do
  ((), s)
_ <- StateT s m () -> s -> m ((), s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m ()
m (s -> m ((), s)) -> s -> m ((), s)
forall a b. (a -> b) -> a -> b
$ Solver -> Bool -> s
forall a. WithSolver a => Solver -> Bool -> a
withSolver Solver
solver Bool
False
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
Process.close Handle
handle

-- | Like 'interactiveWith' but it prints all communication with the solver to console.
debugInteractiveWith :: (WithSolver s, MonadIO m) => (Backend.Solver, Process.Handle) -> StateT s m () -> m ()
debugInteractiveWith :: forall s (m :: * -> *).
(WithSolver s, MonadIO m) =>
(Solver, Handle) -> StateT s m () -> m ()
debugInteractiveWith (Solver
solver, Handle
handle) StateT s m ()
m = do
  ((), s)
_ <- StateT s m () -> s -> m ((), s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT s m ()
m (s -> m ((), s)) -> s -> m ((), s)
forall a b. (a -> b) -> a -> b
$ Solver -> Bool -> s
forall a. WithSolver a => Solver -> Bool -> a
withSolver Solver
solver Bool
True
  IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Handle -> IO ()
Process.close Handle
handle

-- | Solves the current problem with respect to a minimal solution for a given numerical expression.
--
--   Does not rely on MaxSMT/OMT.
--   Instead uses iterative refinement.
--
--   If you want access to intermediate results, use 'solveMinimizedDebug' instead.
solveMinimized :: (MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t, Orderable (Expr t))
  => Expr t
  -> m (Result, Solution)
solveMinimized :: forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t,
 Orderable (Expr t)) =>
Expr t -> m (Result, Solution)
solveMinimized = Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t) =>
Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
solveOptimized Maybe (Solution -> IO ())
forall a. Maybe a
Nothing Expr t -> Expr t -> Expr 'BoolSort
forall a. Orderable a => a -> a -> Expr 'BoolSort
(<?)

-- | Like 'solveMinimized' but with access to intermediate results.
solveMinimizedDebug :: (MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t, Orderable (Expr t))
  => (Solution -> IO ())
  -> Expr t
  -> m (Result, Solution)
solveMinimizedDebug :: forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t,
 Orderable (Expr t)) =>
(Solution -> IO ()) -> Expr t -> m (Result, Solution)
solveMinimizedDebug Solution -> IO ()
debug = Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t) =>
Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
solveOptimized ((Solution -> IO ()) -> Maybe (Solution -> IO ())
forall a. a -> Maybe a
Just Solution -> IO ()
debug) Expr t -> Expr t -> Expr 'BoolSort
forall a. Orderable a => a -> a -> Expr 'BoolSort
(<?)

-- | Solves the current problem with respect to a maximal solution for a given numerical expression.
--
--   Does not rely on MaxSMT/OMT.
--   Instead uses iterative refinement.
--
--   If you want access to intermediate results, use 'solveMaximizedDebug' instead.
solveMaximized :: (MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t, Orderable (Expr t))
  => Expr t
  -> m (Result, Solution)
solveMaximized :: forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t,
 Orderable (Expr t)) =>
Expr t -> m (Result, Solution)
solveMaximized = Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t) =>
Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
solveOptimized Maybe (Solution -> IO ())
forall a. Maybe a
Nothing Expr t -> Expr t -> Expr 'BoolSort
forall a. Orderable a => a -> a -> Expr 'BoolSort
(>?)

-- | Like 'solveMaximized' but with access to intermediate results.
solveMaximizedDebug :: (MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t, Orderable (Expr t))
  => (Solution -> IO ())
  -> Expr t
  -> m (Result, Solution)
solveMaximizedDebug :: forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t,
 Orderable (Expr t)) =>
(Solution -> IO ()) -> Expr t -> m (Result, Solution)
solveMaximizedDebug Solution -> IO ()
debug = Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t) =>
Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
solveOptimized ((Solution -> IO ()) -> Maybe (Solution -> IO ())
forall a. a -> Maybe a
Just Solution -> IO ()
debug) Expr t -> Expr t -> Expr 'BoolSort
forall a. Orderable a => a -> a -> Expr 'BoolSort
(>?)

solveOptimized :: (MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t)
  => Maybe (Solution -> IO ())
  -> (Expr t -> Expr t -> Expr BoolSort)
  -> Expr t
  -> m (Result, Solution)
solveOptimized :: forall (m :: * -> *) (t :: SMTSort).
(MonadIncrSMT Pipe m, MonadIO m, KnownSMTSort t) =>
Maybe (Solution -> IO ())
-> (Expr t -> Expr t -> Expr 'BoolSort)
-> Expr t
-> m (Result, Solution)
solveOptimized Maybe (Solution -> IO ())
mDebug Expr t -> Expr t -> Expr 'BoolSort
op = Result -> Solution -> Expr t -> m (Result, Solution)
go Result
Unknown Solution
forall a. Monoid a => a
mempty
  where
    go :: Result -> Solution -> Expr t -> m (Result, Solution)
go Result
oldRes Solution
oldSol Expr t
target = do
      m ()
forall s (m :: * -> *). MonadIncrSMT s m => m ()
push
      (Result
res, Solution
sol) <- m (Result, Solution)
forall s (m :: * -> *).
(MonadIncrSMT s m, MonadIO m) =>
m (Result, Solution)
solve
      case Result
res of
        Result
Sat   -> do
          case Solution -> Expr t -> Maybe (Decoded (Expr t))
forall a. Codec a => Solution -> a -> Maybe (Decoded a)
decode Solution
sol Expr t
target of
            Maybe (Decoded (Expr t))
Nothing        -> (Result, Solution) -> m (Result, Solution)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
Sat, Solution
forall a. Monoid a => a
mempty)
            Just Decoded (Expr t)
targetSol -> do
              case Maybe (Solution -> IO ())
mDebug of
                Maybe (Solution -> IO ())
Nothing    -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                Just Solution -> IO ()
debug -> IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Solution -> IO ()
debug Solution
sol
              Expr 'BoolSort -> m ()
forall s (m :: * -> *). MonadSMT s m => Expr 'BoolSort -> m ()
assert (Expr 'BoolSort -> m ()) -> Expr 'BoolSort -> m ()
forall a b. (a -> b) -> a -> b
$ Expr t
target Expr t -> Expr t -> Expr 'BoolSort
`op` Decoded (Expr t) -> Expr t
forall a. Codec a => Decoded a -> a
encode Decoded (Expr t)
targetSol
              Result -> Solution -> Expr t -> m (Result, Solution)
go Result
res Solution
sol Expr t
target
        Result
_ -> m ()
forall s (m :: * -> *). MonadIncrSMT s m => m ()
pop m () -> m (Result, Solution) -> m (Result, Solution)
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Result, Solution) -> m (Result, Solution)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Result
oldRes, Solution
oldSol)