-- |
-- Module      :  Cryptol.Symbolic
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ViewPatterns #-}

module Cryptol.Symbolic where

import Control.Monad.IO.Class
import Control.Monad (replicateM, when, zipWithM, foldM)
import Control.Monad.Writer (WriterT, runWriterT, tell, lift)
import Data.List (intercalate, genericLength)
import Data.IORef(IORef)
import qualified Control.Exception as X

import qualified Data.SBV.Dynamic as SBV
import           Data.SBV (Timing(SaveTiming))
import           Data.SBV.Internals (showTDiff)

import qualified Cryptol.ModuleSystem as M hiding (getPrimMap)
import qualified Cryptol.ModuleSystem.Env as M
import qualified Cryptol.ModuleSystem.Base as M
import qualified Cryptol.ModuleSystem.Monad as M

import Cryptol.Symbolic.Prims
import Cryptol.Symbolic.Value

import qualified Cryptol.Eval as Eval
import qualified Cryptol.Eval.Monad as Eval
import qualified Cryptol.Eval.Type as Eval
import qualified Cryptol.Eval.Value as Eval
import           Cryptol.Eval.Env (GenEvalEnv(..))
import Cryptol.TypeCheck.AST
import Cryptol.Utils.Ident (Ident)
import Cryptol.Utils.PP
import Cryptol.Utils.Panic(panic)
import Cryptol.Utils.Logger(logPutStrLn)

import Prelude ()
import Prelude.Compat

import Data.Time (NominalDiffTime)

type EvalEnv = GenEvalEnv SBool SWord


-- External interface ----------------------------------------------------------

proverConfigs :: [(String, SBV.SMTConfig)]
proverConfigs =
  [ ("cvc4"     , SBV.cvc4     )
  , ("yices"    , SBV.yices    )
  , ("z3"       , SBV.z3       )
  , ("boolector", SBV.boolector)
  , ("mathsat"  , SBV.mathSAT  )
  , ("abc"      , SBV.abc      )
  , ("offline"  , SBV.defaultSMTCfg )
  , ("any"      , SBV.defaultSMTCfg )
  ]

proverNames :: [String]
proverNames = map fst proverConfigs

lookupProver :: String -> SBV.SMTConfig
lookupProver s =
  case lookup s proverConfigs of
    Just cfg -> cfg
    -- should be caught by UI for setting prover user variable
    Nothing  -> panic "Cryptol.Symbolic" [ "invalid prover: " ++ s ]

type SatResult = [(Type, Expr, Eval.Value)]

data SatNum = AllSat | SomeSat Int
  deriving (Show)

data QueryType = SatQuery SatNum | ProveQuery
  deriving (Show)

data ProverCommand = ProverCommand {
    pcQueryType :: QueryType
    -- ^ The type of query to run
  , pcProverName :: String
    -- ^ Which prover to use (one of the strings in 'proverConfigs')
  , pcVerbose :: Bool
    -- ^ Verbosity flag passed to SBV
  , pcValidate :: Bool
    -- ^ Model validation flag passed to SBV
  , pcProverStats :: !(IORef ProverStats)
    -- ^ Record timing information here
  , pcExtraDecls :: [DeclGroup]
    -- ^ Extra declarations to bring into scope for symbolic
    -- simulation
  , pcSmtFile :: Maybe FilePath
    -- ^ Optionally output the SMTLIB query to a file
  , pcExpr :: Expr
    -- ^ The typechecked expression to evaluate
  , pcSchema :: Schema
    -- ^ The 'Schema' of @pcExpr@
  }

type ProverStats = NominalDiffTime

-- | A prover result is either an error message, an empty result (eg
-- for the offline prover), a counterexample or a lazy list of
-- satisfying assignments.
data ProverResult = AllSatResult [SatResult] -- LAZY
                  | ThmResult    [Type]
                  | EmptyResult
                  | ProverError  String

satSMTResults :: SBV.SatResult -> [SBV.SMTResult]
satSMTResults (SBV.SatResult r) = [r]

allSatSMTResults :: SBV.AllSatResult -> [SBV.SMTResult]
allSatSMTResults (SBV.AllSatResult (_, _, _, rs)) = rs

thmSMTResults :: SBV.ThmResult -> [SBV.SMTResult]
thmSMTResults (SBV.ThmResult r) = [r]

proverError :: String -> M.ModuleCmd (Maybe SBV.Solver, ProverResult)
proverError msg (_,modEnv) =
  return (Right ((Nothing, ProverError msg), modEnv), [])

satProve :: ProverCommand -> M.ModuleCmd (Maybe SBV.Solver, ProverResult)
satProve ProverCommand {..} =
  protectStack proverError $ \(evo,modEnv) ->

  M.runModuleM (evo,modEnv) $ do
  let (isSat, mSatNum) = case pcQueryType of
        ProveQuery -> (False, Nothing)
        SatQuery sn -> case sn of
          SomeSat n -> (True, Just n)
          AllSat    -> (True, Nothing)
  let extDgs = allDeclGroups modEnv ++ pcExtraDecls
  provers <-
    case pcProverName of
      "any" -> M.io SBV.sbvAvailableSolvers
      _ -> return [(lookupProver pcProverName) { SBV.transcript = pcSmtFile
                                               , SBV.allSatMaxModelCount = mSatNum
                                               }]


  let provers' = [ p { SBV.timing = SaveTiming pcProverStats
                     , SBV.verbose = pcVerbose
                     , SBV.validateModel = pcValidate
                     } | p <- provers ]
  let tyFn = if isSat then existsFinType else forallFinType
  let lPutStrLn = M.withLogger logPutStrLn
  let doEval :: MonadIO m => Eval.Eval a -> m a
      doEval m  = liftIO $ Eval.runEval evo m
  let runProver fn tag e = do
        case provers of
          [prover] -> do
            when pcVerbose $
              lPutStrLn $ "Trying proof with " ++
                                        show (SBV.name (SBV.solver prover))
            res <- M.io (fn prover e)
            when pcVerbose $
              lPutStrLn $ "Got result from " ++
                                        show (SBV.name (SBV.solver prover))
            return (Just (SBV.name (SBV.solver prover)), tag res)
          _ ->
            return ( Nothing
                   , [ SBV.ProofError
                         prover
                         [":sat with option prover=any requires option satNum=1"]
                         Nothing
                     | prover <- provers ]
                   )
      runProvers fn tag e = do
        when pcVerbose $
          lPutStrLn $ "Trying proof with " ++
                  intercalate ", " (map (show . SBV.name . SBV.solver) provers)
        (firstProver, timeElapsed, res) <- M.io (fn provers' e)
        when pcVerbose $
          lPutStrLn $ "Got result from " ++ show firstProver ++
                                            ", time: " ++ showTDiff timeElapsed
        return (Just firstProver, tag res)
  let runFn = case pcQueryType of
        ProveQuery -> runProvers SBV.proveWithAny thmSMTResults
        SatQuery sn -> case sn of
          SomeSat 1 -> runProvers SBV.satWithAny satSMTResults
          _         -> runProver SBV.allSatWith allSatSMTResults
  let addAsm = case pcQueryType of
        ProveQuery -> \x y -> SBV.svOr (SBV.svNot x) y
        SatQuery _ -> \x y -> SBV.svAnd x y
  case predArgTypes pcSchema of
    Left msg -> return (Nothing, ProverError msg)
    Right ts -> do when pcVerbose $ lPutStrLn "Simulating..."
                   v <- doEval $ do env <- Eval.evalDecls extDgs mempty
                                    Eval.evalExpr env pcExpr
                   prims <- M.getPrimMap
                   runRes <- runFn $ do
                               (args, asms) <- runWriterT (mapM tyFn ts)
                               b <- doEval (fromVBit <$>
                                      foldM fromVFun v (map Eval.ready args))
                               return (foldr addAsm b asms)
                   let (firstProver, results) = runRes
                   esatexprs <- case results of
                     -- allSat can return more than one as long as
                     -- they're satisfiable
                     (SBV.Satisfiable {} : _) -> do
                       tevss <- mapM mkTevs results
                       return $ AllSatResult tevss
                       where
                         mkTevs result = do
                           let Right (_, cvs) = SBV.getModelAssignment result
                               (vs, _) = parseValues ts cvs
                               sattys = unFinType <$> ts
                           satexprs <-
                             doEval (zipWithM (Eval.toExpr prims) sattys vs)
                           case zip3 sattys <$> (sequence satexprs) <*> pure vs of
                             Nothing ->
                               panic "Cryptol.Symbolic.sat"
                                 [ "unable to make assignment into expression" ]
                             Just tevs -> return $ tevs
                     -- prove returns only one
                     [SBV.Unsatisfiable {}] ->
                       return $ ThmResult (unFinType <$> ts)
                     -- unsat returns empty
                     [] -> return $ ThmResult (unFinType <$> ts)
                     -- otherwise something is wrong
                     _ -> return $ ProverError (rshow results)
                            where rshow | isSat = show .  SBV.AllSatResult . (False,False,False,)
                                        | otherwise = show . SBV.ThmResult . head
                   return (firstProver, esatexprs)

satProveOffline :: ProverCommand -> M.ModuleCmd (Either String String)
satProveOffline ProverCommand {..} =
  protectStack (\msg (_,modEnv) -> return (Right (Left msg, modEnv), [])) $
  \(evOpts,modEnv) -> do
    let isSat = case pcQueryType of
          ProveQuery -> False
          SatQuery _ -> True
    let extDgs = allDeclGroups modEnv ++ pcExtraDecls
    let tyFn = if isSat then existsFinType else forallFinType
    let addAsm = if isSat then SBV.svAnd else \x y -> SBV.svOr (SBV.svNot x) y
    case predArgTypes pcSchema of
      Left msg -> return (Right (Left msg, modEnv), [])
      Right ts ->
        do when pcVerbose $ logPutStrLn (Eval.evalLogger evOpts) "Simulating..."
           v <- liftIO $ Eval.runEval evOpts $
                   do env <- Eval.evalDecls extDgs mempty
                      Eval.evalExpr env pcExpr
           smtlib <- SBV.generateSMTBenchmark isSat $ do
             (args, asms) <- runWriterT (mapM tyFn ts)
             b <- liftIO $ Eval.runEval evOpts
                        (fromVBit <$> foldM fromVFun v (map Eval.ready args))
             return (foldr addAsm b asms)
           return (Right (Right smtlib, modEnv), [])

protectStack :: (String -> M.ModuleCmd a)
             -> M.ModuleCmd a
             -> M.ModuleCmd a
protectStack mkErr cmd modEnv =
  X.catchJust isOverflow (cmd modEnv) handler
  where isOverflow X.StackOverflow = Just ()
        isOverflow _               = Nothing
        msg = "Symbolic evaluation failed to terminate."
        handler () = mkErr msg modEnv

parseValues :: [FinType] -> [SBV.CV] -> ([Eval.Value], [SBV.CV])
parseValues [] cvs = ([], cvs)
parseValues (t : ts) cvs = (v : vs, cvs'')
  where (v, cvs') = parseValue t cvs
        (vs, cvs'') = parseValues ts cvs'

parseValue :: FinType -> [SBV.CV] -> (Eval.Value, [SBV.CV])
parseValue FTBit [] = panic "Cryptol.Symbolic.parseValue" [ "empty FTBit" ]
parseValue FTBit (cv : cvs) = (Eval.VBit (SBV.cvToBool cv), cvs)
parseValue FTInteger cvs =
  case SBV.genParse SBV.KUnbounded cvs of
    Just (x, cvs') -> (Eval.VInteger x, cvs')
    Nothing        -> panic "Cryptol.Symbolic.parseValue" [ "no integer" ]
parseValue (FTIntMod _) cvs = parseValue FTInteger cvs
parseValue (FTSeq 0 FTBit) cvs = (Eval.word 0 0, cvs)
parseValue (FTSeq n FTBit) cvs =
  case SBV.genParse (SBV.KBounded False n) cvs of
    Just (x, cvs') -> (Eval.word (toInteger n) x, cvs')
    Nothing        -> (VWord (genericLength vs) $ return $ Eval.WordVal $
                         Eval.packWord (map fromVBit vs), cvs')
      where (vs, cvs') = parseValues (replicate n FTBit) cvs
parseValue (FTSeq n t) cvs =
                      (Eval.VSeq (toInteger n) $ Eval.finiteSeqMap (map Eval.ready vs)
                      , cvs'
                      )
  where (vs, cvs') = parseValues (replicate n t) cvs
parseValue (FTTuple ts) cvs = (Eval.VTuple (map Eval.ready vs), cvs')
  where (vs, cvs') = parseValues ts cvs
parseValue (FTRecord fs) cvs = (Eval.VRecord (zip ns (map Eval.ready vs)), cvs')
  where (ns, ts) = unzip fs
        (vs, cvs') = parseValues ts cvs

allDeclGroups :: M.ModuleEnv -> [DeclGroup]
allDeclGroups = concatMap mDecls . M.loadedNonParamModules

data FinType
    = FTBit
    | FTInteger
    | FTIntMod Integer
    | FTSeq Int FinType
    | FTTuple [FinType]
    | FTRecord [(Ident, FinType)]

numType :: Integer -> Maybe Int
numType n
  | 0 <= n && n <= toInteger (maxBound :: Int) = Just (fromInteger n)
  | otherwise = Nothing

finType :: TValue -> Maybe FinType
finType ty =
  case ty of
    Eval.TVBit            -> Just FTBit
    Eval.TVInteger        -> Just FTInteger
    Eval.TVIntMod n       -> Just (FTIntMod n)
    Eval.TVSeq n t        -> FTSeq <$> numType n <*> finType t
    Eval.TVTuple ts       -> FTTuple <$> traverse finType ts
    Eval.TVRec fields     -> FTRecord <$> traverse (traverseSnd finType) fields
    Eval.TVAbstract {}    -> Nothing
    _                     -> Nothing

unFinType :: FinType -> Type
unFinType fty =
  case fty of
    FTBit        -> tBit
    FTInteger    -> tInteger
    FTIntMod n   -> tIntMod (tNum n)
    FTSeq l ety  -> tSeq (tNum l) (unFinType ety)
    FTTuple ftys -> tTuple (unFinType <$> ftys)
    FTRecord fs  -> tRec (zip fns tys)
      where
        fns = fst <$> fs
        tys = unFinType . snd <$> fs

predArgTypes :: Schema -> Either String [FinType]
predArgTypes schema@(Forall ts ps ty)
  | null ts && null ps =
      case go <$> (Eval.evalType mempty ty) of
        Right (Just fts) -> Right fts
        _ -> Left $ "Not a valid predicate type:\n" ++ show (pp schema)
  | otherwise = Left $ "Not a monomorphic type:\n" ++ show (pp schema)
  where
    go :: TValue -> Maybe [FinType]
    go Eval.TVBit             = Just []
    go (Eval.TVFun ty1 ty2)   = (:) <$> finType ty1 <*> go ty2
    go _                      = Nothing

inBoundsIntMod :: Integer -> SInteger -> SBool
inBoundsIntMod n x =
  SBV.svAnd (SBV.svLessEq (Eval.integerLit 0) x) (SBV.svLessThan x (Eval.integerLit n))

forallFinType :: FinType -> WriterT [SBool] SBV.Symbolic Value
forallFinType ty =
  case ty of
    FTBit         -> VBit <$> lift forallSBool_
    FTInteger     -> VInteger <$> lift forallSInteger_
    FTIntMod n    -> do x <- lift forallSInteger_
                        tell [inBoundsIntMod n x]
                        return (VInteger x)
    FTSeq 0 FTBit -> return $ Eval.word 0 0
    FTSeq n FTBit -> VWord (toInteger n) . return . Eval.WordVal <$> lift (forallBV_ n)
    FTSeq n t     -> do vs <- replicateM n (forallFinType t)
                        return $ VSeq (toInteger n) $ Eval.finiteSeqMap (map Eval.ready vs)
    FTTuple ts    -> VTuple <$> mapM (fmap Eval.ready . forallFinType) ts
    FTRecord fs   -> VRecord <$> mapM (traverseSnd (fmap Eval.ready . forallFinType)) fs

existsFinType :: FinType -> WriterT [SBool] SBV.Symbolic Value
existsFinType ty =
  case ty of
    FTBit         -> VBit <$> lift existsSBool_
    FTInteger     -> VInteger <$> lift existsSInteger_
    FTIntMod n    -> do x <- lift existsSInteger_
                        tell [inBoundsIntMod n x]
                        return (VInteger x)
    FTSeq 0 FTBit -> return $ Eval.word 0 0
    FTSeq n FTBit -> VWord (toInteger n) . return . Eval.WordVal <$> lift (existsBV_ n)
    FTSeq n t     -> do vs <- replicateM n (existsFinType t)
                        return $ VSeq (toInteger n) $ Eval.finiteSeqMap (map Eval.ready vs)
    FTTuple ts    -> VTuple <$> mapM (fmap Eval.ready . existsFinType) ts
    FTRecord fs   -> VRecord <$> mapM (traverseSnd (fmap Eval.ready . existsFinType)) fs