{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# Language ImplicitParams #-}
{-# Language MultiWayIf #-}
{-# Language PatternSynonyms #-}
{-# Language ScopedTypeVariables #-}
{-# Language TypeFamilies #-}
{-# Language TypeOperators #-}
{-# Language ViewPatterns #-}

module Crux.Goal where

import Control.Concurrent.Async (async, asyncThreadId, waitAnyCatch)
import Control.Exception (throwTo, SomeException, displayException)
import Control.Lens ((^.), view)

import Control.Monad (forM, forM_, unless, when)
import Data.Either (partitionEithers)
import Data.IORef
import Data.Maybe (fromMaybe)
import qualified Data.Map as Map
import qualified Data.Parameterized.Map as MapF
import qualified Data.Text as Text
import           Data.Void
import           Prettyprinter
import           System.Exit (ExitCode(ExitSuccess))
import           System.FilePath ((<.>), splitExtension)
import           System.IO (Handle, IOMode(..), withFile)
import qualified System.Timeout as ST

import What4.Interface (notPred, getConfiguration, IsExprBuilder)
import What4.Config (setOpt, getOptionSetting, Opt, ConfigOption)
import What4.ProgramLoc (ProgramLoc)
import What4.SatResult(SatResult(..))
import What4.Expr (ExprBuilder, GroundEvalFn(..), BoolExpr, GroundValueWrapper(..))
import What4.Protocol.Online( OnlineSolver, SolverProcess, inNewFrame, inNewFrame2Open
                            , inNewFrame2Close, solverEvalFuns, solverConn
                            , check, getUnsatCore, getAbducts )
import What4.Protocol.SMTWriter( SMTReadWriter, mkFormula, assumeFormulaWithFreshName
                               , assumeFormula, smtExprGroundEvalFn )
import qualified What4.Solver as WS
import Lang.Crucible.Backend
import Lang.Crucible.Backend.Online
        ( OnlineBackend, withSolverProcess, enableOnlineBackend, solverInteractionFile )
import Lang.Crucible.Simulator.SimError
        ( SimError(..), SimErrorReason(..) )
import Lang.Crucible.Simulator.ExecutionTree
        (ctxSymInterface)
import Lang.Crucible.Panic (panic)

import Crux.Types
import Crux.Model
import Crux.Log as Log
import Crux.Config.Common
import Crux.ProgressBar


symCfg :: (IsExprBuilder sym, Opt t a) => sym -> ConfigOption t -> a -> IO ()
symCfg sym x y =
  do opt <- getOptionSetting x (getConfiguration sym)
     _   <- setOpt opt y
     pure ()


-- | Simplify the proved goals.
provedGoalsTree :: forall sym.
  ( IsSymInterface sym
  ) =>
  sym ->
  Maybe (Goals (Assumptions sym) (Assertion sym, [ProgramLoc], ProofResult sym)) ->
  IO (Maybe ProvedGoals)
provedGoalsTree sym = traverse (go mempty)
  where
  go :: Assumptions sym ->
        Goals (Assumptions sym) (Assertion sym, [ProgramLoc], ProofResult sym) ->
        IO ProvedGoals
  go asmps gs =
    case gs of
      Assuming as gs1 -> go (asmps <> as) gs1
      Prove (p,locs,r) -> proveToGoal sym asmps p locs r
      ProveConj g1 g2 -> Branch <$> go asmps g1 <*> go asmps g2

proveToGoal ::
  (IsSymInterface sym) =>
  sym ->
  Assumptions sym ->
  Assertion sym ->
  [ProgramLoc] ->
  ProofResult sym ->
  IO ProvedGoals
proveToGoal sym allAsmps p locs pr =
  case pr of
    NotProved ex cex s ->
      do as <- flattenAssumptions sym allAsmps
         return (NotProvedGoal (map showAsmp as) (showGoal p) ex locs cex s)
    Proved xs ->
      case partitionEithers xs of
        (asmps, [])  -> return (ProvedGoal (map showAsmp asmps) (showGoal p) locs True)
        (asmps, _:_) -> return (ProvedGoal (map showAsmp asmps) (showGoal p) locs False)

 where
 showAsmp x = forgetAssumption x
 showGoal x = x^.labeledPredMsg


countGoals :: Goals a b -> Int
countGoals gs =
  case gs of
    Assuming _ gs1  -> countGoals gs1
    Prove _         -> 1
    ProveConj g1 g2 -> countGoals g1 + countGoals g2

isResourceExhausted :: LabeledPred p SimError -> Bool
isResourceExhausted (view labeledPredMsg -> SimError _ (ResourceExhausted _)) = True
isResourceExhausted _ = False

updateProcessedGoals ::
  LabeledPred p SimError ->
  ProofResult a ->
  ProcessedGoals ->
  ProcessedGoals
updateProcessedGoals _ (Proved _) pgs =
  pgs{ totalProcessedGoals = 1 + totalProcessedGoals pgs
     , provedGoals = 1 + provedGoals pgs
     }

updateProcessedGoals res (NotProved _ _ _) pgs | isResourceExhausted res =
  pgs{ totalProcessedGoals = 1 + totalProcessedGoals pgs
     , incompleteGoals = 1 + incompleteGoals pgs
     }

updateProcessedGoals _ (NotProved _ (Just _) _) pgs =
  pgs{ totalProcessedGoals = 1 + totalProcessedGoals pgs
     , disprovedGoals = 1 + disprovedGoals pgs
     }

updateProcessedGoals _ (NotProved _ Nothing _) pgs =
  pgs{ totalProcessedGoals = 1 + totalProcessedGoals pgs }

-- | A function that can be used to generate a pretty explanation of a
-- simulation error.

type Explainer sym t ann = Maybe (GroundEvalFn t)
                           -> LPred sym SimError
                           -> IO (Doc ann)

type ProverCallback sym =
  forall ext personality t st fs.
    (sym ~ ExprBuilder t st fs) =>
    CruxOptions ->
    SimCtxt personality sym ext ->
    Explainer sym t Void ->
    Maybe (Goals (Assumptions sym) (Assertion sym)) ->
    IO (ProcessedGoals, Maybe (Goals (Assumptions sym) (Assertion sym, [ProgramLoc], ProofResult sym)))

-- | Discharge a tree of proof obligations ('Goals') by using a non-online solver
--
-- This function traverses the 'Goals' tree while keeping track of a collection
-- of assumptions in scope for each goal.  For each proof goal encountered in
-- the tree, it creates a fresh solver connection using the provided solver
-- adapter.
--
-- This is in contrast to 'proveGoalsOnline', which uses an online solver
-- connection with scoped assumption frames.  This function allows using a wider
-- variety of solvers (i.e., ones that don't have support for online solving)
-- and would in principle enable parallel goal evaluation (though the tree
-- structure makes that a bit trickier, it isn't too hard).
--
-- Note that this function uses the same symbolic backend ('ExprBuilder') as the
-- symbolic execution phase, which should not be a problem.
proveGoalsOffline :: forall st sym p t fs personality msgs.
  ( sym ~ ExprBuilder t st fs
  , Logs msgs
  , SupportsCruxLogMessage msgs
  ) =>
  [WS.SolverAdapter st] ->
  CruxOptions ->
  SimCtxt personality sym p ->
  (Maybe (GroundEvalFn t) -> Assertion sym -> IO (Doc Void)) ->
  Maybe (Goals (Assumptions sym) (Assertion sym)) ->
  IO (ProcessedGoals, Maybe (Goals (Assumptions sym) (Assertion sym, [ProgramLoc], ProofResult sym)))
proveGoalsOffline _adapter _opts _ctx _explainFailure Nothing = return (ProcessedGoals 0 0 0 0, Nothing)
proveGoalsOffline adapters opts ctx explainFailure (Just gs0) = do
  goalNum <- newIORef (ProcessedGoals 0 0 0 0)
  (start,end,finish) <- proverMilestoneCallbacks gs0
  gs <- go (start,end) goalNum mempty gs0
  nms <- readIORef goalNum
  finish
  return (nms, Just gs)

  where
    sym = ctx^.ctxSymInterface

    failfast = proofGoalsFailFast opts

    go :: SupportsCruxLogMessage msgs
       => (ProverMilestoneStartGoal, ProverMilestoneEndGoal)
       -> IORef ProcessedGoals
       -> Assumptions sym
       -> Goals (Assumptions sym) (Assertion sym)
       -> IO (Goals (Assumptions sym) (Assertion sym, [ProgramLoc], ProofResult sym))
    go (start,end) goalNum assumptionsInScope gs =
      case gs of
        Assuming ps gs1 -> do
          res <- go (start,end) goalNum (assumptionsInScope <> ps) gs1
          return (Assuming ps res)

        ProveConj g1 g2 -> do
          g1' <- go (start,end) goalNum assumptionsInScope g1
          numDisproved <- disprovedGoals <$> readIORef goalNum
          if failfast && numDisproved > 0
            then return g1'
            else ProveConj g1' <$> go (start,end) goalNum assumptionsInScope g2

        Prove p -> do
          goalNumber <- totalProcessedGoals <$> readIORef goalNum
          start goalNumber

          -- Conjoin all of the in-scope assumptions, the goal, then negate and
          -- check sat with the adapter
          assumptions <- assumptionsPred sym assumptionsInScope
          goal <- notPred sym (p ^. labeledPred)

          res <- dispatchSolversOnGoalAsync (goalTimeout opts) adapters
                           (runOneSolver p assumptionsInScope assumptions goal goalNumber)
          end goalNumber
          case res of
            Right Nothing -> do
              let details = NotProved "(timeout)" Nothing []
              let locs = assumptionsTopLevelLocs assumptionsInScope
              modifyIORef' goalNum (updateProcessedGoals p details)
              return (Prove (p, locs, details))

            Right (Just (locs,details)) -> do
              modifyIORef' goalNum (updateProcessedGoals p details)
              case details of
                NotProved _ (Just _) _ ->
                  when (failfast && not (isResourceExhausted p)) $
                    sayCrux Log.FoundCounterExample
                _ -> return ()
              return (Prove (p, locs, details))
            Left es -> do
              modifyIORef' goalNum (updateProcessedGoals p (NotProved mempty Nothing []))
              let allExceptions = unlines (displayException <$> es)
              fail allExceptions

    runOneSolver :: Assertion sym
                 -> Assumptions sym
                 -> BoolExpr t
                 -> BoolExpr t
                 -> Integer
                 -> WS.SolverAdapter st
                 -> IO ([ProgramLoc], ProofResult sym)
    runOneSolver p assumptionsInScope assumptions goal goalNumber adapter =
      -- Create a file to a single offline solver interaction, assuming the user
      -- has selected this option. A single Crux session might invoke an offline
      -- solver multiple times, so each file has unique suffixes to indicate the
      -- goal number and which solver in particular was used to solve the goal.
      -- (Recall that a Crux user can choose multiple offline solvers, so it is
      -- important to indicate which solver was picked for each goal.)
      withMaybeOfflineSolverOutputHandle (offlineSolverOutput opts) $ \mbLogHandle ->
      let logData = WS.defaultLogData { WS.logHandle = mbLogHandle } in
      WS.solver_adapter_check_sat adapter (ctx ^. ctxSymInterface) logData [assumptions, goal] $ \satRes ->
        case satRes of
          Unsat _ -> do
            -- NOTE: We don't have an easy way to get an unsat core here
            -- because we don't have a solver connection.
            as <- flattenAssumptions sym assumptionsInScope
            let core = fmap Left as ++ [ Right p ]
            let locs = assumptionsTopLevelLocs assumptionsInScope
            return (locs, Proved core)
          Sat (evalFn, _) -> do
            evs  <- concretizeEvents (groundEval evalFn) assumptionsInScope
            let vals = evalModelFromEvents evs
            explain <- explainFailure (Just evalFn) p
            let locs = map eventLoc evs
            return (locs, NotProved explain (Just (vals,evs)) [])
          Unknown -> do
            explain <- explainFailure Nothing p
            let locs = assumptionsTopLevelLocs assumptionsInScope
            return (locs, NotProved explain Nothing [])
      where
        -- Create a handle for a file based on the template. For instance, if
        -- the template is @solver-output.smt2@, then create a file named
        -- @solver-output-<goal number>-<solver name>.smt2@.
        withOfflineSolverOutputHandle :: FilePath -> (Handle -> IO r) -> IO r
        withOfflineSolverOutputHandle template k =
          let (templateName, templateExt) = splitExtension template
              fileName = templateName ++ "-" ++ show goalNumber
                                      ++ "-" ++ WS.solver_adapter_name adapter
                                     <.> templateExt in
          withFile fileName WriteMode k

        -- Lift @withOfflineSolverOutputHandle@ to a @Maybe FilePath@ argument.
        withMaybeOfflineSolverOutputHandle ::
          Maybe FilePath -> (Maybe Handle -> IO r) -> IO r
        withMaybeOfflineSolverOutputHandle mbTemplate k =
          case mbTemplate of
            Just template -> withOfflineSolverOutputHandle template (k . Just)
            Nothing       -> k Nothing

evalModelFromEvents :: [CrucibleEvent GroundValueWrapper] -> ModelView
evalModelFromEvents evs = ModelView (foldl f (modelVals emptyModelView) evs)
 where
   f m (CreateVariableEvent loc nm tpr (GVW v)) = MapF.insertWith jn tpr (Vals [Entry nm loc v]) m
   f m _ = m

   jn (Vals new) (Vals old) = Vals (old++new)

dispatchSolversOnGoalAsync :: forall a b s time.
                              (RealFrac time)
                           => Maybe time
                           -> [WS.SolverAdapter s]
                           -> (WS.SolverAdapter s -> IO (b,ProofResult a))
                           -> IO (Either [SomeException] (Maybe (b,ProofResult a)))
dispatchSolversOnGoalAsync mtimeoutSeconds adapters withAdapter = do
  asyncs <- forM adapters (async . withAdapter)
  await asyncs []
  where
    await [] es = return $ Left es
    await as es = do
      mresult <- withTimeout $ waitAnyCatch as
      case mresult of
        Just (a, result) -> do
          let as' = filter (/= a) as
          case result of
            Left  exc ->
              await as' (exc : es)
            Right (x, r@(Proved _)) -> do
              mapM_ kill as'
              return $ Right (Just (x,r))
            Right (x,r@(NotProved _ (Just _) _)) -> do
              mapM_ kill as'
              return $ Right (Just (x,r))
            Right _ ->
              await as' es
        Nothing -> do
          mapM_ kill as
          return $ Right $ Nothing

    withTimeout action
      | Just seconds <- mtimeoutSeconds = ST.timeout (round seconds * 1000000) action
      | otherwise = Just <$> action

    -- `cancel` from async blocks until the canceled thread has terminated.
    kill a = throwTo (asyncThreadId a) ExitSuccess


-- | Returns three actions, called respectively when the proving process for a
-- goal is started, when it is ended, and when the final goal is solved.  These
-- handlers should handle all necessary output / notifications to external
-- observers, based on the run options.
proverMilestoneCallbacks ::
  Log.Logs msgs =>
  Log.SupportsCruxLogMessage msgs =>
  Goals asmp ast -> IO ProverMilestoneCallbacks
proverMilestoneCallbacks goals = do
  (start, end, finish) <-
    if view quiet ?outputConfig then
      return silentProverMilestoneCallbacks
    else
      prepStatus "Checking: " (countGoals goals)
  return
    ( start <> sayCrux . Log.StartedGoal
    , end <> sayCrux . Log.EndedGoal
    , finish
    )


-- | Prove a collection of goals.  The result is a goal tree, where
-- each goal is annotated with the outcome of the proof.
--
-- NOTE: This function takes an explicit symbolic backend as an argument, even
-- though the symbolic backend used for symbolic execution is available in the
-- 'SimCtxt'.  We do that so that we can use separate solvers for path
-- satisfiability checking and goal discharge.
proveGoalsOnline ::
  forall sym personality p msgs goalSolver s st fs.
  ( sym ~ ExprBuilder s st fs
  , OnlineSolver goalSolver
  , Logs msgs
  , SupportsCruxLogMessage msgs
  ) =>
  OnlineBackend goalSolver s st fs ->
  CruxOptions ->
  SimCtxt personality sym p ->
  (Maybe (GroundEvalFn s) -> Assertion sym -> IO (Doc Void)) ->
  Maybe (Goals (Assumptions sym) (Assertion sym)) ->
  IO (ProcessedGoals, Maybe (Goals (Assumptions sym) (Assertion sym, [ProgramLoc], ProofResult sym)))
proveGoalsOnline _ _opts _ctxt _explainFailure Nothing =
     return (ProcessedGoals 0 0 0 0, Nothing)

proveGoalsOnline bak opts _ctxt explainFailure (Just gs0) =
  do
     -- send solver interactions to the correct file
     mapM_ (symCfg sym solverInteractionFile) (fmap Text.pack (onlineSolverOutput opts))
     -- initial goal count
     goalNum <- newIORef (ProcessedGoals 0 0 0 0)
     -- nameMap is a mutable ref to a map from Text to Either (Assumption sym) (Assertion sym)
     nameMap <- newIORef Map.empty
     when (unsatCores opts && yicesMCSat opts) $
       sayCrux Log.SkippingUnsatCoresBecauseMCSatEnabled
     -- callbacks for starting a goal, ending a goal, and finishing all goals
     (start,end,finish) <- proverMilestoneCallbacks gs0
     -- make sure online features are enabled
     enableOpt <- getOptionSetting enableOnlineBackend (getConfiguration sym)
     _ <- setOpt enableOpt True
     -- @go@ traverses a proof tree, processing/solving each goal as it traverses it.
     -- It also updates goal count and nameMap
     res <- withSolverProcess bak (panic "proveGoalsOnline" ["Online solving not enabled!"]) $ \sp ->
              inNewFrame sp (go (start,end) sp mempty goalNum gs0 nameMap)
     nms <- readIORef goalNum
     finish
     return (nms, Just res)

  where
  sym = backendGetSym bak

  bindName nm p nameMap = modifyIORef nameMap (Map.insert nm p)

  hasUnsatCores = unsatCores opts && not (yicesMCSat opts)

  howManyAbducts = fromMaybe 0 (getNAbducts opts)

  usingAbducts = howManyAbducts > 0

  failfast = proofGoalsFailFast opts

  go (start,end) sp assumptionsInScope gn gs nameMap = do
    -- traverse goal tree
    case gs of
      -- case: assumption in context for all the contained goals
      Assuming asms gs1 ->
        do ps <- flattenAssumptions sym asms
           forM_ ps $ \asm ->
             unless (trivialAssumption asm) $
               -- extract predicate from assumption
               do let p = assumptionPred asm
                  -- create formula, assert to SMT solver, create new name and add to nameMap
                  nm <- doAssume =<< mkFormula conn p
                  bindName nm (Left asm) nameMap
           -- recursive call
           res <- go (start,end) sp (assumptionsInScope <> asms) gn gs1 nameMap
           return (Assuming (mconcat (map singleAssumption ps)) res)
      -- case: proof obligation in the context of all previously-made assumptions
      Prove p ->
        -- number of processed goals gives goal number to prove
        do goalNumber <- totalProcessedGoals <$> readIORef gn
           start goalNumber
           -- negate goal, create formula
           t <- mkFormula conn =<< notPred sym (p ^. labeledPred)
           -- assert formula to SMT solver, create new name and add to nameMap.
           -- This is done in a new assertion frame if abduction is turned on since
           -- this assertion would need to be removed before asking for abducts
           let inNewFrame2ForAbducts =
                 if usingAbducts then inNewFrame2 sp else id
           ret <- inNewFrame2ForAbducts $ do
             nm <- doAssume t
             bindName nm (Right p) nameMap
             -- check-sat with SMT solver, pattern match on result
             res <- check sp "proof"
             case res of
               Unsat () ->
                 -- build unsat core, which is the entire assertion set by default
                 do namemap <- readIORef nameMap
                    core <- if hasUnsatCores then
                               map (lookupnm namemap) <$> getUnsatCore sp
                            -- default unsat core: entire assertion set
                            else return (Map.elems namemap)
                    let locs = assumptionsTopLevelLocs assumptionsInScope
                    return $ UnsatResult core locs
               Sat ()  ->
                 do -- evaluate counter-example
                    f <- smtExprGroundEvalFn conn (solverEvalFuns sp)
                    evs <- concretizeEvents (groundEval f) assumptionsInScope
                    explain <- explainFailure (Just f) p
                    return $ SatResult explain evs
               Unknown ->
                 do explain <- explainFailure Nothing p
                    let locs = assumptionsTopLevelLocs assumptionsInScope
                    return $ UnknownResult explain locs
           end goalNumber
           smtResultToGoals p ret
      -- case: conjunction of goals
      ProveConj g1 g2 ->
        do g1' <- inNewFrame sp (go (start,end) sp assumptionsInScope gn g1 nameMap)
           -- NB, we don't need 'inNewFrame' here because
           --  we don't need to back up to this point again.
           if failfast then
             do numDisproved <- disprovedGoals <$> readIORef gn
                if numDisproved > 0 then
                  return g1'
                else
                  ProveConj g1' <$> go (start,end) sp assumptionsInScope gn g2 nameMap
           else
             ProveConj g1' <$> go (start,end) sp assumptionsInScope gn g2 nameMap

    where
    conn = solverConn sp

    lookupnm namemap x =
      fromMaybe (error $ "Named predicate " ++ show x ++ " not found!")
                (Map.lookup x namemap)

    doAssume formula = do
      namemap <- readIORef nameMap
      if hasUnsatCores
      then assumeFormulaWithFreshName conn formula
      else assumeFormula conn formula >> return (Text.pack ("x" ++ show (Map.size namemap)))

    -- Convert an 'SMTResult' to a 'ProofResult'. This function should be
    -- called /after/ 'inNewFrame2' so that the abducts can be queried properly
    -- in the SatResult case.
    smtResultToGoals :: LabeledPred (BoolExpr s) SimError
                     -> SMTResult sym
                     -> IO ( Goals asmp (LabeledPred (BoolExpr s) SimError
                           , [ProgramLoc]
                           , ProofResult sym)
                           )
    smtResultToGoals p smtRes = do
      (locs, gt) <- case smtRes of
        UnsatResult core locs -> do
          let pr = Proved core
          return (locs, pr)
        SatResult explain evs -> do
          let vals = evalModelFromEvents evs
          abds <- if usingAbducts then
                    getAbducts sp (fromIntegral howManyAbducts) "abd" (p ^. labeledPred)
                  else
                    return []
          let gt = NotProved explain (Just (vals,evs)) abds
          when (failfast && not (isResourceExhausted p)) $
            sayCrux Log.FoundCounterExample
          let locs = map eventLoc evs
          return (locs, gt)
        UnknownResult explain locs -> do
          let gt = NotProved explain Nothing []
          return (locs, gt)

      -- update goal count
      modifyIORef' gn (updateProcessedGoals p gt)
      return (Prove (p, locs, gt))

-- | Like 'inNewFrame', but specifically for frame @2@. This is used for the
-- purpose of generating abducts.

-- TODO: Upstream this to @what4@ (Issue what4#218).
inNewFrame2 :: SMTReadWriter solver => SolverProcess scope solver -> IO a -> IO a
inNewFrame2 sp action = do
  inNewFrame2Open sp
  val <- action
  inNewFrame2Close sp
  return val

-- | An intermediate data structure used in 'proveGoalsOnline'. This can be
-- thought of as a halfway point between a 'SatResult' and a 'ProofResult'.
data SMTResult sym
  = UnsatResult [Either (Assumption sym) (Assertion sym)]
                [ProgramLoc]
  | SatResult (Doc Void)
              [CrucibleEvent GroundValueWrapper]
  | UnknownResult (Doc Void)
                  [ProgramLoc]
