--------------------------------------------------------------------------------

{-# LANGUAGE GADTs, FlexibleInstances #-}
{-# LANGUAGE Safe #-}

-- | A backend to the SMT-Lib format, enabling to produce commands for SMT-Lib
-- implementing solvers, and parse results.
module Copilot.Theorem.Prover.SMTLib (SmtLib, interpret) where

import Copilot.Theorem.Prover.Backend (SmtFormat (..), SatResult (..))

import Copilot.Theorem.IL
import Copilot.Theorem.Misc.SExpr

import Text.Printf

--------------------------------------------------------------------------------

-- | Type used to represent SMT-lib commands.
--
-- Use the interface in 'SmtFormat' to create such commands.
newtype SmtLib = SmtLib (SExpr String)

instance Show SmtLib where
  show :: SmtLib -> String
show (SmtLib s :: SExpr String
s) = SExpr String -> String
forall a. Show a => a -> String
show SExpr String
s

smtTy :: Type -> String
smtTy :: Type -> String
smtTy Bool    = "Bool"
smtTy Real    = "Real"
smtTy _       = "Int"

--------------------------------------------------------------------------------

-- | Interface for SMT-Lib conforming backends.
instance SmtFormat SmtLib where
  push :: SmtLib
push = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node "push" [String -> SExpr String
forall a. a -> SExpr a
atom "1"]
  pop :: SmtLib
pop = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node "pop" [String -> SExpr String
forall a. a -> SExpr a
atom "1"]
  checkSat :: SmtLib
checkSat = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> SExpr String
forall a. a -> SExpr a
singleton "check-sat"
  setLogic :: String -> SmtLib
setLogic "" = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ SExpr String
blank
  setLogic l :: String
l = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node "set-logic" [String -> SExpr String
forall a. a -> SExpr a
atom String
l]
  declFun :: String -> Type -> [Type] -> SmtLib
declFun name :: String
name retTy :: Type
retTy args :: [Type]
args = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$
    String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node "declare-fun" [String -> SExpr String
forall a. a -> SExpr a
atom String
name, ([SExpr String] -> SExpr String
forall a. [SExpr a] -> SExpr a
list ([SExpr String] -> SExpr String) -> [SExpr String] -> SExpr String
forall a b. (a -> b) -> a -> b
$ (Type -> SExpr String) -> [Type] -> [SExpr String]
forall a b. (a -> b) -> [a] -> [b]
map (String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String)
-> (Type -> String) -> Type -> SExpr String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> String
smtTy) [Type]
args), String -> SExpr String
forall a. a -> SExpr a
atom (Type -> String
smtTy Type
retTy)]
  assert :: Expr -> SmtLib
assert c :: Expr
c = SExpr String -> SmtLib
SmtLib (SExpr String -> SmtLib) -> SExpr String -> SmtLib
forall a b. (a -> b) -> a -> b
$ String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node "assert" [Expr -> SExpr String
expr Expr
c]

-- | Parse a satisfiability result.
interpret :: String -> Maybe SatResult
interpret :: String -> Maybe SatResult
interpret "sat"   = SatResult -> Maybe SatResult
forall a. a -> Maybe a
Just SatResult
Sat
interpret "unsat" = SatResult -> Maybe SatResult
forall a. a -> Maybe a
Just SatResult
Unsat
interpret _       = SatResult -> Maybe SatResult
forall a. a -> Maybe a
Just SatResult
Unknown

--------------------------------------------------------------------------------

expr :: Expr -> SExpr String

expr :: Expr -> SExpr String
expr (ConstB v :: Bool
v) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ if Bool
v then "true" else "false"
expr (ConstI _ v :: Integer
v) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ Integer -> String
forall a. Show a => a -> String
show Integer
v
expr (ConstR v :: Double
v) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ String -> Double -> String
forall r. PrintfType r => String -> r
printf "%f" Double
v

expr (Ite _ cond :: Expr
cond e1 :: Expr
e1 e2 :: Expr
e2) = String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node "ite" [Expr -> SExpr String
expr Expr
cond, Expr -> SExpr String
expr Expr
e1, Expr -> SExpr String
expr Expr
e2]

expr (FunApp _ funName :: String
funName args :: [Expr]
args) = String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
funName ([SExpr String] -> SExpr String) -> [SExpr String] -> SExpr String
forall a b. (a -> b) -> a -> b
$ (Expr -> SExpr String) -> [Expr] -> [SExpr String]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> SExpr String
expr [Expr]
args

expr (Op1 _ op :: Op1
op e :: Expr
e) =
  String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
smtOp [Expr -> SExpr String
expr Expr
e]
  where
    smtOp :: String
smtOp = case Op1
op of
      Not   -> "not"
      Neg   -> "-"
      Abs   -> "abs"
      Exp   -> "exp"
      Sqrt  -> "sqrt"
      Log   -> "log"
      Sin   -> "sin"
      Tan   -> "tan"
      Cos   -> "cos"
      Asin  -> "asin"
      Atan  -> "atan"
      Acos  -> "acos"
      Sinh  -> "sinh"
      Tanh  -> "tanh"
      Cosh  -> "cosh"
      Asinh -> "asinh"
      Atanh -> "atanh"
      Acosh -> "acosh"

expr (Op2 _ op :: Op2
op e1 :: Expr
e1 e2 :: Expr
e2) =
  String -> [SExpr String] -> SExpr String
forall a. a -> [SExpr a] -> SExpr a
node String
smtOp [Expr -> SExpr String
expr Expr
e1, Expr -> SExpr String
expr Expr
e2]
  where
    smtOp :: String
smtOp = case Op2
op of
      Eq   -> "="
      Le   -> "<="
      Lt   -> "<"
      Ge   -> ">="
      Gt   -> ">"
      And  -> "and"
      Or   -> "or"
      Add  -> "+"
      Sub  -> "-"
      Mul  -> "*"
      Mod  -> "mod"
      Fdiv -> "/"
      Pow  -> "^"

expr (SVal _ f :: String
f ix :: SeqIndex
ix) = String -> SExpr String
forall a. a -> SExpr a
atom (String -> SExpr String) -> String -> SExpr String
forall a b. (a -> b) -> a -> b
$ case SeqIndex
ix of
  Fixed i :: Integer
i -> String
f String -> ShowS
forall a. [a] -> [a] -> [a]
++ "_" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i
  Var off :: Integer
off -> String
f String -> ShowS
forall a. [a] -> [a] -> [a]
++ "_n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
off

--------------------------------------------------------------------------------