-- |
-- Module      :  $Header$
-- Copyright   :  (c) 2014-2016 Galois, Inc.
-- License     :  BSD3
-- Maintainer  :  cryptol@galois.com
-- Stability   :  provisional
-- Portability :  portable
--
-- Desugar into SMTLIB Terminology

{-# LANGUAGE Safe #-}

module Cryptol.TypeCheck.Solver.Numeric.SMT
  ( desugarProp
  , smtName
  , smtFinName
  , ifPropToSmtLib
  , cryImproveModel
  , getVal
  , getVals
  ) where

import           Cryptol.TypeCheck.AST (TVar(TVFree,TVBound))
import           Cryptol.TypeCheck.Solver.InfNat
import           Cryptol.TypeCheck.Solver.Numeric.AST
import           Cryptol.TypeCheck.Solver.Numeric.Simplify(crySimplify)
import           Cryptol.Utils.Misc ( anyJust )
import           Cryptol.Utils.Panic ( panic )

import           Data.List ( partition, unfoldr )
import           Data.Map ( Map )
import qualified Data.Map as Map
import qualified Data.Set as Set
import           SimpleSMT ( SExpr )
import qualified SimpleSMT as SMT
import           MonadLib


--------------------------------------------------------------------------------
-- Desugar to SMT
--------------------------------------------------------------------------------

-- XXX: Expanding the if-then-elses could make things large.
-- Perhaps keep them as first class things, in hope that the solver
-- can do something more clever with that?


-- | Assumes simplified, linear, finite, defined expressions.
desugarExpr :: Expr -> IfExpr Expr
desugarExpr expr =
  do es <- mapM desugarExpr (cryExprExprs expr)
     case (expr,es) of
       (Min {}, [x,y]) -> If (x :>: y) (return y) (return x)
       (Max {}, [x,y]) -> If (x :>: y) (return x) (return y)
       (LenFromThenTo {}, [ x@(K (Nat a)), K (Nat b), z ])

          -- going down
          | a > b -> If (z :>: x) (return zero)
                                  (return (Div (x :- z) step :+ one))

          -- goind up
          | b > a -> If (x :>: z) (return zero)
                                  (return (Div (z :- x) step :+ one))

          where step = K (Nat (abs (a - b)))

       _ -> return (cryRebuildExpr expr es)


-- | Assumes simplified, linear, defined.
desugarProp :: Prop -> IfExpr Prop
desugarProp prop =
  case prop of
    PFalse      -> return PFalse
    PTrue       -> return PTrue
    Not p       -> Not   `fmap` desugarProp p
    p :&& q     -> (:&&) `fmap` desugarProp p `ap` desugarProp q
    p :|| q     -> (:||) `fmap` desugarProp p `ap` desugarProp q
    Fin (Var _) -> return prop
    x :==: y    -> (:==:) `fmap` desugarExpr x `ap` desugarExpr y
    x :>: y     -> (:>:)  `fmap` desugarExpr x `ap` desugarExpr y

    Fin _     -> unexpected
    _ :== _   -> unexpected
    _ :>= _   -> unexpected
    _ :> _    -> unexpected

  where
  unexpected = panic "desugarProp" [ show (ppProp prop) ]


ifPropToSmtLib :: IfExpr Prop -> SExpr
ifPropToSmtLib ifProp =
  case ifProp of
    Impossible -> SMT.bool False -- Sholdn't really matter
    Return p   -> propToSmtLib p
    If p q r   -> SMT.ite (propToSmtLib p) (ifPropToSmtLib q) (ifPropToSmtLib r)

propToSmtLib :: Prop -> SExpr
propToSmtLib prop =
  case prop of
    PFalse       -> SMT.bool False
    PTrue        -> SMT.bool True
    Not p        -> case p of
                      Fin _   -> SMT.not (propToSmtLib p)

                      -- It is IMPORTANT that the fin constraints are outside
                      -- the not.
                      x :>: y -> addFin $ SMT.geq (exprToSmtLib y)
                                                  (exprToSmtLib x)
                      _ -> unexpected


    p :&& q     -> SMT.and (propToSmtLib p) (propToSmtLib q)
    p :|| q     -> SMT.or  (propToSmtLib p) (propToSmtLib q)
    Fin (Var x) -> fin x

    {- For the linear constraints, if the term is finite, then all of
       its variables must have been finite.

       XXX: Adding the `fin` decls at the leaves leads to some duplication:
       We could add them just once for each conjunctoin of simple formulas,
       but I am not sure how much this matters.
    -}
    x :==: y    -> addFin $ SMT.eq (exprToSmtLib x) (exprToSmtLib y)
    x :>: y     -> addFin $ SMT.gt (exprToSmtLib x) (exprToSmtLib y)

    Fin _       -> unexpected
    _ :== _     -> unexpected
    _ :>= _     -> unexpected
    _ :> _      -> unexpected

  where
  unexpected = panic "propToSmtLib" [ show (ppProp prop) ]
  fin x      = SMT.const (smtFinName x)

  addFin e   = foldr (\x' e' -> SMT.and (fin x') e') e
                     (Set.toList (cryPropFVS prop))

exprToSmtLib :: Expr -> SExpr
exprToSmtLib expr =

  case expr of
    K (Nat n)           -> SMT.int n
    Var a               -> SMT.const (smtName a)
    x :+ y              -> SMT.add (exprToSmtLib x) (exprToSmtLib y)
    x :- y              -> SMT.sub (exprToSmtLib x) (exprToSmtLib y)
    x :* y              -> SMT.mul (exprToSmtLib x) (exprToSmtLib y)
    Div x y             -> SMT.div (exprToSmtLib x) (exprToSmtLib y)
    Mod x y             -> SMT.mod (exprToSmtLib x) (exprToSmtLib y)

    K Inf               -> unexpected
    _ :^^ _             -> unexpected
    Min {}              -> unexpected
    Max {}              -> unexpected
    Width {}            -> unexpected
    LenFromThen {}      -> unexpected
    LenFromThenTo {}    -> unexpected

  where
  unexpected = panic "exprToSmtLib" [ show (ppExpr expr) ]


-- | The name of a variable in the SMT translation.
smtName :: Name -> String
smtName a = case a of
              SysName n -> name "s" n
              UserName tv -> case tv of
                               TVFree n _ _ _ -> name "u" n
                               TVBound n _    -> name "k" n

  where
  name p n = case divMod n 26 of
               (q,r) -> p ++ toEnum (fromEnum 'a' + r) :
                              (if q == 0 then "" else show q)


-- | The name of a boolean variable, representing `fin x`.
smtFinName :: Name -> String
smtFinName x = "fin_" ++ smtName x





--------------------------------------------------------------------------------
-- Models
--------------------------------------------------------------------------------

{- | Get the value for the given name.
      * Assumes that we are in a SAT state (i.e., there is a model)
      * Assumes that the name is in the model -}
getVal :: SMT.Solver -> Name -> IO Nat'
getVal s a =
  do yes <- isInf a
     if yes
       then return Inf
       else do v <- SMT.getConst s (smtName a)
               case v of
                 SMT.Int x | x >= 0 -> return (Nat x)
                 _ -> panic "cryCheck.getVal" [ "Not a natural number", show v ]

  where
  isInf v = do yes <- SMT.getConst s (smtFinName v)
               case yes of
                 SMT.Bool ans -> return (not ans)
                 _            -> panic "cryCheck.isInf"
                                       [ "Not a boolean value", show yes ]


-- | Get the values for the given names.
getVals :: SMT.Solver -> [Name] -> IO (Map Name Nat')
getVals s xs =
  do es <- mapM (getVal s) xs
     return (Map.fromList (zip xs es))


-- | Convert a bunch of improving equations into an idempotent substitution.
-- Assumes that the equations don't have loops.
toSubst :: Map Name Expr -> Subst
toSubst m0 = last (m0 : unfoldr step m0)
  where step m = do m1 <- anyJust (apSubst m) m
                    return (m1,m1)


{- | Given a model, compute an improving substitution, implied by the model.
The entries in the substitution look like this:

  * @x = A@         variable @x@ must equal constant @A@

  * @x = y@         variable @x@ must equal variable @y@

  * @x = A * y + B@ (coming soon)
                    variable @x@ is a linear function of @y@,
                    @A@ and @B@ are natural numbers.
-}



{- | We are mostly interested in improving unification variables.
However, it is also useful to improve skolem variables, as this could
turn non-linear constraints into linear ones.  For example, if we
have a constraint @x * y = z@, and we can figure out that @x@ must be 5,
then we end up with a linear constraint @5 * y = z@.
-}
cryImproveModel :: SMT.Solver -> SMT.Logger -> Map Name Nat'
                -> IO (Map Name Expr, [Prop])
cryImproveModel solver logger model =
  do (imps,subGoals) <- go Map.empty [] initialTodo
     return (toSubst imps, subGoals)

  where
  -- Process unification variables first.  That way, if we get `x = y`, we'd
  -- prefer `x` to be a unification variable.
  initialTodo    = uncurry (++) $ partition (isUniVar . fst) $ Map.toList model
  isUniVar x     = case x of
                     UserName (TVFree {}) -> True
                     _                    -> False


  -- done:  the set of known improvements
  -- extra: the collection of inferred sub-goals
  go done extra []             = return (done,extra)
  go done extra ((x,e) : rest) =

    -- x = K?
    do mbCounter <- cryMustEqualK solver (Map.keys model) x e
       case mbCounter of
         Nothing -> go (Map.insert x (K e) done) extra rest
         Just ce -> goV ce done extra [] x e rest


  -- ce:    a counter example to `x = e`
  -- done:  known improvements
  -- extra: known sub-goals
  -- todo:  more work to process once we are done with `x`.
  goV _  done extra todo _ _ [] = go done extra (reverse todo)
  goV ce done extra todo x e ((y,e') : more)
    -- x = y?
    | e == e' = do yesK <- cryMustEqualV solver x y
                   if yesK then go (Map.insert x (Var y) done)
                                   extra
                                   (reverse todo ++ more)
                           else tryLR

    | otherwise = tryLR

    where
    next = goV ce done extra ((y,e'):todo) x e more

    tryLR =
      do mb <- tryLR_with x e y e'
         case mb of
           Just (r,subGoals) -> go (Map.insert x r done)
                                   (subGoals ++ extra)
                                   (reverse todo ++ more)
           Nothing ->
             do mb1 <- tryLR_with y e' x e
                case mb1 of
                  Nothing -> next
                  Just (r, subGoals) -> go (Map.insert y r done)
                                           (subGoals ++ extra)
                                           (reverse todo ++ more)


    tryLR_with v1 v1Expr v2 v2Expr =
      case ( isUniVar v1
           , v1Expr
           , v2Expr
           , Map.lookup v1 ce
           , Map.lookup v2 ce
           ) of
        (True, x1, y1, Just x2, Just y2) ->
          cryCheckLinRel solver logger v2 v1 (y1,x1) (y2,x2)
        _ -> return Nothing







-- | Try to prove the given expression.
checkUnsat :: SMT.Solver -> SExpr -> IO Bool
checkUnsat s e =
  do SMT.push s
     SMT.assert s e
     res <- SMT.check s
     SMT.pop s
     return (res == SMT.Unsat)


-- | Try to prove the given expression.
-- If we fail, we try to give a counter example.
-- If the answer is unknown, then we return an empty counter example.
getCounterExample :: SMT.Solver -> [Name] -> SExpr -> IO (Maybe (Map Name Nat'))
getCounterExample s xs e =
  do SMT.push s
     SMT.assert s e
     res <- SMT.check s
     ans <- case res of
              SMT.Unsat   -> return Nothing
              SMT.Unknown -> return (Just Map.empty)
              SMT.Sat     -> Just `fmap` getVals s xs
     SMT.pop s
     return ans





-- | Is this the only possible value for the constant, under the current
-- assumptions.
-- Assumes that we are in a 'Sat' state.
-- Returns 'Nothing' if the variables must always match the given value.
-- Otherwise, we return a counter-example (which may be empty, if uniknown)
cryMustEqualK :: SMT.Solver -> [Name] ->
                 Name -> Nat' -> IO (Maybe (Map Name Nat'))
cryMustEqualK solver xs x val =
  case val of
    Inf   -> getCounterExample solver xs (SMT.const (smtFinName x))
    Nat n -> getCounterExample solver xs $
             SMT.not $
             SMT.and (SMT.const $ smtFinName x)
                     (SMT.eq (SMT.const (smtName x)) (SMT.int n))



-- | Do these two variables need to always be the same, under the current
-- assumptions.
-- Assumes that we are in a 'Sat' state.
cryMustEqualV :: SMT.Solver -> Name -> Name -> IO Bool
cryMustEqualV solver x y =
     checkUnsat solver $
        SMT.not $
        SMT.or (SMT.not (fin x) `SMT.and` SMT.not (fin y))
               (fin x `SMT.and` fin y `SMT.and` SMT.eq (var x) (var y))

  where fin a = SMT.const (smtFinName a)
        var a = SMT.const (smtName a)


-- | Try to find a linear relation between the two variables, based
-- on two observed data points.
-- NOTE:  The variable being defined is the SECOND name.
cryCheckLinRel :: SMT.Solver -> SMT.Logger ->
         {- x -} Name {- ^ Definition in terms of this variable. -} ->
         {- y -} Name {- ^ Define this variable. -} ->
                 (Nat',Nat') {- ^ Values in one model (x,y) -} ->
                 (Nat',Nat') {- ^ Values in another model (x,y) -} ->
                 IO (Maybe (Expr,[Prop]))
                 {- ^ Either nothing, or an improving expression, and any
                      additional obligations -}
cryCheckLinRel s logger x y p1 p2 =
  -- First, try to find a linear relation that holds in all finite cases.
  do SMT.push s
     SMT.assert s (isFin x)
     SMT.assert s (isFin y)
     ans <- case (p1,p2) of
              ((Nat x1, Nat y1), (Nat x2, Nat y2)) ->
                  checkLR x1 y1 x2 y2

              ((Inf,    Inf),    (Nat x2, Nat y2)) ->
                 mbGoOn (getFinPt x2) $ \(x1,y1) -> checkLR x1 y1 x2 y2

              ((Nat x1, Nat y1), (Inf,    Inf)) ->
                 mbGoOn (getFinPt x1) $ \(x2,y2) -> checkLR x1 y1 x2 y2

              _ -> return Nothing

     SMT.pop s


     -- Next, check the infinite cases: if @y = A * x + B@, then
     -- either both @x@ and @y@ must be infinite or they both must be finite.
     -- Note that we don't consider relations where A = 0: because they
     -- would be handled when we checked that @y@ is a constant.
     case ans of
       Nothing -> return Nothing
       Just e ->
         do SMT.push s
            SMT.assert s (SMT.not (SMT.eq (isFin x) (isFin y)))
            c <- SMT.check s
            SMT.pop s
            case c of
              SMT.Unsat -> return (Just e)
              _         -> return Nothing

  where
  isFin a = SMT.const (smtFinName a)

  -- XXX: Duplicates `cryDefined`
  -- The constraints are always of the form: x >= K, or K >= x
  wellDefined e =
    case e of
      (K (Nat a) :* t) :- K (Nat b) ->
        let c = div (b + a - 1) a
        in [ t :>= K (Nat c) ]

      K (Nat b) :- (K (Nat a) :* t)
        -> let c = div b a
           in [ K (Nat c) :>= t ]

      a  :- b -> [ a :>= b ]

      _ -> []


  checkLR x1 y1 x2 y2 =
    do SMT.logMessage logger ("checkLR: " ++ show (x1,y1) ++ "   "
                                          ++ show (x2,y2))
       mbGoOn (return (linRel (x1,y1) (x2,y2))) (\(a,b) ->
         do let sumTerm v
                   | b == 0    = v
                   | b < 0     = v :- K (Nat (negate b))
                   | otherwise = v :+ K (Nat b)

                expr
                  | a == 1     = sumTerm (Var x)
                  | a <  0     = K (Nat b) :- K (Nat (negate a)) :* Var x
                  | otherwise  = sumTerm (K (Nat a) :* Var x)

            SMT.logMessage logger ("candidate: " ++ show (ppProp (Var y :==: expr)))

            proved <- checkUnsat s
                    $ propToSmtLib $ crySimplify
                    $ Not $ Var y :==: expr

            if not proved
               then SMT.logMessage logger "failed" >> return Nothing
               else return (Just (expr,wellDefined expr)))

  -- Try to get an example of another point, which is finite, and at
  -- different @x@ coordinate.
  getFinPt otherX =
    do SMT.push s
       SMT.assert s (SMT.not (SMT.eq (SMT.const (smtName x)) (SMT.int otherX)))
       smtAns <- SMT.check s
       ans <- case smtAns of
                SMT.Sat ->
                  do vX <- SMT.getConst s (smtName x)
                     vY <- SMT.getConst s (smtName y)
                     case (vX, vY) of
                       (SMT.Int vx, SMT.Int vy)
                          | vx >= 0 && vy >= 0 -> return (Just (vx,vy))
                       _ -> return Nothing
                _ -> return Nothing
       SMT.pop s
       return ans

  mbGoOn m k = do ans <- m
                  case ans of
                    Nothing -> return Nothing
                    Just a  -> k a

{- | Compute a linear relation through two concrete points.
Try to find a relation of the form @y = a * x + b@.
Depending on the signs of @a@ and @b@, we need additional checks,
to ensure tha @a * x + b@ is valid.

y1 = A * x1 + B
y2 = A * x2 + B
(y2 - y1) = A * (x2 - x1)

A = (y2 - y1) / (x2 - x1)
B = y1 - A * x1
-}
linRel :: (Integer,Integer)       {- ^ First point -} ->
          (Integer,Integer)       {- ^ Second point -} ->
          Maybe (Integer,Integer) {- ^ (A,B) -}
linRel (x1,y1) (x2,y2) =
  do guard (x1 /= x2)
     let (a,r) = divMod (y2 - y1) (x2 - x1)
     guard (r == 0 && a /= 0)    -- Not interested in A = 0
     let b = y1 - a * x1
     guard $ not $ a < 0 && b < 0   -- No way this will give a natural number.
     return (a,b)