--------------------------------------------------------------------------------

{-# LANGUAGE LambdaCase, NamedFieldPuns, RankNTypes, ViewPatterns #-}
{-# LANGUAGE Safe #-}

-- | Communication with SMT solvers or theorem provers.
--
-- A solver is a running process defined by a 'Backend'.
module Copilot.Theorem.Prover.SMTIO
  ( Solver
  , startNewSolver, stop
  , assume, entailed, declVars
  ) where

import Copilot.Theorem.IL
import Copilot.Theorem.Prover.Backend

import System.IO
import System.Process
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.Maybe
import Data.Set ((\\), fromList, Set, union, empty, elems)

--------------------------------------------------------------------------------

-- | A connection with a running SMT solver or theorem prover.
data Solver a = Solver
  { Solver a -> String
solverName :: String
  , Solver a -> Handle
inh        :: Handle
  , Solver a -> Handle
outh       :: Handle
  , Solver a -> ProcessHandle
process    :: ProcessHandle
  , Solver a -> Bool
debugMode  :: Bool
  , Solver a -> Set VarDescr
vars       :: Set VarDescr
  , Solver a -> Set Expr
model      :: Set Expr
  , Solver a -> Backend a
backend    :: Backend a
  }

--------------------------------------------------------------------------------

-- | Output a debugging message if debugging is enabled for the solver.
debug :: Bool -> Solver a -> String -> IO ()
debug :: Bool -> Solver a -> String -> IO ()
debug printName :: Bool
printName s :: Solver a
s str :: String
str = Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Solver a -> Bool
forall a. Solver a -> Bool
debugMode Solver a
s) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
  String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ (if Bool
printName then "<" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Solver a -> String
forall a. Solver a -> String
solverName Solver a
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ ">  " else "") String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str

send :: Show a => Solver a -> a -> IO ()
send :: Solver a -> a -> IO ()
send _ (a -> String
forall a. Show a => a -> String
show -> String
"") = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
send s :: Solver a
s (a -> String
forall a. Show a => a -> String
show -> String
a) = do
    Handle -> String -> IO ()
hPutStr (Solver a -> Handle
forall a. Solver a -> Handle
inh Solver a
s) (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
a String -> String -> String
forall a. [a] -> [a] -> [a]
++ "\n"
    Bool -> Solver a -> String -> IO ()
forall a. Bool -> Solver a -> String -> IO ()
debug Bool
True Solver a
s String
a
    Handle -> IO ()
hFlush (Handle -> IO ()) -> Handle -> IO ()
forall a b. (a -> b) -> a -> b
$ Solver a -> Handle
forall a. Solver a -> Handle
inh Solver a
s

receive :: Solver a -> IO SatResult
receive :: Solver a -> IO SatResult
receive s :: Solver a
s = Maybe SatResult -> SatResult
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe SatResult -> SatResult)
-> IO (Maybe SatResult) -> IO SatResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MaybeT IO SatResult -> IO (Maybe SatResult)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT ([MaybeT IO SatResult] -> MaybeT IO SatResult
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum ([MaybeT IO SatResult] -> MaybeT IO SatResult)
-> [MaybeT IO SatResult] -> MaybeT IO SatResult
forall a b. (a -> b) -> a -> b
$ MaybeT IO SatResult -> [MaybeT IO SatResult]
forall a. a -> [a]
repeat MaybeT IO SatResult
line)
  where
    line :: MaybeT IO SatResult
    line :: MaybeT IO SatResult
line = do
      Bool
eof <- IO Bool -> MaybeT IO Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> MaybeT IO Bool) -> IO Bool -> MaybeT IO Bool
forall a b. (a -> b) -> a -> b
$ Handle -> IO Bool
hIsEOF (Handle -> IO Bool) -> Handle -> IO Bool
forall a b. (a -> b) -> a -> b
$ Solver a -> Handle
forall a. Solver a -> Handle
outh Solver a
s
      if Bool
eof
        then IO () -> MaybeT IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (Bool -> Solver a -> String -> IO ()
forall a. Bool -> Solver a -> String -> IO ()
debug Bool
True Solver a
s "[received: EOF]") MaybeT IO () -> MaybeT IO SatResult -> MaybeT IO SatResult
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SatResult -> MaybeT IO SatResult
forall (m :: * -> *) a. Monad m => a -> m a
return SatResult
Unknown
        else do
          String
ln <- IO String -> MaybeT IO String
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO String -> MaybeT IO String) -> IO String -> MaybeT IO String
forall a b. (a -> b) -> a -> b
$ Handle -> IO String
hGetLine (Handle -> IO String) -> Handle -> IO String
forall a b. (a -> b) -> a -> b
$ Solver a -> Handle
forall a. Solver a -> Handle
outh Solver a
s
          IO () -> MaybeT IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MaybeT IO ()) -> IO () -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> Solver a -> String -> IO ()
forall a. Bool -> Solver a -> String -> IO ()
debug Bool
True Solver a
s (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ "[received: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
ln String -> String -> String
forall a. [a] -> [a] -> [a]
++ "]"
          IO (Maybe SatResult) -> MaybeT IO SatResult
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (IO (Maybe SatResult) -> MaybeT IO SatResult)
-> IO (Maybe SatResult) -> MaybeT IO SatResult
forall a b. (a -> b) -> a -> b
$ Maybe SatResult -> IO (Maybe SatResult)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe SatResult -> IO (Maybe SatResult))
-> Maybe SatResult -> IO (Maybe SatResult)
forall a b. (a -> b) -> a -> b
$ (Backend a -> String -> Maybe SatResult
forall a. Backend a -> String -> Maybe SatResult
interpret (Backend a -> String -> Maybe SatResult)
-> Backend a -> String -> Maybe SatResult
forall a b. (a -> b) -> a -> b
$ Solver a -> Backend a
forall a. Solver a -> Backend a
backend Solver a
s) String
ln

--------------------------------------------------------------------------------

-- | Create a new solver implemented by the backend specified.
--
-- The error handle from the backend handle is immediately closed/discarded,
-- and the logic initialized as specifiied by the backend options.
startNewSolver :: SmtFormat a => String -> Bool -> Backend a -> IO (Solver a)
startNewSolver :: String -> Bool -> Backend a -> IO (Solver a)
startNewSolver name :: String
name dbgMode :: Bool
dbgMode b :: Backend a
b = do
  (i :: Handle
i, o :: Handle
o, e :: Handle
e, p :: ProcessHandle
p) <- String
-> [String]
-> Maybe String
-> Maybe [(String, String)]
-> IO (Handle, Handle, Handle, ProcessHandle)
runInteractiveProcess (Backend a -> String
forall a. Backend a -> String
cmd Backend a
b) (Backend a -> [String]
forall a. Backend a -> [String]
cmdOpts Backend a
b) Maybe String
forall a. Maybe a
Nothing Maybe [(String, String)]
forall a. Maybe a
Nothing
  Handle -> IO ()
hClose Handle
e
  let s :: Solver a
s = String
-> Handle
-> Handle
-> ProcessHandle
-> Bool
-> Set VarDescr
-> Set Expr
-> Backend a
-> Solver a
forall a.
String
-> Handle
-> Handle
-> ProcessHandle
-> Bool
-> Set VarDescr
-> Set Expr
-> Backend a
-> Solver a
Solver String
name Handle
i Handle
o ProcessHandle
p Bool
dbgMode Set VarDescr
forall a. Set a
empty Set Expr
forall a. Set a
empty Backend a
b
  Solver a -> a -> IO ()
forall a. Show a => Solver a -> a -> IO ()
send Solver a
s (a -> IO ()) -> a -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> a
forall a. SmtFormat a => String -> a
setLogic (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ Backend a -> String
forall a. Backend a -> String
logic Backend a
b
  Solver a -> IO (Solver a)
forall (m :: * -> *) a. Monad m => a -> m a
return Solver a
s

-- | Stop a solver, closing all communication handles and terminating the
-- process.
stop :: Solver a -> IO ()
stop :: Solver a -> IO ()
stop s :: Solver a
s = do
  Handle -> IO ()
hClose (Handle -> IO ()) -> Handle -> IO ()
forall a b. (a -> b) -> a -> b
$ Solver a -> Handle
forall a. Solver a -> Handle
inh Solver a
s
  Handle -> IO ()
hClose (Handle -> IO ()) -> Handle -> IO ()
forall a b. (a -> b) -> a -> b
$ Solver a -> Handle
forall a. Solver a -> Handle
outh Solver a
s
  ProcessHandle -> IO ()
terminateProcess (ProcessHandle -> IO ()) -> ProcessHandle -> IO ()
forall a b. (a -> b) -> a -> b
$ Solver a -> ProcessHandle
forall a. Solver a -> ProcessHandle
process Solver a
s

--------------------------------------------------------------------------------

-- | Register the given expressions as assumptions or axioms with the solver.
assume :: SmtFormat a => Solver a -> [Expr] -> IO (Solver a)
assume :: Solver a -> [Expr] -> IO (Solver a)
assume s :: Solver a
s@(Solver { Set Expr
model :: Set Expr
model :: forall a. Solver a -> Set Expr
model }) cs :: [Expr]
cs = do
  let newAxioms :: [Expr]
newAxioms = Set Expr -> [Expr]
forall a. Set a -> [a]
elems (Set Expr -> [Expr]) -> Set Expr -> [Expr]
forall a b. (a -> b) -> a -> b
$ [Expr] -> Set Expr
forall a. Ord a => [a] -> Set a
fromList [Expr]
cs Set Expr -> Set Expr -> Set Expr
forall a. Ord a => Set a -> Set a -> Set a
\\ Set Expr
model
  Solver a -> [Expr] -> IO ()
forall a. SmtFormat a => Solver a -> [Expr] -> IO ()
assume' Solver a
s [Expr]
newAxioms
  Solver a -> IO (Solver a)
forall (m :: * -> *) a. Monad m => a -> m a
return Solver a
s { model :: Set Expr
model = Set Expr
model Set Expr -> Set Expr -> Set Expr
forall a. Ord a => Set a -> Set a -> Set a
`union` [Expr] -> Set Expr
forall a. Ord a => [a] -> Set a
fromList [Expr]
newAxioms }

assume' :: SmtFormat a => Solver a -> [Expr] -> IO ()
assume' :: Solver a -> [Expr] -> IO ()
assume' s :: Solver a
s cs :: [Expr]
cs = [Expr] -> (Expr -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Expr]
cs (Solver a -> a -> IO ()
forall a. Show a => Solver a -> a -> IO ()
send Solver a
s (a -> IO ()) -> (Expr -> a) -> Expr -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> a
forall a. SmtFormat a => Expr -> a
assert (Expr -> a) -> (Expr -> Expr) -> Expr -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Expr
bsimpl)

-- | Check if a series of expressions are entailed by the axioms or assumptions
-- already registered with the solver.
entailed :: SmtFormat a => Solver a -> [Expr] -> IO SatResult
entailed :: Solver a -> [Expr] -> IO SatResult
entailed s :: Solver a
s cs :: [Expr]
cs = do
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Backend a -> Bool
forall a. Backend a -> Bool
incremental (Backend a -> Bool) -> Backend a -> Bool
forall a b. (a -> b) -> a -> b
$ Solver a -> Backend a
forall a. Solver a -> Backend a
backend Solver a
s) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Solver a -> a -> IO ()
forall a. Show a => Solver a -> a -> IO ()
send Solver a
s a
forall a. SmtFormat a => a
push
  case [Expr]
cs of
      []  -> String -> IO ()
putStrLn "Warning: no proposition to prove." IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Solver a -> [Expr] -> IO ()
forall a. SmtFormat a => Solver a -> [Expr] -> IO ()
assume' Solver a
s [Bool -> Expr
ConstB Bool
True]
      _   -> Solver a -> [Expr] -> IO ()
forall a. SmtFormat a => Solver a -> [Expr] -> IO ()
assume' Solver a
s [(Expr -> Expr -> Expr) -> [Expr] -> Expr
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (Type -> Op2 -> Expr -> Expr -> Expr
Op2 Type
Bool Op2
Or) ((Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Op1 -> Expr -> Expr
Op1 Type
Bool Op1
Not) [Expr]
cs)]
  Solver a -> a -> IO ()
forall a. Show a => Solver a -> a -> IO ()
send Solver a
s a
forall a. SmtFormat a => a
checkSat
  (Backend a -> Handle -> IO ()
forall a. Backend a -> Handle -> IO ()
inputTerminator (Backend a -> Handle -> IO ()) -> Backend a -> Handle -> IO ()
forall a b. (a -> b) -> a -> b
$ Solver a -> Backend a
forall a. Solver a -> Backend a
backend Solver a
s) (Solver a -> Handle
forall a. Solver a -> Handle
inh Solver a
s)

  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Backend a -> Bool
forall a. Backend a -> Bool
incremental (Backend a -> Bool) -> Backend a -> Bool
forall a b. (a -> b) -> a -> b
$ Solver a -> Backend a
forall a. Solver a -> Backend a
backend Solver a
s) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Solver a -> a -> IO ()
forall a. Show a => Solver a -> a -> IO ()
send Solver a
s a
forall a. SmtFormat a => a
pop
  Solver a -> IO SatResult
forall a. Solver a -> IO SatResult
receive Solver a
s

-- | Register the given variables with the solver.
declVars :: SmtFormat a => Solver a -> [VarDescr] -> IO (Solver a)
declVars :: Solver a -> [VarDescr] -> IO (Solver a)
declVars s :: Solver a
s@(Solver { Set VarDescr
vars :: Set VarDescr
vars :: forall a. Solver a -> Set VarDescr
vars }) decls :: [VarDescr]
decls = do
  let newVars :: [VarDescr]
newVars = Set VarDescr -> [VarDescr]
forall a. Set a -> [a]
elems (Set VarDescr -> [VarDescr]) -> Set VarDescr -> [VarDescr]
forall a b. (a -> b) -> a -> b
$ [VarDescr] -> Set VarDescr
forall a. Ord a => [a] -> Set a
fromList [VarDescr]
decls Set VarDescr -> Set VarDescr -> Set VarDescr
forall a. Ord a => Set a -> Set a -> Set a
\\ Set VarDescr
vars
  [VarDescr] -> (VarDescr -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VarDescr]
newVars ((VarDescr -> IO ()) -> IO ()) -> (VarDescr -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(VarDescr {String
varName :: VarDescr -> String
varName :: String
varName, Type
varType :: VarDescr -> Type
varType :: Type
varType, [Type]
args :: VarDescr -> [Type]
args :: [Type]
args}) ->
    Solver a -> a -> IO ()
forall a. Show a => Solver a -> a -> IO ()
send Solver a
s (a -> IO ()) -> a -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> Type -> [Type] -> a
forall a. SmtFormat a => String -> Type -> [Type] -> a
declFun String
varName Type
varType [Type]
args
  Solver a -> IO (Solver a)
forall (m :: * -> *) a. Monad m => a -> m a
return Solver a
s { vars :: Set VarDescr
vars = Set VarDescr
vars Set VarDescr -> Set VarDescr -> Set VarDescr
forall a. Ord a => Set a -> Set a -> Set a
`union` [VarDescr] -> Set VarDescr
forall a. Ord a => [a] -> Set a
fromList [VarDescr]
newVars }

--------------------------------------------------------------------------------