module Language.Hasmtlib.Type.MonadSMT where

import Language.Hasmtlib.Internal.Expr
import Language.Hasmtlib.Type.Option
import Language.Hasmtlib.Type.SMTSort
import Language.Hasmtlib.Type.Solution
import Language.Hasmtlib.Codec
import Data.Proxy
import Control.Monad
import Control.Monad.State

-- | A 'MonadState' that holds an SMT-Problem.
class MonadState s m => MonadSMT s m where
  -- | Construct a variable.
  --   This is mainly intended for internal use.
  --   In the API use 'var'' instead.
  --
  -- @
  -- x :: SMTVar RealType <- smtvar' (Proxy @RealType)
  -- @
  smtvar' :: forall t. KnownSMTSort t => Proxy t -> m (SMTVar t)

  -- | Construct a variable as expression.
  --
  -- @
  -- x :: Expr RealType <- var' (Proxy @RealType)
  -- @
  var' :: forall t. KnownSMTSort t => Proxy t -> m (Expr t)

  -- | Assert a boolean expression.
  --
  -- @
  -- x :: Expr IntType <- var @IntType
  -- assert $ x + 5 === 42
  -- @
  assert :: Expr BoolSort -> m ()

  -- | Set an SMT-Solver-Option.
  --
  -- @
  -- setOption $ Incremental True
  -- @
  setOption :: SMTOption -> m ()

  -- | Set the logic for the SMT-Solver to use.
  --
  -- @
  -- setLogic \"QF_LRA\"
  -- @
  setLogic :: String -> m ()

-- | Wrapper for 'var'' which hides the 'Proxy'.
var :: forall t s m. (KnownSMTSort t, MonadSMT s m) => m (Expr t)
var :: forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m) =>
m (Expr t)
var = Proxy t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
(MonadSMT s m, KnownSMTSort t) =>
Proxy t -> m (Expr t)
forall (t :: SMTSort). KnownSMTSort t => Proxy t -> m (Expr t)
var' (forall {k} (t :: k). Proxy t
forall (t :: SMTSort). Proxy t
Proxy @t)
{-# INLINE var #-}

-- | Wrapper for 'smtvar'' which hides the 'Proxy'.
--   This is mainly intended for internal use.
--   In the API use 'var' instead.
smtvar :: forall t s m. (KnownSMTSort t, MonadSMT s m) => m (SMTVar t)
smtvar :: forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m) =>
m (SMTVar t)
smtvar = Proxy t -> m (SMTVar t)
forall s (m :: * -> *) (t :: SMTSort).
(MonadSMT s m, KnownSMTSort t) =>
Proxy t -> m (SMTVar t)
forall (t :: SMTSort). KnownSMTSort t => Proxy t -> m (SMTVar t)
smtvar' (forall {k} (t :: k). Proxy t
forall (t :: SMTSort). Proxy t
Proxy @t)
{-# INLINE smtvar #-}

-- | Create a constant.
--
--   >>> constant True
--       Constant (BoolValue True)
--
--   >>> let x :: Integer = 10 ; constant x
--       Constant (IntValue 10)
--
--   >>> constant @IntType 5
--       Constant (IntValue 5)
--
--   >>> constant @(BvType 8) 5
--       Constant (BvValue 0000101)
constant :: KnownSMTSort t => HaskellType t -> Expr t
constant :: forall (t :: SMTSort). KnownSMTSort t => HaskellType t -> Expr t
constant = Value t -> Expr t
forall (t :: SMTSort). Value t -> Expr t
Constant (Value t -> Expr t)
-> (HaskellType t -> Value t) -> HaskellType t -> Expr t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HaskellType t -> Value t
forall (t :: SMTSort). KnownSMTSort t => HaskellType t -> Value t
wrapValue
{-# INLINE constant #-}

-- | Maybe assert a boolean expression.
--   Asserts given expression if 'Maybe' is a 'Just'.
--   Does nothing otherwise.
assertMaybe :: MonadSMT s m => Maybe (Expr BoolSort) -> m ()
assertMaybe :: forall s (m :: * -> *).
MonadSMT s m =>
Maybe (Expr 'BoolSort) -> m ()
assertMaybe Maybe (Expr 'BoolSort)
Nothing = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
assertMaybe (Just Expr 'BoolSort
expr) = Expr 'BoolSort -> m ()
forall s (m :: * -> *). MonadSMT s m => Expr 'BoolSort -> m ()
assert Expr 'BoolSort
expr

--   We need this separate so we get a pure API for quantifiers
--   Ideally we would do that when rendering the expression
--   However renderSMTLib2 is pure but we need a new quantified var which is stateful
-- | Assign quantified variables to all quantified subexpressions of an expression.
--   This shall only be used internally.
--   Usually before rendering an assert.
quantify :: MonadSMT s m => Expr t -> m (Expr t)
quantify :: forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify (Not Expr t
x)      = (Expr t -> Expr t) -> m (Expr t) -> m (Expr t)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap   Expr t -> Expr t
forall (t :: SMTSort). Boolean (HaskellType t) => Expr t -> Expr t
Not  (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
x)
quantify (And Expr t
x Expr t
y)    = (Expr t -> Expr t -> Expr t)
-> m (Expr t) -> m (Expr t) -> m (Expr t)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Expr t -> Expr t -> Expr t
forall (t :: SMTSort).
Boolean (HaskellType t) =>
Expr t -> Expr t -> Expr t
And  (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
x) (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
y)
quantify (Or Expr t
x Expr t
y)     = (Expr t -> Expr t -> Expr t)
-> m (Expr t) -> m (Expr t) -> m (Expr t)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Expr t -> Expr t -> Expr t
forall (t :: SMTSort).
Boolean (HaskellType t) =>
Expr t -> Expr t -> Expr t
Or   (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
x) (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
y)
quantify (Impl Expr t
x Expr t
y)   = (Expr t -> Expr t -> Expr t)
-> m (Expr t) -> m (Expr t) -> m (Expr t)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Expr t -> Expr t -> Expr t
forall (t :: SMTSort).
Boolean (HaskellType t) =>
Expr t -> Expr t -> Expr t
Impl (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
x) (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
y)
quantify (Xor Expr t
x Expr t
y)    = (Expr t -> Expr t -> Expr t)
-> m (Expr t) -> m (Expr t) -> m (Expr t)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Expr t -> Expr t -> Expr t
forall (t :: SMTSort).
Boolean (HaskellType t) =>
Expr t -> Expr t -> Expr t
Xor  (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
x) (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
y)
quantify (Ite Expr 'BoolSort
p Expr t
t Expr t
f)  = (Expr 'BoolSort -> Expr t -> Expr t -> Expr t)
-> m (Expr 'BoolSort) -> m (Expr t) -> m (Expr t) -> m (Expr t)
forall (m :: * -> *) a1 a2 a3 r.
Monad m =>
(a1 -> a2 -> a3 -> r) -> m a1 -> m a2 -> m a3 -> m r
liftM3 Expr 'BoolSort -> Expr t -> Expr t -> Expr t
forall (t :: SMTSort). Expr 'BoolSort -> Expr t -> Expr t -> Expr t
Ite  (Expr 'BoolSort -> m (Expr 'BoolSort)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr 'BoolSort
p) (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
t) (Expr t -> m (Expr t)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify Expr t
f)
quantify (ForAll Maybe (SMTVar t1)
_ Expr t1 -> Expr 'BoolSort
f) = do
  SMTVar t1
qVar <- m (SMTVar t1)
forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m) =>
m (SMTVar t)
smtvar
  Expr 'BoolSort
qBody <- Expr 'BoolSort -> m (Expr 'BoolSort)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify (Expr 'BoolSort -> m (Expr 'BoolSort))
-> Expr 'BoolSort -> m (Expr 'BoolSort)
forall a b. (a -> b) -> a -> b
$ Expr t1 -> Expr 'BoolSort
f (Expr t1 -> Expr 'BoolSort) -> Expr t1 -> Expr 'BoolSort
forall a b. (a -> b) -> a -> b
$ SMTVar t1 -> Expr t1
forall (t :: SMTSort). SMTVar t -> Expr t
Var SMTVar t1
qVar
  Expr 'BoolSort -> m (Expr 'BoolSort)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr 'BoolSort -> m (Expr 'BoolSort))
-> Expr 'BoolSort -> m (Expr 'BoolSort)
forall a b. (a -> b) -> a -> b
$ Maybe (SMTVar t1) -> (Expr t1 -> Expr 'BoolSort) -> Expr 'BoolSort
forall (t1 :: SMTSort).
KnownSMTSort t1 =>
Maybe (SMTVar t1) -> (Expr t1 -> Expr 'BoolSort) -> Expr 'BoolSort
ForAll (SMTVar t1 -> Maybe (SMTVar t1)
forall a. a -> Maybe a
Just SMTVar t1
qVar) (Expr 'BoolSort -> Expr t1 -> Expr 'BoolSort
forall a b. a -> b -> a
const Expr 'BoolSort
qBody)
quantify (Exists Maybe (SMTVar t1)
_ Expr t1 -> Expr 'BoolSort
f) = do
  SMTVar t1
qVar <- m (SMTVar t1)
forall (t :: SMTSort) s (m :: * -> *).
(KnownSMTSort t, MonadSMT s m) =>
m (SMTVar t)
smtvar
  Expr 'BoolSort
qBody <- Expr 'BoolSort -> m (Expr 'BoolSort)
forall s (m :: * -> *) (t :: SMTSort).
MonadSMT s m =>
Expr t -> m (Expr t)
quantify (Expr 'BoolSort -> m (Expr 'BoolSort))
-> Expr 'BoolSort -> m (Expr 'BoolSort)
forall a b. (a -> b) -> a -> b
$ Expr t1 -> Expr 'BoolSort
f (Expr t1 -> Expr 'BoolSort) -> Expr t1 -> Expr 'BoolSort
forall a b. (a -> b) -> a -> b
$ SMTVar t1 -> Expr t1
forall (t :: SMTSort). SMTVar t -> Expr t
Var SMTVar t1
qVar
  Expr 'BoolSort -> m (Expr 'BoolSort)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr 'BoolSort -> m (Expr 'BoolSort))
-> Expr 'BoolSort -> m (Expr 'BoolSort)
forall a b. (a -> b) -> a -> b
$ Maybe (SMTVar t1) -> (Expr t1 -> Expr 'BoolSort) -> Expr 'BoolSort
forall (t1 :: SMTSort).
KnownSMTSort t1 =>
Maybe (SMTVar t1) -> (Expr t1 -> Expr 'BoolSort) -> Expr 'BoolSort
Exists (SMTVar t1 -> Maybe (SMTVar t1)
forall a. a -> Maybe a
Just SMTVar t1
qVar) (Expr 'BoolSort -> Expr t1 -> Expr 'BoolSort
forall a b. a -> b -> a
const Expr 'BoolSort
qBody)
quantify Expr t
expr = Expr t -> m (Expr t)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Expr t
expr

-- | A 'MonadSMT' that allows incremental solving.
class MonadSMT s m => MonadIncrSMT s m where
  -- | Push a new context (one) to the solvers context-stack.
  push :: m ()

  -- | Pop the solvers context-stack by one.
  pop :: m ()

  -- | Run check-sat on the current problem.
  checkSat :: m Result

  -- | Run get-model on the current problem.
  --   This can be used to decode temporary models within the SMT-Problem.
  --
  -- @
  -- x <- var @RealSort
  -- y <- var
  -- assert $ x >? y && y <? (-1)
  -- res <- checkSat
  -- case res of
  --   Unsat -> print "Unsat. Cannot get model."
  --   r     -> do
  --     model <- getModel
  --     liftIO $ print $ decode model x
  -- @
  getModel :: m Solution

  -- | Evaluate any expressions value in the solvers model.
  --   Requires a 'Sat' or 'Unknown' check-sat response beforehand.
  --
  -- @
  -- x <- var @RealSort
  -- assert $ x >? 10
  -- res <- checkSat
  -- case res of
  --   Unsat -> print "Unsat. Cannot get value for 'x'."
  --   r     -> do
  --     x' <- getValue x
  --     liftIO $ print $ show r ++ ": x = " ++ show x'
  -- @
  getValue :: KnownSMTSort t => Expr t -> m (Maybe (Decoded (Expr t)))

-- | First run 'checkSat' and then 'getModel' on the current problem.
solve :: (MonadIncrSMT s m, MonadIO m) => m (Result, Solution)
solve :: forall s (m :: * -> *).
(MonadIncrSMT s m, MonadIO m) =>
m (Result, Solution)
solve = (Result -> Solution -> (Result, Solution))
-> m Result -> m Solution -> m (Result, Solution)
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 (,) m Result
forall s (m :: * -> *). MonadIncrSMT s m => m Result
checkSat m Solution
forall s (m :: * -> *). MonadIncrSMT s m => m Solution
getModel

-- | A 'MonadState' that holds an OMT-Problem.
--   An OMT-Problem is a 'SMT-Problem' with additional optimization targets.
class MonadSMT s m => MonadOMT s m where
  -- | Minimizes a numerical expression within the OMT-Problem.
  --
  --   For example, below minimization:
  --
  -- @
  -- x <- var @IntSort
  -- assert $ x >? -2
  -- minimize x
  -- @
  --
  --   will give @x := -1@ as solution.
  minimize :: (KnownSMTSort t, Num (Expr t)) => Expr t -> m ()

  -- | Maximizes a numerical expression within the OMT-Problem.
  --
  --   For example, below maximization:
  --
  -- @
  -- x <- var @(BvSort 8)
  -- maximize x
  -- @
  --
  --   will give @x := 11111111@ as solution.
  maximize :: (KnownSMTSort t, Num (Expr t)) => Expr t -> m ()

  -- | Asserts a soft boolean expression.
  --   May take a weight and an identifier for grouping.
  --
  --   For example, below a soft constraint with weight 2.0 and identifier \"myId\" for grouping:
  --
  -- @
  -- x <- var @BoolSort
  -- assertSoft x (Just 2.0) (Just "myId")
  -- @
  --
  --   Omitting the weight will default it to 1.0.
  --
  -- @
  -- x <- var @BoolSort
  -- y <- var @BoolSort
  -- assertSoft x
  -- assertSoft y (Just "myId")
  -- @
  assertSoft :: Expr BoolSort -> Maybe Double -> Maybe String -> m ()

-- | Like 'assertSoft' but forces a weight and omits the group-id.
assertSoftWeighted :: MonadOMT s m => Expr BoolSort -> Double -> m ()
assertSoftWeighted :: forall s (m :: * -> *).
MonadOMT s m =>
Expr 'BoolSort -> Double -> m ()
assertSoftWeighted Expr 'BoolSort
expr Double
w = Expr 'BoolSort -> Maybe Double -> Maybe String -> m ()
forall s (m :: * -> *).
MonadOMT s m =>
Expr 'BoolSort -> Maybe Double -> Maybe String -> m ()
assertSoft Expr 'BoolSort
expr (Double -> Maybe Double
forall a. a -> Maybe a
Just Double
w) Maybe String
forall a. Maybe a
Nothing