{-# LANGUAGE LambdaCase     #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes     #-}
{-# LANGUAGE Safe           #-}
{-# LANGUAGE ViewPatterns   #-}

-- | Communication with SMT solvers or theorem provers.
--
-- A solver is a running process defined by a 'Backend'.
module Copilot.Theorem.Prover.SMTIO
  ( Solver
  , startNewSolver, stop
  , assume, entailed, declVars
  ) where

import Copilot.Theorem.IL
import Copilot.Theorem.Prover.Backend

import System.IO
import System.Process
import Control.Monad
import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.Maybe
import Data.Set ((\\), fromList, Set, union, empty, elems)

-- | A connection with a running SMT solver or theorem prover.
data Solver a = Solver
  { forall a. Solver a -> String
solverName :: String
  , forall a. Solver a -> Handle
inh        :: Handle
  , forall a. Solver a -> Handle
outh       :: Handle
  , forall a. Solver a -> ProcessHandle
process    :: ProcessHandle
  , forall a. Solver a -> Bool
debugMode  :: Bool
  , forall a. Solver a -> Set VarDescr
vars       :: Set VarDescr
  , forall a. Solver a -> Set Expr
model      :: Set Expr
  , forall a. Solver a -> Backend a
backend    :: Backend a
  }

-- | Output a debugging message if debugging is enabled for the solver.
debug :: Bool -> Solver a -> String -> IO ()
debug :: forall a. Bool -> Solver a -> String -> IO ()
debug Bool
printName Solver a
s String
str = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Solver a -> Bool
debugMode Solver a
s) forall a b. (a -> b) -> a -> b
$
  String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ (if Bool
printName then String
"<" forall a. [a] -> [a] -> [a]
++ forall a. Solver a -> String
solverName Solver a
s forall a. [a] -> [a] -> [a]
++ String
">  " else String
"") forall a. [a] -> [a] -> [a]
++ String
str

send :: Show a => Solver a -> a -> IO ()
send :: forall a. Show a => Solver a -> a -> IO ()
send Solver a
_ (forall a. Show a => a -> String
show -> String
"") = forall (m :: * -> *) a. Monad m => a -> m a
return ()
send Solver a
s (forall a. Show a => a -> String
show -> String
a) = do
    Handle -> String -> IO ()
hPutStr (forall a. Solver a -> Handle
inh Solver a
s) forall a b. (a -> b) -> a -> b
$ String
a forall a. [a] -> [a] -> [a]
++ String
"\n"
    forall a. Bool -> Solver a -> String -> IO ()
debug Bool
True Solver a
s String
a
    Handle -> IO ()
hFlush forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Handle
inh Solver a
s

receive :: Solver a -> IO SatResult
receive :: forall a. Solver a -> IO SatResult
receive Solver a
s = forall a. HasCallStack => Maybe a -> a
fromJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat MaybeT IO SatResult
line)
  where
    line :: MaybeT IO SatResult
    line :: MaybeT IO SatResult
line = do
      Bool
eof <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Handle -> IO Bool
hIsEOF forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Handle
outh Solver a
s
      if Bool
eof
        then forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (forall a. Bool -> Solver a -> String -> IO ()
debug Bool
True Solver a
s String
"[received: EOF]") forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return SatResult
Unknown
        else do
          String
ln <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Handle -> IO String
hGetLine forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Handle
outh Solver a
s
          forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Bool -> Solver a -> String -> IO ()
debug Bool
True Solver a
s forall a b. (a -> b) -> a -> b
$ String
"[received: " forall a. [a] -> [a] -> [a]
++ String
ln forall a. [a] -> [a] -> [a]
++ String
"]"
          forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ (forall a. Backend a -> String -> Maybe SatResult
interpret forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Backend a
backend Solver a
s) String
ln

-- | Create a new solver implemented by the backend specified.
--
-- The error handle from the backend handle is immediately closed/discarded,
-- and the logic initialized as specifiied by the backend options.
startNewSolver :: SmtFormat a => String -> Bool -> Backend a -> IO (Solver a)
startNewSolver :: forall a.
SmtFormat a =>
String -> Bool -> Backend a -> IO (Solver a)
startNewSolver String
name Bool
dbgMode Backend a
b = do
  (Handle
i, Handle
o, Handle
e, ProcessHandle
p) <- String
-> [String]
-> Maybe String
-> Maybe [(String, String)]
-> IO (Handle, Handle, Handle, ProcessHandle)
runInteractiveProcess (forall a. Backend a -> String
cmd Backend a
b) (forall a. Backend a -> [String]
cmdOpts Backend a
b) forall a. Maybe a
Nothing forall a. Maybe a
Nothing
  Handle -> IO ()
hClose Handle
e
  let s :: Solver a
s = forall a.
String
-> Handle
-> Handle
-> ProcessHandle
-> Bool
-> Set VarDescr
-> Set Expr
-> Backend a
-> Solver a
Solver String
name Handle
i Handle
o ProcessHandle
p Bool
dbgMode forall a. Set a
empty forall a. Set a
empty Backend a
b
  forall a. Show a => Solver a -> a -> IO ()
send Solver a
s forall a b. (a -> b) -> a -> b
$ forall a. SmtFormat a => String -> a
setLogic forall a b. (a -> b) -> a -> b
$ forall a. Backend a -> String
logic Backend a
b
  forall (m :: * -> *) a. Monad m => a -> m a
return Solver a
s

-- | Stop a solver, closing all communication handles and terminating the
-- process.
stop :: Solver a -> IO ()
stop :: forall a. Solver a -> IO ()
stop Solver a
s = do
  Handle -> IO ()
hClose forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Handle
inh Solver a
s
  Handle -> IO ()
hClose forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Handle
outh Solver a
s
  ProcessHandle -> IO ()
terminateProcess forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> ProcessHandle
process Solver a
s

-- | Register the given expressions as assumptions or axioms with the solver.
assume :: SmtFormat a => Solver a -> [Expr] -> IO (Solver a)
assume :: forall a. SmtFormat a => Solver a -> [Expr] -> IO (Solver a)
assume s :: Solver a
s@(Solver { Set Expr
model :: Set Expr
model :: forall a. Solver a -> Set Expr
model }) [Expr]
cs = do
  let newAxioms :: [Expr]
newAxioms = forall a. Set a -> [a]
elems forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> Set a
fromList [Expr]
cs forall a. Ord a => Set a -> Set a -> Set a
\\ Set Expr
model
  forall a. SmtFormat a => Solver a -> [Expr] -> IO ()
assume' Solver a
s [Expr]
newAxioms
  forall (m :: * -> *) a. Monad m => a -> m a
return Solver a
s { model :: Set Expr
model = Set Expr
model forall a. Ord a => Set a -> Set a -> Set a
`union` forall a. Ord a => [a] -> Set a
fromList [Expr]
newAxioms }

assume' :: SmtFormat a => Solver a -> [Expr] -> IO ()
assume' :: forall a. SmtFormat a => Solver a -> [Expr] -> IO ()
assume' Solver a
s [Expr]
cs = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Expr]
cs (forall a. Show a => Solver a -> a -> IO ()
send Solver a
s forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. SmtFormat a => Expr -> a
assert forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Expr
bsimpl)

-- | Check if a series of expressions are entailed by the axioms or assumptions
-- already registered with the solver.
entailed :: SmtFormat a => Solver a -> [Expr] -> IO SatResult
entailed :: forall a. SmtFormat a => Solver a -> [Expr] -> IO SatResult
entailed Solver a
s [Expr]
cs = do
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Backend a -> Bool
incremental forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Backend a
backend Solver a
s) forall a b. (a -> b) -> a -> b
$ forall a. Show a => Solver a -> a -> IO ()
send Solver a
s forall a. SmtFormat a => a
push
  case [Expr]
cs of
      []  -> String -> IO ()
putStrLn String
"Warning: no proposition to prove." forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. SmtFormat a => Solver a -> [Expr] -> IO ()
assume' Solver a
s [Bool -> Expr
ConstB Bool
True]
      [Expr]
_   -> forall a. SmtFormat a => Solver a -> [Expr] -> IO ()
assume' Solver a
s [forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 (Type -> Op2 -> Expr -> Expr -> Expr
Op2 Type
Bool Op2
Or) (forall a b. (a -> b) -> [a] -> [b]
map (Type -> Op1 -> Expr -> Expr
Op1 Type
Bool Op1
Not) [Expr]
cs)]
  forall a. Show a => Solver a -> a -> IO ()
send Solver a
s forall a. SmtFormat a => a
checkSat
  (forall a. Backend a -> Handle -> IO ()
inputTerminator forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Backend a
backend Solver a
s) (forall a. Solver a -> Handle
inh Solver a
s)

  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Backend a -> Bool
incremental forall a b. (a -> b) -> a -> b
$ forall a. Solver a -> Backend a
backend Solver a
s) forall a b. (a -> b) -> a -> b
$ forall a. Show a => Solver a -> a -> IO ()
send Solver a
s forall a. SmtFormat a => a
pop
  forall a. Solver a -> IO SatResult
receive Solver a
s

-- | Register the given variables with the solver.
declVars :: SmtFormat a => Solver a -> [VarDescr] -> IO (Solver a)
declVars :: forall a. SmtFormat a => Solver a -> [VarDescr] -> IO (Solver a)
declVars s :: Solver a
s@(Solver { Set VarDescr
vars :: Set VarDescr
vars :: forall a. Solver a -> Set VarDescr
vars }) [VarDescr]
decls = do
  let newVars :: [VarDescr]
newVars = forall a. Set a -> [a]
elems forall a b. (a -> b) -> a -> b
$ forall a. Ord a => [a] -> Set a
fromList [VarDescr]
decls forall a. Ord a => Set a -> Set a -> Set a
\\ Set VarDescr
vars
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VarDescr]
newVars forall a b. (a -> b) -> a -> b
$ \(VarDescr {String
varName :: VarDescr -> String
varName :: String
varName, Type
varType :: VarDescr -> Type
varType :: Type
varType, [Type]
args :: VarDescr -> [Type]
args :: [Type]
args}) ->
    forall a. Show a => Solver a -> a -> IO ()
send Solver a
s forall a b. (a -> b) -> a -> b
$ forall a. SmtFormat a => String -> Type -> [Type] -> a
declFun String
varName Type
varType [Type]
args
  forall (m :: * -> *) a. Monad m => a -> m a
return Solver a
s { vars :: Set VarDescr
vars = Set VarDescr
vars forall a. Ord a => Set a -> Set a -> Set a
`union` forall a. Ord a => [a] -> Set a
fromList [VarDescr]
newVars }