-- |
-- Module      :  Cryptol.TypeCheck.Solver.SMT
-- Copyright   :  (c) 2013-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable
{-# LANGUAGE Safe #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE FlexibleContexts #-}
{-# Language FlexibleInstances #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE TypeSynonymInstances #-}

module Cryptol.TypeCheck.Solver.SMT
  ( -- * Setup
    Solver
  , withSolver
  , isNumeric

    -- * Debugging
  , debugBlock
  , debugLog

    -- * Proving stuff
  , proveImp
  , checkUnsolvable
  , tryGetModel
  , shrinkModel
  ) where

import           SimpleSMT (SExpr)
import qualified SimpleSMT as SMT
import           Data.Map ( Map )
import qualified Data.Map as Map
import qualified Data.Set as Set
import           Data.Maybe(catMaybes)
import           Data.List(partition)
import           Control.Exception
import           Control.Monad(msum,zipWithM,void)
import           Data.Char(isSpace)
import           Text.Read(readMaybe)
import qualified System.IO.Strict as StrictIO
import           System.FilePath((</>))
import           System.Directory(doesFileExist)

import Cryptol.Prelude(cryptolTcContents)
import Cryptol.TypeCheck.Type
import Cryptol.TypeCheck.InferTypes
import Cryptol.TypeCheck.Solver.InfNat(Nat'(..))
import Cryptol.TypeCheck.TypePat hiding ((~>),(~~>))
import Cryptol.TypeCheck.Subst(Subst)
import Cryptol.Utils.Panic
import Cryptol.Utils.PP -- ( Doc )




-- | An SMT solver packed with a logger for debugging.
data Solver = Solver
  { solver    :: SMT.Solver
    -- ^ The actual solver

  , logger    :: SMT.Logger
    -- ^ For debugging
  }

-- | Execute a computation with a fresh solver instance.
withSolver :: SolverConfig -> (Solver -> IO a) -> IO a
withSolver SolverConfig{ .. } =
     bracket
       (do logger <- if solverVerbose > 0 then SMT.newLogger 0
                                          else return quietLogger
           let smtDbg = if solverVerbose > 1 then Just logger else Nothing
           solver <- SMT.newSolver solverPath solverArgs smtDbg
           _ <- SMT.setOptionMaybe solver ":global-decls" "false"
           -- SMT.setLogic solver "QF_LIA"
           let sol = Solver { .. }
           loadTcPrelude sol solverPreludePath
           return sol)
       (\s -> void $ SMT.stop (solver s))

  where
  quietLogger = SMT.Logger { SMT.logMessage = \_ -> return ()
                           , SMT.logLevel   = return 0
                           , SMT.logSetLevel= \_ -> return ()
                           , SMT.logTab     = return ()
                           , SMT.logUntab   = return ()
                           }


-- | Load the definitions used for type checking.
loadTcPrelude :: Solver -> [FilePath] {- ^ Search in this paths -} -> IO ()
loadTcPrelude s [] = loadString s cryptolTcContents
loadTcPrelude s (p : ps) =
  do let file = p </> "CryptolTC.z3"
     yes <- doesFileExist file
     if yes then loadFile s file
            else loadTcPrelude s ps


loadFile :: Solver -> FilePath -> IO ()
loadFile s file = loadString s =<< StrictIO.readFile file

loadString :: Solver -> String -> IO ()
loadString s str = go (dropComments str)
  where
  go txt
    | all isSpace txt = return ()
    | otherwise =
      case SMT.readSExpr txt of
        Just (e,rest) -> SMT.command (solver s) e >> go rest
        Nothing       -> panic "loadFile" [ "Failed to parse SMT file."
                                          , txt
                                          ]

  dropComments = unlines . map dropComment . lines
  dropComment xs = case break (== ';') xs of
                     (as,_:_) -> as
                     _ -> xs




--------------------------------------------------------------------------------
-- Debugging


debugBlock :: Solver -> String -> IO a -> IO a
debugBlock s@Solver { .. } name m =
  do debugLog s name
     SMT.logTab logger
     a <- m
     SMT.logUntab logger
     return a

class DebugLog t where
  debugLog :: Solver -> t -> IO ()

  debugLogList :: Solver -> [t] -> IO ()
  debugLogList s ts = case ts of
                        [] -> debugLog s "(none)"
                        _  -> mapM_ (debugLog s) ts

instance DebugLog Char where
  debugLog s x     = SMT.logMessage (logger s) (show x)
  debugLogList s x = SMT.logMessage (logger s) x

instance DebugLog a => DebugLog [a] where
  debugLog = debugLogList

instance DebugLog a => DebugLog (Maybe a) where
  debugLog s x = case x of
                   Nothing -> debugLog s "(nothing)"
                   Just a  -> debugLog s a

instance DebugLog Doc where
  debugLog s x = debugLog s (show x)

instance DebugLog Type where
  debugLog s x = debugLog s (pp x)

instance DebugLog Goal where
  debugLog s x = debugLog s (goal x)

instance DebugLog Subst where
  debugLog s x = debugLog s (pp x)
--------------------------------------------------------------------------------





-- | Returns goals that were not proved
proveImp :: Solver -> [Prop] -> [Goal] -> IO [Goal]
proveImp sol ps gs0 =
  debugBlock sol "PROVE IMP" $
  do let gs1       = concatMap flatGoal gs0
         (gs,rest) = partition (isNumeric . goal) gs1
         numAsmp   = filter isNumeric (concatMap pSplitAnd ps)
         vs        = Set.toList (fvs (numAsmp, map goal gs))
     tvs <- debugBlock sol "VARIABLES" $
       do SMT.push (solver sol)
          Map.fromList <$> zipWithM (declareVar sol) [ 0 .. ] vs
     debugBlock sol "ASSUMPTIONS" $
       mapM_ (assume sol tvs) numAsmp
     gs' <- mapM (prove sol tvs) gs
     SMT.pop (solver sol)
     return (catMaybes gs' ++ rest)

-- | Check if the given goals are known to be unsolvable.
checkUnsolvable :: Solver -> [Goal] -> IO Bool
checkUnsolvable sol gs0 =
  debugBlock sol "CHECK UNSOLVABLE" $
  do let ps = filter isNumeric
            $ map goal
            $ concatMap flatGoal gs0
         vs = Set.toList (fvs ps)
     tvs <- debugBlock sol "VARIABLES" $
       do push sol
          Map.fromList <$> zipWithM (declareVar sol) [ 0 .. ] vs
     ans <- unsolvable sol tvs ps
     pop sol
     return ans

tryGetModel :: Solver -> [TVar] -> [Prop] -> IO (Maybe [(TVar,Nat')])
tryGetModel sol as ps =
  debugBlock sol "TRY GET MODEL" $
  do push sol
     tvs <- Map.fromList <$> zipWithM (declareVar sol) [ 0 .. ] as
     mapM_ (assume sol tvs) ps
     sat <- SMT.check (solver sol)
     su <- case sat of
             SMT.Sat ->
               case as of
                 [] -> return (Just [])
                 _ -> do res <- SMT.getExprs (solver sol) (Map.elems tvs)
                         let parse x = do e <- Map.lookup x tvs
                                          t <- parseNum =<< lookup e res
                                          return (x, t)
                         return (mapM parse as)
             _ -> return Nothing
     pop sol
     return su

  where
  parseNum a
    | SMT.Other s <- a
    , SMT.List [con,val,isFin,isErr] <- s
    , SMT.Atom "mk-infnat" <- con
    , SMT.Atom "false"     <- isErr
    , SMT.Atom fin         <- isFin
    , SMT.Atom v           <- val
    , Just n               <- readMaybe v
    = Just (if fin == "false" then Inf else Nat n)

  parseNum _ = Nothing

shrinkModel :: Solver -> [TVar] -> [Prop] -> [(TVar,Nat')] -> IO [(TVar,Nat')]
shrinkModel sol as ps0 mdl = go [] ps0 mdl
  where
  go done ps ((x,Nat k) : more) =
    do k1 <- shrink1 ps x k
       go ((x,Nat k1) : done) ((tNum k1 >== TVar x) : ps) more

  go done ps ((x,i) : more) = go ((x,i) : done) ps more
  go done _ [] = return done

  shrink1 ps x k
    | k == 0 = return 0
    | otherwise =
      do let k1 = div k 2
             p1 = tNum k1 >== TVar x
         mb <- tryGetModel sol as (p1 : ps)
         case mb of
           Nothing     -> return k
           Just newMdl ->
             case lookup x newMdl of
               Just (Nat k2) -> shrink1 ps x k2
               _ -> panic "shrink" ["model is missing variable", show x]



--------------------------------------------------------------------------------

push :: Solver -> IO ()
push sol = SMT.push (solver sol)

pop :: Solver -> IO ()
pop sol = SMT.pop (solver sol)


declareVar :: Solver -> Int -> TVar -> IO (TVar, SExpr)
declareVar s x v =
  do let name = (if isFreeTV v then "fv" else "kv") ++ show x
     e <- SMT.declare (solver s) name cryInfNat
     SMT.assert (solver s) (SMT.fun "cryVar" [ e ])
     return (v,e)

assume :: Solver -> TVars -> Prop -> IO ()
assume s tvs p = SMT.assert (solver s) (SMT.fun "cryAssume" [ toSMT tvs p ])

prove :: Solver -> TVars -> Goal -> IO (Maybe Goal)
prove sol tvs g =
  debugBlock sol "PROVE" $
  do let s = solver sol
     push sol
     SMT.assert s (SMT.fun "cryProve" [ toSMT tvs (goal g) ])
     res <- SMT.check s
     pop sol
     case res of
       SMT.Unsat -> return Nothing
       _         -> return (Just g)


-- | Check if some numeric goals are known to be unsolvable.
unsolvable :: Solver -> TVars -> [Prop] -> IO Bool
unsolvable sol tvs ps =
  debugBlock sol "UNSOLVABLE" $
  do SMT.push (solver sol)
     mapM_ (assume sol tvs) ps
     res <- SMT.check (solver sol)
     SMT.pop (solver sol)
     case res of
       SMT.Unsat -> return True
       _         -> return False



--------------------------------------------------------------------------------

-- | Split up the 'And' in a goal
flatGoal :: Goal -> [Goal]
flatGoal g = [ g { goal = p } | p <- pSplitAnd (goal g) ]


-- | Assumes no 'And'
isNumeric :: Prop -> Bool
isNumeric ty = matchDefault False $ msum [ is (|=|), is (|/=|), is (|>=|), is aFin ]
  where
  is f = f ty >> return True


--------------------------------------------------------------------------------

type TVars = Map TVar SExpr

cryInfNat :: SExpr
cryInfNat = SMT.const "InfNat"


toSMT :: TVars -> Type -> SExpr
toSMT tvs ty = matchDefault (panic "toSMT" [ "Unexpected type", show ty ])
  $ msum $ map (\f -> f tvs ty)
  [ aInf            ~> "cryInf"
  , aNat            ~> "cryNat"

  , aFin            ~> "cryFin"
  , (|=|)           ~> "cryEq"
  , (|/=|)          ~> "cryNeq"
  , (|>=|)          ~> "cryGeq"
  , aAnd            ~> "cryAnd"
  , aTrue           ~> "cryTrue"

  , anAdd           ~> "cryAdd"
  , (|-|)           ~> "crySub"
  , aMul            ~> "cryMul"
  , (|^|)           ~> "cryExp"
  , (|/|)           ~> "cryDiv"
  , (|%|)           ~> "cryMod"
  , aMin            ~> "cryMin"
  , aMax            ~> "cryMax"
  , aWidth          ~> "cryWidth"
  , aCeilDiv        ~> "cryCeilDiv"
  , aCeilMod        ~> "cryCeilMod"
  , aLenFromThenTo  ~> "cryLenFromThenTo"

  , anError KNum    ~> "cryErr"
  , anError KProp   ~> "cryErrProp"

  , aTVar           ~> "(unused)"
  ]

--------------------------------------------------------------------------------

(~>) :: Mk a => (Type -> Match a) -> String -> TVars -> Type -> Match SExpr
(m ~> f) tvs t = m t >>= \a -> return (mk tvs f a)

class Mk t where
  mk :: TVars -> String -> t -> SExpr

instance Mk () where
  mk _ f _ = SMT.const f

instance Mk Integer where
  mk _ f x = SMT.fun f [ SMT.int x ]

instance Mk TVar where
  mk tvs _ x = tvs Map.! x

instance Mk Type where
  mk tvs f x = SMT.fun f [toSMT tvs x]

instance Mk TCErrorMessage where
  mk _ f _ = SMT.fun f []

instance Mk (Type,Type) where
  mk tvs f (x,y) = SMT.fun f [ toSMT tvs x, toSMT tvs y]

instance Mk (Type,Type,Type) where
  mk tvs f (x,y,z) = SMT.fun f [ toSMT tvs x, toSMT tvs y, toSMT tvs z ]

--------------------------------------------------------------------------------