{-# LANGUAGE BangPatterns              #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings         #-}
{-# LANGUAGE RecordWildCards           #-}
{-# LANGUAGE UndecidableInstances      #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE PatternGuards             #-}
{-# LANGUAGE DoAndIfThenElse           #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
{-# HLINT ignore "Use isNothing" #-}

-- | This module contains an SMTLIB2 interface for
--   1. checking the validity, and,
--   2. computing satisfying assignments
--   for formulas.
--   By implementing a binary interface over the SMTLIB2 format defined at
--   http://www.smt-lib.org/
--   http://www.grammatech.com/resource/smt/SMTLIBTutorial.pdf

module Language.Fixpoint.Smt.Interface (

    -- * Commands
      Command  (..)

    -- * Responses
    , Response (..)

    -- * Typeclass for SMTLIB2 conversion
    , SMTLIB2 (..)

    -- * Creating and killing SMTLIB2 Process
    , Context (..)
    , makeContext
    , makeContextNoLog
    , makeContextWithSEnv
    , cleanupContext

    -- * Execute Queries
    , command
    , smtSetMbqi

    -- * Query API
    , smtDecl
    , smtDecls
    , smtDefineFunc
    , smtAssert
    , smtAssertDecl
    , smtFuncDecl
    , smtAssertAxiom
    , smtCheckUnsat
    , smtBracket, smtBracketAt
    , smtDistinct
    , smtPush, smtPop
    , smtComment

    -- * Check Validity
    , checkValid
    , checkValid'
    , checkValidWithContext
    , checkValids

    , funcSortVars

    ) where

import           Language.Fixpoint.Types.Config ( SMTSolver (..), solverFlags
                                                , Config (solver, smtTimeout, noStringTheory, save, allowHO))
import qualified Language.Fixpoint.Misc          as Misc
import           Language.Fixpoint.Types.Errors
import           Language.Fixpoint.Utils.Files
import           Language.Fixpoint.Types         hiding (allowHO)
import qualified Language.Fixpoint.Types         as F
import           Language.Fixpoint.Smt.Types
import qualified Language.Fixpoint.Smt.Theories as Thy
import           Language.Fixpoint.Smt.Serialize ()
import           Control.Applicative      ((<|>))
import           Control.Monad
import           Control.Monad.State
import           Control.Exception
import           Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Builder as BS
import qualified Data.ByteString.Lazy as LBS
import qualified Data.ByteString.Lazy.Char8 as Char8
import           Data.Char
import qualified Data.HashMap.Strict      as M
import           Data.List                (uncons)
import           Data.Maybe              (fromMaybe)
import qualified Data.Text                as T
import qualified Data.Text.Encoding       as TE
import qualified Data.Text.IO
-- import           Data.Text.Format
import qualified Data.Text.Lazy.IO        as LTIO
import           System.Directory
import           System.Console.CmdArgs.Verbosity
import           System.FilePath
import           System.IO
import qualified Data.Attoparsec.Text     as A
-- import qualified Data.HashMap.Strict      as M
import           Data.Attoparsec.Internal.Types (Parser)
import           Text.PrettyPrint.HughesPJ (text)
import           Language.Fixpoint.SortCheck
import           Language.Fixpoint.Utils.Builder as Builder
-- import qualified Language.Fixpoint.Types as F
-- import           Language.Fixpoint.Types.PrettyPrint (tracepp)
import qualified SMTLIB.Backends
import qualified SMTLIB.Backends.Process as Process
import qualified Language.Fixpoint.Conditional.Z3 as Conditional.Z3
import Control.Concurrent.Async (async)
import GHC.Stack (HasCallStack)

{-
runFile f
  = readFile f >>= runString

runString str
  = runCommands $ rr str

runCommands cmds
  = do me   <- makeContext Z3
       mapM_ (T.putStrLn . smt2) cmds
       zs   <- mapM (command me) cmds
       return zs
-}

checkValidWithContext
  :: HasCallStack => [(Symbol, Sort)] -> Expr -> Expr -> SmtM Bool
checkValidWithContext xts p q =
  smtBracket "checkValidWithContext" $
    checkValid' xts p q

-- | type ClosedPred E = {v:Pred | subset (vars v) (keys E) }
-- checkValid :: e:Env -> ClosedPred e -> ClosedPred e -> IO Bool
checkValid
  :: HasCallStack
  => Config -> FilePath -> [(Symbol, Sort)] -> Expr -> Expr -> IO Bool
checkValid cfg f xts p q = do
  me <- makeContext cfg f
  evalStateT (checkValid' xts p q) me

checkValid' :: HasCallStack => [(Symbol, Sort)] -> Expr -> Expr -> SmtM Bool
checkValid' xts p q = do
  smtDecls xts
  smtAssertDecl $ pAnd [p, PNot q]
  smtCheckUnsat

-- | If you already HAVE a context, where all the variables have declared types
--   (e.g. if you want to make MANY repeated Queries)

-- checkValid :: e:Env -> [ClosedPred e] -> IO [Bool]
checkValids :: Config -> FilePath -> [(Symbol, Sort)] -> [Expr] -> IO [Bool]
checkValids cfg f xts ps = do
  me <- makeContext cfg f
  evalStateT (checkValids' xts ps) me

checkValids' :: [(Symbol, Sort)] -> [Expr] -> SmtM [Bool]
checkValids' xts ps = do
  smtDecls xts
  forM ps $ \p ->
     smtBracket "checkValids" $
       smtAssert (PNot p) >> smtCheckUnsat

-- debugFile :: FilePath
-- debugFile = "DEBUG.smt2"

--------------------------------------------------------------------------------
-- | SMT IO --------------------------------------------------------------------
--------------------------------------------------------------------------------

commandRaw :: Maybe Handle -> SMTLIB.Backends.Solver -> Bool -> Builder -> IO Response
commandRaw ctxLog ctxSolver ctxVerbose cmdBS = do
  resp <- SMTLIB.Backends.command ctxSolver cmdBS
  let respTxt =
        TE.decodeUtf8With (const $ const $ Just ' ') $
        LBS.toStrict resp
  case A.parseOnly responseP respTxt of
    Left e  -> Misc.errorstar $ "SMTREAD:" ++ e
    Right r -> do
      let textResponse = "; SMT Says: " <> T.pack (show r)
      forM_ ctxLog $ \h ->
        Data.Text.IO.hPutStrLn h textResponse
      when ctxVerbose $
        Data.Text.IO.putStrLn textResponse
      return r

--------------------------------------------------------------------------------
{-# SCC command #-}
command  :: HasCallStack => Command -> SmtM Response
--------------------------------------------------------------------------------
command !cmd       = do
  -- whenLoud $ do LTIO.appendFile debugFile (s <> "\n")
  --               LTIO.putStrLn ("CMD-RAW:" <> s <> ":CMD-RAW:DONE")
  ctxLog <- gets ctxLog
  ctxSolver <- gets ctxSolver
  ctxVerbose <- gets ctxVerbose
  cmdBS <- liftSym $ runSmt2 cmd
  forM_ ctxLog $ \h -> lift $ do
    BS.hPutBuilder h cmdBS
    LBS.hPutStr h "\n"
  lift $ case cmd of
    CheckSat   -> commandRaw ctxLog ctxSolver ctxVerbose cmdBS
    GetValue _ -> commandRaw ctxLog ctxSolver ctxVerbose cmdBS
    _          -> SMTLIB.Backends.command_ ctxSolver cmdBS >> return Ok

-- | A variant of `command` that accepts a pre-built command
commandB :: Builder -> SmtM Response
--------------------------------------------------------------------------------
commandB cmdBS       = do
  ctxLog <- gets ctxLog
  ctxSolver <- gets ctxSolver
  forM_ ctxLog $ \h -> lift $ do
    BS.hPutBuilder h cmdBS
    LBS.hPutStr h "\n"
  lift $ SMTLIB.Backends.command_ ctxSolver cmdBS >> return Ok

smtSetMbqi :: SmtM ()
smtSetMbqi = interact' SetMbqi

type SmtParser a = Parser T.Text a

responseP :: SmtParser Response
responseP = {- SCC "responseP" -} A.char '(' *> sexpP
         <|> A.string "sat"     *> return Sat
         <|> A.string "unsat"   *> return Unsat
         <|> A.string "unknown" *> return Unknown

sexpP :: SmtParser Response
sexpP = {- SCC "sexpP" -} A.string "error" *> (Error <$> errorP)
     <|> Values <$> valuesP

errorP :: SmtParser T.Text
errorP = A.skipSpace *> A.char '"' *> A.takeWhile1 (/='"') <* A.string "\")"

valuesP :: SmtParser [(Symbol, T.Text)]
valuesP = A.many1' pairP <* A.char ')'

pairP :: SmtParser (Symbol, T.Text)
pairP = {- SCC "pairP" -}
  do A.skipSpace
     _ <- A.char '('
     !x <- symbolP
     A.skipSpace
     !v <- valueP
     _ <- A.char ')'
     return (x,v)

symbolP :: SmtParser Symbol
symbolP = {- SCC "symbolP" -} symbol <$> A.takeWhile1 (not . isSpace)

valueP :: SmtParser T.Text
valueP = {- SCC "valueP" -} negativeP
      <|> A.takeWhile1 (\c -> not (c == ')' || isSpace c))

negativeP :: SmtParser T.Text
negativeP
  = do v <- A.char '(' *> A.takeWhile1 (/=')') <* A.char ')'
       return $ "(" <> v <> ")"

--------------------------------------------------------------------------
-- | SMT Context ---------------------------------------------------------
--------------------------------------------------------------------------

--------------------------------------------------------------------------
makeContext :: Config -> FilePath -> IO Context
--------------------------------------------------------------------------
makeContext cfg f
  = do mb_hLog <- if not (save cfg) then pure Nothing else do
           createDirectoryIfMissing True $ takeDirectory smtFile
           hLog <- openFile smtFile WriteMode
           hSetBuffering hLog $ BlockBuffering $ Just $ 1024 * 1024 * 64
           return $ Just hLog
       me   <- makeContext' cfg mb_hLog
       pre  <- smtPreamble cfg (solver cfg) me
       forM_ pre $ \line -> do
           SMTLIB.Backends.command_ (ctxSolver me) line
           forM_ mb_hLog $ \hLog -> do
               BS.hPutBuilder hLog line
               LBS.hPutStr hLog "\n"
       return me
    where
       smtFile = extFileName Smt2 f

makeContextWithSEnv :: Config -> FilePath -> SymEnv -> DefinedFuns -> IO Context
makeContextWithSEnv cfg f env defns = do
  ctx      <- makeContext cfg f
  let ctx' = ctx {ctxSymEnv = env, ctxDefines = defns}
  execStateT declare ctx'

makeContextNoLog :: Config -> IO Context
makeContextNoLog cfg = do
  me  <- makeContext' cfg Nothing
  pre <- smtPreamble cfg (solver cfg) me
  mapM_ (SMTLIB.Backends.command_ (ctxSolver me)) pre
  return me

makeProcess
  :: Maybe Handle
  -> Process.Config
  -> IO (SMTLIB.Backends.Backend, IO ())
makeProcess ctxLog cfg
  = do handl@Process.Handle {hMaybeErr = Just hErr, ..} <- Process.new cfg
       case ctxLog of
         Nothing -> return ()
         Just hLog -> void $ async $ forever
           (do errTxt <- LTIO.hGetLine hErr
               LTIO.hPutStrLn hLog $ "OOPS, SMT solver error:" <> errTxt
           ) `catch` \ SomeException {} -> return ()
       let backend = Process.toBackend handl
       hSetBuffering hOut $ BlockBuffering $ Just $ 1024 * 1024 * 64
       hSetBuffering hIn $ BlockBuffering $ Just $ 1024 * 1024 * 64
       return (backend, Process.close handl)

makeContext' :: Config -> Maybe Handle -> IO Context
makeContext' cfg ctxLog
  = do let slv = solver cfg
       (backend, closeIO) <- case slv of
         Z3      ->
           {- "z3 -smt2 -in"                   -}
           {- "z3 -smtc SOFT_TIMEOUT=1000 -in" -}
           {- "z3 -smtc -in MBQI=false"        -}
           makeProcess ctxLog $ Process.defaultConfig
                             { Process.exe = "z3"
                             , Process.args = ["-smt2", "-in"] }
         Z3mem   -> Conditional.Z3.makeZ3
         Mathsat -> makeProcess ctxLog $ Process.defaultConfig
                             { Process.exe = "mathsat"
                             , Process.args = ["-input=smt2"] }
         Cvc4    -> makeProcess ctxLog $
                      Process.defaultConfig
                             { Process.exe = "cvc4"
                             , Process.args = ["-L", "smtlib2"] }
         Cvc5    -> makeProcess ctxLog $
                      Process.defaultConfig
                             { Process.exe = "cvc5"
                             , Process.args = ["-L", "smtlib2", "--arrays-exp"] }
       solver <- SMTLIB.Backends.initSolver SMTLIB.Backends.Queuing backend
       loud <- isLoud
       return Ctx { ctxSolver    = solver
                  , ctxElabF     = solverFlags cfg
                  , ctxClose     = closeIO
                  , ctxLog       = ctxLog
                  , ctxVerbose   = loud
                  , ctxSymEnv    = mempty
                  , ctxIxs       = []
                  , ctxDefines   = mempty
                  -- This is a heurstic to avoid generating large sequences of unused `lam_arg` symbols
                  -- when there's no higher-order reasoning. It might require some tuning on larger codebases
                  -- if `unknown function/constant lam_arg$XXX` errors are encountered.
                  , ctxLams      = allowHO cfg
                  , config       = cfg
                  }

-- | Close file handles and release the solver backend's resources.
cleanupContext :: Context -> IO ()
cleanupContext Ctx {..} = do
  maybe (return ()) (hCloseMe "ctxLog") ctxLog
  ctxClose

hCloseMe :: String -> Handle -> IO ()
hCloseMe msg h = hClose h `catch` (\(exn :: IOException) -> putStrLn $ "OOPS, hClose breaks: " ++ msg ++ show exn)

smtPreamble :: Config -> SMTSolver -> Context -> IO [Builder]
smtPreamble cfg s me
  | s == Z3 || s == Z3mem
    = do v <- getZ3Version me
         checkValidStringFlag Z3 v cfg
         return $ makeMbqi ++ makeTimeout cfg ++ Thy.preamble cfg Z3
  | otherwise
    = checkValidStringFlag s [] cfg >> return (Thy.preamble cfg s)
  where
    makeMbqi = ["\n(set-option :smt.mbqi false)"]

getZ3Version :: Context -> IO [Int]
getZ3Version me
  = do -- resp is like (:version "4.8.15")
       resp <- SMTLIB.Backends.command (ctxSolver me) "(get-info :version)"
       case Char8.split '"' resp of
         _:rText:_ -> do
           -- strip off potential " - build hashcode ..." suffix
           let vText = Char8.takeWhile (not . isSpace) rText
           let parsedComponents = [ reads (Char8.unpack cText) | cText <- Char8.split '.' vText ]
           sequence
             [ case pComponent of
                 [(c, "")] -> return c
                 xs -> error $ "Can't parse z3 version: " ++ show xs
             | pComponent <- parsedComponents
             ]
         xs -> error $ "Can't parse z3 (get-info :version): " ++ show xs

checkValidStringFlag :: SMTSolver -> [Int] -> Config -> IO ()
checkValidStringFlag smt v cfg
  = when (noString smt v cfg) $
      die $ err dummySpan (text "stringTheory is only supported by z3 version >=4.2.2")

noString :: SMTSolver -> [Int] -> Config -> Bool
noString smt v cfg
  =  not (noStringTheory cfg)
  && not (smt == Cvc5 || (smt == Z3 && (v >= [4, 4, 2])))
-----------------------------------------------------------------------------
-- | SMT Commands -----------------------------------------------------------
-----------------------------------------------------------------------------

smtPush, smtPop :: SmtM ()
smtPush = interact' Push
smtPop  = interact' Pop

smtComment :: T.Text -> SmtM ()
smtComment t = interact' (Comment t)

smtDecls :: [(Symbol, Sort)] -> SmtM ()
smtDecls = mapM_ $ uncurry smtDecl

smtDecl :: Symbol -> Sort -> SmtM ()
smtDecl x t = do
  me <- get
  let env = seData (ctxSymEnv me)
  let ins' = sortSmtSort False env <$> ins
  let out' = sortSmtSort False env     out
  interact' (notracepp _msg $ Declare (symbolSafeText x) ins' out')
  where
    (ins, out) = deconSort t
    _msg       = "smtDecl: " ++ showpp (x, t, ins, out)

smtFuncDecl :: T.Text -> ([SmtSort],  SmtSort) -> SmtM ()
smtFuncDecl x (ts, t) = interact' (Declare x ts t)

smtDataDecl :: [DataDecl] -> SmtM ()
smtDataDecl ds = interact' (DeclData ds)

deconSort :: Sort -> ([Sort], Sort)
deconSort t = case functionSort t of
                Just (_, ins, out) -> (ins, out)
                Nothing            -> ([], t)

smtAssert :: Expr -> SmtM ()
smtAssert p = interact' (Assert Nothing p)

-- the following three functions will emit additional `apply`,
-- `coerce`, and `lambda` symbols for fresh function sorts as needed
smtAssertDecl :: HasCallStack => Expr -> SmtM ()
smtAssertDecl p = interactDecl' (Assert Nothing p)

smtDefineEqn :: Equation -> SmtM ()
smtDefineEqn Equ {..} = smtDefineFunc eqName eqArgs eqSort eqBody

smtDefineFunc :: Symbol -> [(Symbol, F.Sort)] -> F.Sort -> Expr -> SmtM ()
smtDefineFunc name symList rsort e =
  do env <- gets (seData . ctxSymEnv)
     interactDecl' $
        DefineFunc
          name
          (map (sortSmtSort False env <$>) symList)
          (sortSmtSort False env rsort)
          e

-----------------------------------------------------------------

smtAssertAxiom :: Triggered Expr -> SmtM ()
smtAssertAxiom p  = interact' (AssertAx p)

smtDistinct :: [Expr] -> SmtM ()
smtDistinct az = interact' (Distinct az)

smtCheckUnsat :: HasCallStack => SmtM Bool
smtCheckUnsat = respSat <$> command CheckSat

smtBracketAt :: SrcSpan -> String -> SmtM a -> SmtM a
smtBracketAt sp _msg a =
  smtBracket _msg a `catchSMT` dieAt sp

-- | `smtBracket` adds a new level to the apply stack and saves the last fresh index
--   on the index stack before the action, and reverts these changes after the action.
smtBracket :: String -> SmtM a -> SmtM a
smtBracket msg a = do
  smtComment (T.pack $ "smtBracket - start: " ++ msg)
  smtPush
  modify $ \ctx ->
    let env = ctxSymEnv ctx in
    ctx { ctxSymEnv = env { seAppls = pushAppls (seAppls env) }
        , ctxIxs = seIx env : ctxIxs ctx}
  r <- a
  smtPop
  smtComment (T.pack $ "smtBracket - end: " ++ msg)
  modify $ \ctx ->
    let env = ctxSymEnv ctx
        (i , is) = fromMaybe (0, []) (uncons $ ctxIxs ctx)
      in
    ctx { ctxSymEnv = env {seAppls = popAppls (seAppls env) , seIx = i}
        , ctxIxs = is}
  return r

respSat :: HasCallStack => Response -> Bool
respSat Unsat   = True
respSat Sat     = False
respSat Unknown = False
respSat r       = die $ err dummySpan $ text ("crash: SMTLIB2 respSat = " ++ show r)

interact' :: Command -> SmtM ()
interact' cmd  = void $ command cmd

-- | a variant of `interact'` which also emits fresh
--   `apply`, `coerce`, and `lambda` symbols
interactDecl' :: HasCallStack => Command -> SmtM ()
interactDecl' cmd  = do
  cmdBS <- liftSym $ runSmt2 cmd
  ctx <- get
  let env = ctxSymEnv ctx
  let ats = funcSortVars (ctxLams ctx) env
  forM_ ats $ uncurry smtFuncDecl
  put (ctx {ctxSymEnv = env {seAppls = mergeTopAppls (seApplsCur env) (seAppls env), seApplsCur = M.empty} })
  void $ commandB cmdBS

makeTimeout :: Config -> [Builder]
makeTimeout cfg
  | Just i <- smtTimeout cfg = [ "\n(set-option :timeout " <> fromString (show i) <> ")\n"]
  | otherwise                = [""]


--------------------------------------------------------------------------------
declare :: SmtM ()
--------------------------------------------------------------------------------
declare = do
  me <- get
  let env        = ctxSymEnv me
  let xts        = symbolSorts (F.seSort env)
  let tx         = elaborate (ElabParam (ctxElabF me) "declare" env)
  let lts        = F.toListSEnv . F.seLits $ env
  let dss        = dataDeclarations          env
  let thyXTs     =             [ (x, t) | (x, t) <- xts, symKind env x == Just F.Uninterp ]
  let qryXTs     = fmap tx <$> [ (x, t) | (x, t) <- xts, symKind env x == Nothing ]
  -- let isKind n   = (n ==)  . symKind env . fst
  let MkDefinedFuns defs = ctxDefines me
  let ess        = distinctLiterals  lts
  let axs        = Thy.axiomLiterals (config me) lts
  forM_ dss              smtDataDecl
  forM_ thyXTs $ uncurry smtDecl
  forM_ qryXTs $ uncurry smtDecl
  forM_ defs             smtDefineEqn
  forM_ ess              smtDistinct
  forM_ axs              smtAssert

symbolSorts :: F.SEnv F.Sort -> [(F.Symbol, F.Sort)]
symbolSorts env = [(x, tx t) | (x, t) <- F.toListSEnv env ]
 where
  tx t@(FObj a) = fromMaybe t (F.lookupSEnv a env)
  tx t          = t

dataDeclarations :: SymEnv -> [[DataDecl]]
dataDeclarations = orderDeclarations . map snd . F.toListSEnv . F.seData

-- | See 'F.seApplsCur' for explanation.
funcSortVars :: Bool -> F.SymEnv -> [(T.Text, ([F.SmtSort], F.SmtSort))]
funcSortVars lams env =
    concatMap symbolsForTag $ M.toList $ F.seApplsCur env
  where
    symbolsForTag (t, i) =
      let applySym  = symbolAtSortIndex applyName i
          coerceSym = symbolAtSortIndex coerceName i
          lamSym    = symbolAtSortIndex lambdaName i
          argSyms   = if lams && snd t == F.SInt
                        then [ (symbolAtSortIndex (lamArgSymbol j) i, argSort t)
                             | j <- [1..Thy.maxLamArg] ]
                        else []
      in  (applySym, appSort t)
        : (coerceSym, ([fst t], snd t))
        : (lamSym, lamSort t)
        : argSyms

    appSort (s,t) = ([F.SInt, s], t)
    lamSort (s,t) = ([s, t], F.SInt)
    argSort (s,_) = ([]    , s)

symKind :: F.SymEnv -> F.Symbol -> Maybe Sem
symKind env x = F.tsInterp <$> F.symEnvTheory x env

-- | `distinctLiterals` is used solely to determine the set of literals
--   (of each sort) that are *disequal* to each other, e.g. EQ, LT, GT,
--   or string literals "cat", "dog", "mouse". These should only include
--   non-function sorted values.
distinctLiterals :: [(F.Symbol, F.Sort)] -> [[F.Expr]]
distinctLiterals xts = [ es | (_, es) <- tess ]
   where
    tess             = Misc.groupList [(t, F.expr x) | (x, t) <- xts, notFun t]
    notFun           = not . F.isFunctionSortedReft . (`F.RR` F.trueReft)
    -- _notStr          = not . (F.strSort ==) . F.sr_sort . (`F.RR` F.trueReft)
