{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE InstanceSigs #-}

{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}

{-# LANGUAGE UndecidableInstances #-}

-- | This module contains functionality for creating SMTLIB expressions and interacting
--   with an SMT solver.
module Language.REST.SMT
  (
    checkSat
  , checkSat'
  , getModel
  , parseModel
  , killZ3
  , spawnZ3
  , smtAdd
  , smtAnd
  , smtFalse
  , smtGTE
  , smtTrue
  , withZ3
  , SolverHandle
  , SMTExpr(..)
  , SMTVar(..)
  , ToSMT(..)
  , ToSMTVar(..)
  , Z3Model
) where

import Control.Monad.IO.Class
import Data.Hashable
import qualified Data.Map as M
import qualified Data.List as L
import qualified Data.Set as S
import qualified Data.Text as T
import System.Process
import Text.Parsec (endBy)
import Text.Parsec.Prim
import Text.ParserCombinators.Parsec.Char
import GHC.Generics (Generic)
import GHC.IO.Handle

-- | A model returned by Z3 corresponding to a satisfiable
--   set of constraints. Untyped.
type Z3Model = M.Map String String

parens :: Text.Parsec.Prim.Stream s m Char => ParsecT s u m a -> ParsecT s u m a
parens :: forall s (m :: * -> *) u a.
Stream s m Char =>
ParsecT s u m a -> ParsecT s u m a
parens ParsecT s u m a
p = do
  Char
_ <- forall s (m :: * -> *) u.
Stream s m Char =>
Char -> ParsecT s u m Char
char Char
'('
  a
r <- ParsecT s u m a
p
  Char
_ <- forall s (m :: * -> *) u.
Stream s m Char =>
Char -> ParsecT s u m Char
char Char
')'
  forall (m :: * -> *) a. Monad m => a -> m a
return a
r

parseFunDef :: Text.Parsec.Prim.Stream s m Char => ParsecT s u m (String, String)
parseFunDef :: forall s (m :: * -> *) u.
Stream s m Char =>
ParsecT s u m ([Char], [Char])
parseFunDef = forall s (m :: * -> *) u a.
Stream s m Char =>
ParsecT s u m a -> ParsecT s u m a
parens forall a b. (a -> b) -> a -> b
$ do
  [Char]
_     <- forall s (m :: * -> *) u.
Stream s m Char =>
[Char] -> ParsecT s u m [Char]
string [Char]
"define-fun "
  [Char]
var   <- forall s u (m :: * -> *) a. ParsecT s u m a -> ParsecT s u m [a]
many (forall s (m :: * -> *) u.
Stream s m Char =>
[Char] -> ParsecT s u m Char
noneOf [Char]
" ")
  ()
_     <- forall s (m :: * -> *) u. Stream s m Char => ParsecT s u m ()
spaces
  [Char]
_     <- forall s u (m :: * -> *) a. ParsecT s u m a -> ParsecT s u m [a]
many (forall s (m :: * -> *) u.
Stream s m Char =>
[Char] -> ParsecT s u m Char
noneOf [Char]
" ") -- args
  ()
_     <- forall s (m :: * -> *) u. Stream s m Char => ParsecT s u m ()
spaces
  [Char]
_     <- forall s u (m :: * -> *) a. ParsecT s u m a -> ParsecT s u m [a]
many (forall s (m :: * -> *) u.
Stream s m Char =>
[Char] -> ParsecT s u m Char
noneOf [Char]
" ") -- type
  ()
_     <- forall s (m :: * -> *) u. Stream s m Char => ParsecT s u m ()
spaces
  [Char]
value <- forall s u (m :: * -> *) a. ParsecT s u m a -> ParsecT s u m [a]
many (forall s (m :: * -> *) u.
Stream s m Char =>
[Char] -> ParsecT s u m Char
noneOf [Char]
")")
  forall (m :: * -> *) a. Monad m => a -> m a
return ([Char]
var, [Char]
value)

modelParser :: Text.Parsec.Prim.Stream s m Char => ParsecT s u m Z3Model
modelParser :: forall s (m :: * -> *) u. Stream s m Char => ParsecT s u m Z3Model
modelParser = forall s (m :: * -> *) u a.
Stream s m Char =>
ParsecT s u m a -> ParsecT s u m a
parens forall a b. (a -> b) -> a -> b
$ do
  forall s (m :: * -> *) u. Stream s m Char => ParsecT s u m ()
spaces
  [([Char], [Char])]
defs <- forall s (m :: * -> *) t u a sep.
Stream s m t =>
ParsecT s u m a -> ParsecT s u m sep -> ParsecT s u m [a]
endBy forall s (m :: * -> *) u.
Stream s m Char =>
ParsecT s u m ([Char], [Char])
parseFunDef forall s (m :: * -> *) u. Stream s m Char => ParsecT s u m ()
spaces
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [([Char], [Char])]
defs

-- | Parses Z3's model string into a 'Z3Model'.
parseModel :: String -> Z3Model
parseModel :: [Char] -> Z3Model
parseModel [Char]
str = case forall s t a.
Stream s Identity t =>
Parsec s () a -> [Char] -> s -> Either ParseError a
parse forall s (m :: * -> *) u. Stream s m Char => ParsecT s u m Z3Model
modelParser [Char]
"" [Char]
str of
  Left ParseError
err -> forall a. HasCallStack => [Char] -> a
error (forall a. Show a => a -> [Char]
show ParseError
err)
  Right Z3Model
t  -> Z3Model
t

-- | An SMT variable
newtype SMTVar a = SMTVar T.Text deriving (SMTVar a -> SMTVar a -> Bool
forall a. SMTVar a -> SMTVar a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SMTVar a -> SMTVar a -> Bool
$c/= :: forall a. SMTVar a -> SMTVar a -> Bool
== :: SMTVar a -> SMTVar a -> Bool
$c== :: forall a. SMTVar a -> SMTVar a -> Bool
Eq, SMTVar a -> SMTVar a -> Bool
SMTVar a -> SMTVar a -> Ordering
SMTVar a -> SMTVar a -> SMTVar a
forall a. Eq (SMTVar a)
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. SMTVar a -> SMTVar a -> Bool
forall a. SMTVar a -> SMTVar a -> Ordering
forall a. SMTVar a -> SMTVar a -> SMTVar a
min :: SMTVar a -> SMTVar a -> SMTVar a
$cmin :: forall a. SMTVar a -> SMTVar a -> SMTVar a
max :: SMTVar a -> SMTVar a -> SMTVar a
$cmax :: forall a. SMTVar a -> SMTVar a -> SMTVar a
>= :: SMTVar a -> SMTVar a -> Bool
$c>= :: forall a. SMTVar a -> SMTVar a -> Bool
> :: SMTVar a -> SMTVar a -> Bool
$c> :: forall a. SMTVar a -> SMTVar a -> Bool
<= :: SMTVar a -> SMTVar a -> Bool
$c<= :: forall a. SMTVar a -> SMTVar a -> Bool
< :: SMTVar a -> SMTVar a -> Bool
$c< :: forall a. SMTVar a -> SMTVar a -> Bool
compare :: SMTVar a -> SMTVar a -> Ordering
$ccompare :: forall a. SMTVar a -> SMTVar a -> Ordering
Ord)

-- | SMTLib expressions
data SMTExpr a where
    And     :: [SMTExpr Bool] -> SMTExpr Bool
    Add     :: [SMTExpr Int]  -> SMTExpr Int
    Or      :: [SMTExpr Bool] -> SMTExpr Bool
    Equal   :: [SMTExpr a]    -> SMTExpr Bool
    Greater :: SMTExpr Int    -> SMTExpr Int  -> SMTExpr Bool
    GTE     :: SMTExpr Int    -> SMTExpr Int  -> SMTExpr Bool
    Implies :: SMTExpr Bool   -> SMTExpr Bool -> SMTExpr Bool
    Var     :: SMTVar a       -> SMTExpr a
    Const   :: Int            -> SMTExpr Int


data UntypedExpr =
    UAnd [UntypedExpr]
  | UAdd [UntypedExpr]
  | UOr  [UntypedExpr]
  | UEqual  [UntypedExpr]
  | UGreater UntypedExpr UntypedExpr
  | UGTE UntypedExpr UntypedExpr
  | UImplies UntypedExpr UntypedExpr
  | UVar T.Text
  | UConst Int
  deriving (Int -> UntypedExpr -> ShowS
[UntypedExpr] -> ShowS
UntypedExpr -> [Char]
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [UntypedExpr] -> ShowS
$cshowList :: [UntypedExpr] -> ShowS
show :: UntypedExpr -> [Char]
$cshow :: UntypedExpr -> [Char]
showsPrec :: Int -> UntypedExpr -> ShowS
$cshowsPrec :: Int -> UntypedExpr -> ShowS
Show, UntypedExpr -> UntypedExpr -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: UntypedExpr -> UntypedExpr -> Bool
$c/= :: UntypedExpr -> UntypedExpr -> Bool
== :: UntypedExpr -> UntypedExpr -> Bool
$c== :: UntypedExpr -> UntypedExpr -> Bool
Eq, Eq UntypedExpr
UntypedExpr -> UntypedExpr -> Bool
UntypedExpr -> UntypedExpr -> Ordering
UntypedExpr -> UntypedExpr -> UntypedExpr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: UntypedExpr -> UntypedExpr -> UntypedExpr
$cmin :: UntypedExpr -> UntypedExpr -> UntypedExpr
max :: UntypedExpr -> UntypedExpr -> UntypedExpr
$cmax :: UntypedExpr -> UntypedExpr -> UntypedExpr
>= :: UntypedExpr -> UntypedExpr -> Bool
$c>= :: UntypedExpr -> UntypedExpr -> Bool
> :: UntypedExpr -> UntypedExpr -> Bool
$c> :: UntypedExpr -> UntypedExpr -> Bool
<= :: UntypedExpr -> UntypedExpr -> Bool
$c<= :: UntypedExpr -> UntypedExpr -> Bool
< :: UntypedExpr -> UntypedExpr -> Bool
$c< :: UntypedExpr -> UntypedExpr -> Bool
compare :: UntypedExpr -> UntypedExpr -> Ordering
$ccompare :: UntypedExpr -> UntypedExpr -> Ordering
Ord, Eq UntypedExpr
Int -> UntypedExpr -> Int
UntypedExpr -> Int
forall a. Eq a -> (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: UntypedExpr -> Int
$chash :: UntypedExpr -> Int
hashWithSalt :: Int -> UntypedExpr -> Int
$chashWithSalt :: Int -> UntypedExpr -> Int
Hashable, forall x. Rep UntypedExpr x -> UntypedExpr
forall x. UntypedExpr -> Rep UntypedExpr x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep UntypedExpr x -> UntypedExpr
$cfrom :: forall x. UntypedExpr -> Rep UntypedExpr x
Generic)

toUntyped :: SMTExpr a -> UntypedExpr
toUntyped :: forall a. SMTExpr a -> UntypedExpr
toUntyped (And [SMTExpr Bool]
xs) = [UntypedExpr] -> UntypedExpr
UAnd (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> UntypedExpr
toUntyped [SMTExpr Bool]
xs)
toUntyped (Add [SMTExpr Int]
xs) = [UntypedExpr] -> UntypedExpr
UAdd (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> UntypedExpr
toUntyped [SMTExpr Int]
xs)
toUntyped (Or [SMTExpr Bool]
xs)  = [UntypedExpr] -> UntypedExpr
UOr (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> UntypedExpr
toUntyped [SMTExpr Bool]
xs)
toUntyped (Equal [SMTExpr a]
xs) = [UntypedExpr] -> UntypedExpr
UEqual (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> UntypedExpr
toUntyped [SMTExpr a]
xs)
toUntyped (Greater SMTExpr Int
t SMTExpr Int
u) = UntypedExpr -> UntypedExpr -> UntypedExpr
UGreater (forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr Int
t) (forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr Int
u)
toUntyped (GTE SMTExpr Int
t SMTExpr Int
u) = UntypedExpr -> UntypedExpr -> UntypedExpr
UGTE (forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr Int
t) (forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr Int
u)
toUntyped (Implies SMTExpr Bool
t SMTExpr Bool
u) = UntypedExpr -> UntypedExpr -> UntypedExpr
UImplies (forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr Bool
t) (forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr Bool
u)
toUntyped (Var (SMTVar Text
text)) = Text -> UntypedExpr
UVar Text
text
toUntyped (Const Int
i) = Int -> UntypedExpr
UConst Int
i

instance (Eq (SMTExpr a)) where
  SMTExpr a
t == :: SMTExpr a -> SMTExpr a -> Bool
== SMTExpr a
u = forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr a
t forall a. Eq a => a -> a -> Bool
== forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr a
u

instance (Ord (SMTExpr a)) where
  SMTExpr a
t <= :: SMTExpr a -> SMTExpr a -> Bool
<= SMTExpr a
u = forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr a
t forall a. Ord a => a -> a -> Bool
<= forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr a
u

instance Hashable (SMTExpr a) where
  hashWithSalt :: Int -> SMTExpr a -> Int
hashWithSalt Int
salt SMTExpr a
e = forall a. Hashable a => Int -> a -> Int
hashWithSalt Int
salt (forall a. SMTExpr a -> UntypedExpr
toUntyped SMTExpr a
e)

instance Show (SMTExpr a) where
  show :: SMTExpr a -> [Char]
show = Text -> [Char]
T.unpack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. SMTExpr a -> Text
toFormula


toFormula :: SMTExpr a -> T.Text
toFormula :: forall a. SMTExpr a -> Text
toFormula = forall a. Bool -> SMTExpr a -> Text
go Bool
False where
  go :: Bool -> SMTExpr a -> T.Text
  go :: forall a. Bool -> SMTExpr a -> Text
go Bool
_ (And [])         = Text
"⊤"
  go Bool
p (And [SMTExpr Bool]
ts)         = Bool -> Text -> Text
eparens Bool
p forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
" ∧ " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Bool -> SMTExpr a -> Text
go (Bool -> Bool
not Bool
p)) [SMTExpr Bool]
ts
  go Bool
p (Add [SMTExpr Int]
ts)         = Bool -> Text -> Text
eparens Bool
p forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
" + " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Bool -> SMTExpr a -> Text
go (Bool -> Bool
not Bool
p)) [SMTExpr Int]
ts
  go Bool
p (GTE SMTExpr Int
t SMTExpr Int
u)        = Bool -> Text -> Text
eparens Bool
p forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
" ≥ " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Bool -> SMTExpr a -> Text
go Bool
True) [SMTExpr Int
t, SMTExpr Int
u]
  go Bool
p (Greater SMTExpr Int
t SMTExpr Int
u)    = Bool -> Text -> Text
eparens Bool
p forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
" > " forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Bool -> SMTExpr a -> Text
go Bool
True) [SMTExpr Int
t, SMTExpr Int
u]
  go Bool
_ (Var (SMTVar Text
v)) = Text
v
  go Bool
_ (Const Int
c)        = [Char] -> Text
T.pack (forall a. Show a => a -> [Char]
show Int
c)
  go Bool
_ SMTExpr a
_e               = forall a. HasCallStack => a
undefined

  eparens :: Bool -> Text -> Text
eparens Bool
True Text
t = [Text] -> Text
T.concat [Text
"(", Text
t, Text
")"]
  eparens Bool
False Text
t = Text
t

vars :: SMTExpr a -> S.Set T.Text
vars :: forall a. SMTExpr a -> Set Text
vars (And [SMTExpr Bool]
ts)        = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> Set Text
vars [SMTExpr Bool]
ts)
vars (Add [SMTExpr Int]
ts)        = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> Set Text
vars [SMTExpr Int]
ts)
vars (Or [SMTExpr Bool]
ts)         = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> Set Text
vars [SMTExpr Bool]
ts)
vars (Equal [SMTExpr a]
ts)      = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> Set Text
vars [SMTExpr a]
ts)
vars (Greater SMTExpr Int
t SMTExpr Int
u)   = forall a. Ord a => Set a -> Set a -> Set a
S.union (forall a. SMTExpr a -> Set Text
vars SMTExpr Int
t) (forall a. SMTExpr a -> Set Text
vars SMTExpr Int
u)
vars (GTE SMTExpr Int
t SMTExpr Int
u)       = forall a. Ord a => Set a -> Set a -> Set a
S.union (forall a. SMTExpr a -> Set Text
vars SMTExpr Int
t) (forall a. SMTExpr a -> Set Text
vars SMTExpr Int
u)
vars (Var (SMTVar Text
var)) = forall a. a -> Set a
S.singleton Text
var
vars (Implies SMTExpr Bool
e1 SMTExpr Bool
e2) = forall a. Ord a => Set a -> Set a -> Set a
S.union (forall a. SMTExpr a -> Set Text
vars SMTExpr Bool
e1) (forall a. SMTExpr a -> Set Text
vars SMTExpr Bool
e2)
vars (Const Int
_)       = forall a. Set a
S.empty

data SMTCommand = SMTAssert (SMTExpr Bool) | DeclareVar T.Text | CheckSat | Push | Pop

smtFalse :: SMTExpr Bool
smtFalse :: SMTExpr Bool
smtFalse = [SMTExpr Bool] -> SMTExpr Bool
Or []

smtTrue :: SMTExpr Bool
smtTrue :: SMTExpr Bool
smtTrue  = [SMTExpr Bool] -> SMTExpr Bool
And []

-- | Returns an SMT expression that adds all elements in the list. If the list is empty,
--   returns @Const 0@.
smtAdd :: [SMTExpr Int] -> SMTExpr Int
smtAdd :: [SMTExpr Int] -> SMTExpr Int
smtAdd [] = Int -> SMTExpr Int
Const Int
0
smtAdd [SMTExpr Int]
ts = [SMTExpr Int] -> SMTExpr Int
Add [SMTExpr Int]
ts

-- | `smtAnd t u` returns an smt expression representing \( t \land u \).
smtAnd :: SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
smtAnd :: SMTExpr Bool -> SMTExpr Bool -> SMTExpr Bool
smtAnd (And [SMTExpr Bool]
xs) (And [SMTExpr Bool]
ys) = [SMTExpr Bool] -> SMTExpr Bool
And forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> [a]
L.nub ([SMTExpr Bool]
xs forall a. [a] -> [a] -> [a]
++ [SMTExpr Bool]
ys)
smtAnd (And [SMTExpr Bool]
xs) SMTExpr Bool
e        = [SMTExpr Bool] -> SMTExpr Bool
And forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> [a]
L.nub ([SMTExpr Bool]
xs forall a. [a] -> [a] -> [a]
++ [SMTExpr Bool
e])
smtAnd SMTExpr Bool
e        (And [SMTExpr Bool]
ys) = [SMTExpr Bool] -> SMTExpr Bool
And forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> [a]
L.nub (SMTExpr Bool
eforall a. a -> [a] -> [a]
:[SMTExpr Bool]
ys)
smtAnd SMTExpr Bool
t        SMTExpr Bool
u        = [SMTExpr Bool] -> SMTExpr Bool
And [SMTExpr Bool
t, SMTExpr Bool
u]

-- | `smtGTE t u` returns an SMT expression \( t \geqslant u \). If @t == u@, returns 'smtTrue'.
smtGTE :: SMTExpr Int -> SMTExpr Int -> SMTExpr Bool
smtGTE :: SMTExpr Int -> SMTExpr Int -> SMTExpr Bool
smtGTE SMTExpr Int
t SMTExpr Int
u | SMTExpr Int
t forall a. Eq a => a -> a -> Bool
== SMTExpr Int
u    = SMTExpr Bool
smtTrue
smtGTE SMTExpr Int
t SMTExpr Int
u  = SMTExpr Int -> SMTExpr Int -> SMTExpr Bool
GTE SMTExpr Int
t SMTExpr Int
u

app :: T.Text -> [SMTExpr a] -> T.Text
app :: forall a. Text -> [SMTExpr a] -> Text
app Text
op [SMTExpr a]
trms = [Text] -> Text
T.concat [Text
"(", Text
op, Text
" ", Text -> [Text] -> Text
T.intercalate Text
" " (forall a b. (a -> b) -> [a] -> [b]
map forall a. SMTExpr a -> Text
exprString [SMTExpr a]
trms), Text
")"]

exprString :: SMTExpr a -> T.Text
exprString :: forall a. SMTExpr a -> Text
exprString (And [])           = Text
"true"
exprString (Add [SMTExpr Int]
es)           = forall a. Text -> [SMTExpr a] -> Text
app Text
"+" [SMTExpr Int]
es
exprString (Or [])            = Text
"false"
exprString (And   [SMTExpr Bool]
es)         = forall a. Text -> [SMTExpr a] -> Text
app Text
"and" [SMTExpr Bool]
es
exprString (Or    [SMTExpr Bool]
es)         = forall a. Text -> [SMTExpr a] -> Text
app Text
"or" [SMTExpr Bool]
es
exprString (Equal [SMTExpr a]
xs) | forall (t :: * -> *) a. Foldable t => t a -> Int
length [SMTExpr a]
xs forall a. Ord a => a -> a -> Bool
< Int
2 = Text
"true"
exprString (Equal [SMTExpr a]
es)         = forall a. Text -> [SMTExpr a] -> Text
app Text
"=" [SMTExpr a]
es
exprString (Greater SMTExpr Int
e1 SMTExpr Int
e2)    = forall a. Text -> [SMTExpr a] -> Text
app Text
">" [SMTExpr Int
e1, SMTExpr Int
e2]
exprString (GTE SMTExpr Int
e1 SMTExpr Int
e2)        = forall a. Text -> [SMTExpr a] -> Text
app Text
">=" [SMTExpr Int
e1, SMTExpr Int
e2]
exprString (Implies SMTExpr Bool
e1 SMTExpr Bool
e2)    = forall a. Text -> [SMTExpr a] -> Text
app Text
"=>" [SMTExpr Bool
e1, SMTExpr Bool
e2]
exprString (Var (SMTVar Text
var)) = Text
var
exprString (Const Int
i)          = [Char] -> Text
T.pack (forall a. Show a => a -> [Char]
show Int
i)

commandString :: SMTCommand -> T.Text
commandString :: SMTCommand -> Text
commandString (SMTAssert SMTExpr Bool
expr) = forall a. Text -> [SMTExpr a] -> Text
app Text
"assert" [SMTExpr Bool
expr]
commandString (DeclareVar Text
var) = [Text] -> Text
T.concat [Text
"(declare-const ", Text
var,  Text
" Int)"]
commandString SMTCommand
CheckSat = Text
"(check-sat)"
commandString SMTCommand
Push     = Text
"(push)"
commandString SMTCommand
Pop      = Text
"(pop)"

askCmds :: SMTExpr Bool -> [SMTCommand]
askCmds :: SMTExpr Bool -> [SMTCommand]
askCmds SMTExpr Bool
expr = [SMTCommand]
varDecls forall a. [a] -> [a] -> [a]
++ [SMTExpr Bool -> SMTCommand
SMTAssert SMTExpr Bool
expr, SMTCommand
CheckSat] where
  varDecls :: [SMTCommand]
varDecls = forall a b. (a -> b) -> [a] -> [b]
map Text -> SMTCommand
DeclareVar forall a b. (a -> b) -> a -> b
$ forall a. Set a -> [a]
S.toList (forall a. SMTExpr a -> Set Text
vars SMTExpr Bool
expr)

-- | The handle (stdIn, stdOut) used for interacting with Z3
type SolverHandle = (Handle, Handle)

-- | Instantiates a Z3 instance, returning the solver handle for interaction
spawnZ3 :: IO SolverHandle
spawnZ3 :: IO SolverHandle
spawnZ3 = do
  (Just Handle
stdIn, Just Handle
stdOut, Maybe Handle
_, ProcessHandle
_) <- CreateProcess
-> IO (Maybe Handle, Maybe Handle, Maybe Handle, ProcessHandle)
createProcess ([Char] -> [[Char]] -> CreateProcess
proc [Char]
"z3" [[Char]
"-in"]) {std_in :: StdStream
std_in = StdStream
CreatePipe, std_out :: StdStream
std_out = StdStream
CreatePipe}
  forall (m :: * -> *) a. Monad m => a -> m a
return (Handle
stdIn, Handle
stdOut)

-- | Kills the Z3 instance by closing the standard input stream
killZ3 :: SolverHandle -> IO ()
killZ3 :: SolverHandle -> IO ()
killZ3 (Handle
stdIn, Handle
_) = Handle -> IO ()
hClose Handle
stdIn

-- | @withZ3 f@ instantiates a Z3 instance, runs @f@ with that instance,
--   and then closes the instance and returns the result
withZ3 :: MonadIO m => (SolverHandle -> m b) -> m b
withZ3 :: forall (m :: * -> *) b. MonadIO m => (SolverHandle -> m b) -> m b
withZ3 SolverHandle -> m b
f =
  do
    SolverHandle
z3     <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO SolverHandle
spawnZ3
    b
result <- SolverHandle -> m b
f SolverHandle
z3
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ SolverHandle -> IO ()
killZ3 SolverHandle
z3
    forall (m :: * -> *) a. Monad m => a -> m a
return b
result

-- | @getModel@ instructs an instantiated SMT solver to produce its model.
getModel :: Handle -> IO ()
getModel :: Handle -> IO ()
getModel Handle
stdIn = do
  Handle -> [Char] -> IO ()
hPutStr Handle
stdIn [Char]
"(get-model)\n"
  Handle -> IO ()
hFlush Handle
stdIn

-- | @checkSat' handles expr@ checks satisfiability of @expr@ in an instantiated SMT solver.
--   This is wrapped in a @push@ / @pop@, so it does not change the SMT environment
checkSat' :: (Handle,  Handle) -> SMTExpr Bool -> IO Bool
checkSat' :: SolverHandle -> SMTExpr Bool -> IO Bool
checkSat' (Handle
stdIn, Handle
stdOut) SMTExpr Bool
expr = do
  [SMTCommand] -> IO ()
sendCommands forall a b. (a -> b) -> a -> b
$ SMTCommand
Pushforall a. a -> [a] -> [a]
:SMTExpr Bool -> [SMTCommand]
askCmds SMTExpr Bool
expr
  [Char]
result <- Handle -> IO [Char]
hGetLine Handle
stdOut
  Bool
sat <- case [Char]
result of
    [Char]
"sat"   -> do
      -- getModel stdIn
      -- model <- readModel stdOut
      -- putStrLn model
      forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    [Char]
"unsat" -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    [Char]
other   -> forall a. HasCallStack => [Char] -> a
error [Char]
other
  [SMTCommand] -> IO ()
sendCommands [SMTCommand
Pop]
  forall (m :: * -> *) a. Monad m => a -> m a
return Bool
sat
  where
    sendCommands :: [SMTCommand] -> IO ()
sendCommands [SMTCommand]
cmds = do
      Handle -> [Char] -> IO ()
hPutStr Handle
stdIn forall a b. (a -> b) -> a -> b
$ Text -> [Char]
T.unpack (Text -> [Text] -> Text
T.intercalate Text
"\n" (forall a b. (a -> b) -> [a] -> [b]
map SMTCommand -> Text
commandString [SMTCommand]
cmds)) forall a. [a] -> [a] -> [a]
++ [Char]
"\n"
      Handle -> IO ()
hFlush Handle
stdIn

-- | @checkSat expr@ launches Z3, to checks satisfiability of @expr@, terminating Z3
--   afterwards. Just a utility wrapper for `checkSat'`
checkSat :: SMTExpr Bool -> IO Bool
checkSat :: SMTExpr Bool -> IO Bool
checkSat SMTExpr Bool
expr = do
  SolverHandle
z3     <- IO SolverHandle
spawnZ3
  Bool
result <- SolverHandle -> SMTExpr Bool -> IO Bool
checkSat' SolverHandle
z3 SMTExpr Bool
expr
  SolverHandle -> IO ()
killZ3 SolverHandle
z3
  forall (m :: * -> *) a. Monad m => a -> m a
return Bool
result

-- | This class allows elements of type @a@ to be used as SMT /vaiables/ of type @b@.
--   For example, the instance @ToSMTVar Op Int@ allows 'RuntimeTerm' operators to be
--   represented as 'Int' variables.
class ToSMTVar a b | a -> b where
  toSMTVar :: a -> SMTVar b

-- | This class allows elements of type @a@ to be used as SMT expressions of type
--   @b@
class ToSMT a b where
  toSMT :: a -> SMTExpr b

instance ToSMT Int Int where
  toSMT :: Int -> SMTExpr Int
toSMT = Int -> SMTExpr Int
Const

instance {-# OVERLAPPABLE #-} (ToSMTVar a b) => ToSMT a b where
  toSMT :: a -> SMTExpr b
  toSMT :: a -> SMTExpr b
toSMT a
op = forall a. SMTVar a -> SMTExpr a
Var forall a b. (a -> b) -> a -> b
$ forall a b. ToSMTVar a b => a -> SMTVar b
toSMTVar a
op