{-# LANGUAGE EmptyDataDecls #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}

-- | A module providing a backend that sends commands to Z3 using its C API.
module SMTLIB.Backends.Z3
  ( Handle,
    new,
    close,
    with,
    toBackend,
  )
where

import Control.Exception (bracket)
import Data.ByteString.Builder.Extra
  ( defaultChunkSize,
    smallChunkSize,
    toLazyByteStringWith,
    untrimmedStrategy,
  )
import qualified Data.ByteString.Char8 as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.Map as M
import Foreign.ForeignPtr (ForeignPtr, finalizeForeignPtr, newForeignPtr)
import Foreign.Ptr (Ptr)
import qualified Language.C.Inline as C
import qualified Language.C.Inline.Context as C
import qualified Language.C.Inline.Unsafe as CU
import qualified Language.C.Types as C
import SMTLIB.Backends (Backend (..))

data LogicalContext

C.context
  ( C.baseCtx
      <> C.fptrCtx
      <> C.bsCtx
      <> mempty
        { C.ctxTypesTable =
            M.singleton (C.TypeName "Z3_context") [t|Ptr LogicalContext|]
        }
  )
C.include "z3.h"

data Handle = Handle
  { -- | A black-box representing the internal state of the solver.
    Handle -> ForeignPtr LogicalContext
context :: ForeignPtr LogicalContext
  }

-- | Create a new solver instance.
new :: IO Handle
new :: IO Handle
new = do
  let ctxFinalizer :: FunPtr (Ptr LogicalContext -> IO ())
ctxFinalizer =
        [C.funPtr| void free_context(Z3_context ctx) {
                 Z3_del_context(ctx);
                 } |]

  {-
  We set the error handler to ignore errors. That way if an error happens it doesn't
  cause the whole program to crash, and the error message is simply transmitted to
  the Haskell layer inside the output of the 'send' method.
  -}
  ForeignPtr LogicalContext
ctx <-
    FunPtr (Ptr LogicalContext -> IO ())
-> Ptr LogicalContext -> IO (ForeignPtr LogicalContext)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FunPtr (Ptr LogicalContext -> IO ())
ctxFinalizer
      (Ptr LogicalContext -> IO (ForeignPtr LogicalContext))
-> IO (Ptr LogicalContext) -> IO (ForeignPtr LogicalContext)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.block| Z3_context {
                 Z3_config cfg = Z3_mk_config();
                 Z3_context ctx = Z3_mk_context(cfg);
                 Z3_del_config(cfg);

                 void ignore_error(Z3_context c, Z3_error_code e) {}
                 Z3_set_error_handler(ctx, ignore_error);

                 return ctx;
                 } |]
  Handle -> IO Handle
forall (m :: * -> *) a. Monad m => a -> m a
return (Handle -> IO Handle) -> Handle -> IO Handle
forall a b. (a -> b) -> a -> b
$ ForeignPtr LogicalContext -> Handle
Handle ForeignPtr LogicalContext
ctx

-- | Release the resources associated with a Z3 instance.
close :: Handle -> IO ()
close :: Handle -> IO ()
close = ForeignPtr LogicalContext -> IO ()
forall a. ForeignPtr a -> IO ()
finalizeForeignPtr (ForeignPtr LogicalContext -> IO ())
-> (Handle -> ForeignPtr LogicalContext) -> Handle -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handle -> ForeignPtr LogicalContext
context

-- | Create a Z3 instance, use it to run a computation and release its resources.
with :: (Handle -> IO a) -> IO a
with :: forall a. (Handle -> IO a) -> IO a
with = IO Handle -> (Handle -> IO ()) -> (Handle -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO Handle
new Handle -> IO ()
close

-- | Create a solver backend out of a Z3 instance.
toBackend :: Handle -> Backend
toBackend :: Handle -> Backend
toBackend Handle
handle =
  (Builder -> IO ByteString) -> Backend
Backend ((Builder -> IO ByteString) -> Backend)
-> (Builder -> IO ByteString) -> Backend
forall a b. (a -> b) -> a -> b
$ \Builder
cmd -> do
    let ctx :: ForeignPtr LogicalContext
ctx = Handle -> ForeignPtr LogicalContext
context Handle
handle
    let cmd' :: ByteString
cmd' =
          ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$
            AllocationStrategy -> ByteString -> Builder -> ByteString
toLazyByteStringWith
              (Int -> Int -> AllocationStrategy
untrimmedStrategy Int
smallChunkSize Int
defaultChunkSize)
              ByteString
"\NUL"
              Builder
cmd
    ByteString -> ByteString
LBS.fromStrict
      (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( CString -> IO ByteString
BS.packCString
              (CString -> IO ByteString) -> IO CString -> IO ByteString
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [CU.exp| const char* {
               Z3_eval_smtlib2_string($fptr-ptr:(Z3_context ctx), $bs-ptr:cmd')
               }|]
          )