module Cryptol.TypeCheck.Solver.Numeric.SMT
( desugarProp
, smtName
, smtFinName
, ifPropToSmtLib
, cryImproveModel
, getVal
, getVals
) where
import Cryptol.TypeCheck.AST (TVar(TVFree,TVBound))
import Cryptol.TypeCheck.Solver.InfNat
import Cryptol.TypeCheck.Solver.Numeric.AST
import Cryptol.TypeCheck.Solver.Numeric.Simplify(crySimplify)
import Cryptol.Utils.Misc ( anyJust )
import Cryptol.Utils.Panic ( panic )
import Data.List ( partition, unfoldr )
import Data.Map ( Map )
import qualified Data.Map as Map
import qualified Data.Set as Set
import SimpleSMT ( SExpr )
import qualified SimpleSMT as SMT
import MonadLib
desugarExpr :: Expr -> IfExpr Expr
desugarExpr expr =
do es <- mapM desugarExpr (cryExprExprs expr)
case (expr,es) of
(Min {}, [x,y]) -> If (x :>: y) (return y) (return x)
(Max {}, [x,y]) -> If (x :>: y) (return x) (return y)
(LenFromThenTo {}, [ x@(K (Nat a)), K (Nat b), z ])
| a > b -> If (z :>: x) (return zero)
(return (Div (x :- z) step :+ one))
| b > a -> If (x :>: z) (return zero)
(return (Div (z :- x) step :+ one))
where step = K (Nat (abs (a b)))
_ -> return (cryRebuildExpr expr es)
desugarProp :: Prop -> IfExpr Prop
desugarProp prop =
case prop of
PFalse -> return PFalse
PTrue -> return PTrue
Not p -> Not `fmap` desugarProp p
p :&& q -> (:&&) `fmap` desugarProp p `ap` desugarProp q
p :|| q -> (:||) `fmap` desugarProp p `ap` desugarProp q
Fin (Var _) -> return prop
x :==: y -> (:==:) `fmap` desugarExpr x `ap` desugarExpr y
x :>: y -> (:>:) `fmap` desugarExpr x `ap` desugarExpr y
Fin _ -> unexpected
_ :== _ -> unexpected
_ :>= _ -> unexpected
_ :> _ -> unexpected
where
unexpected = panic "desugarProp" [ show (ppProp prop) ]
ifPropToSmtLib :: IfExpr Prop -> SExpr
ifPropToSmtLib ifProp =
case ifProp of
Impossible -> SMT.bool False
Return p -> propToSmtLib p
If p q r -> SMT.ite (propToSmtLib p) (ifPropToSmtLib q) (ifPropToSmtLib r)
propToSmtLib :: Prop -> SExpr
propToSmtLib prop =
case prop of
PFalse -> SMT.bool False
PTrue -> SMT.bool True
Not p -> case p of
Fin _ -> SMT.not (propToSmtLib p)
x :>: y -> addFin $ SMT.geq (exprToSmtLib y)
(exprToSmtLib x)
_ -> unexpected
p :&& q -> SMT.and (propToSmtLib p) (propToSmtLib q)
p :|| q -> SMT.or (propToSmtLib p) (propToSmtLib q)
Fin (Var x) -> fin x
x :==: y -> addFin $ SMT.eq (exprToSmtLib x) (exprToSmtLib y)
x :>: y -> addFin $ SMT.gt (exprToSmtLib x) (exprToSmtLib y)
Fin _ -> unexpected
_ :== _ -> unexpected
_ :>= _ -> unexpected
_ :> _ -> unexpected
where
unexpected = panic "propToSmtLib" [ show (ppProp prop) ]
fin x = SMT.const (smtFinName x)
addFin e = foldr (\x' e' -> SMT.and (fin x') e') e
(Set.toList (cryPropFVS prop))
exprToSmtLib :: Expr -> SExpr
exprToSmtLib expr =
case expr of
K (Nat n) -> SMT.int n
Var a -> SMT.const (smtName a)
x :+ y -> SMT.add (exprToSmtLib x) (exprToSmtLib y)
x :- y -> SMT.sub (exprToSmtLib x) (exprToSmtLib y)
x :* y -> SMT.mul (exprToSmtLib x) (exprToSmtLib y)
Div x y -> SMT.div (exprToSmtLib x) (exprToSmtLib y)
Mod x y -> SMT.mod (exprToSmtLib x) (exprToSmtLib y)
K Inf -> unexpected
_ :^^ _ -> unexpected
Min {} -> unexpected
Max {} -> unexpected
Width {} -> unexpected
LenFromThen {} -> unexpected
LenFromThenTo {} -> unexpected
where
unexpected = panic "exprToSmtLib" [ show (ppExpr expr) ]
smtName :: Name -> String
smtName a = case a of
SysName n -> name "s" n
UserName tv -> case tv of
TVFree n _ _ _ -> name "u" n
TVBound n _ -> name "k" n
where
name p n = case divMod n 26 of
(q,r) -> p ++ toEnum (fromEnum 'a' + r) :
(if q == 0 then "" else show q)
smtFinName :: Name -> String
smtFinName x = "fin_" ++ smtName x
getVal :: SMT.Solver -> Name -> IO Nat'
getVal s a =
do yes <- isInf a
if yes
then return Inf
else do v <- SMT.getConst s (smtName a)
case v of
SMT.Int x | x >= 0 -> return (Nat x)
_ -> panic "cryCheck.getVal" [ "Not a natural number", show v ]
where
isInf v = do yes <- SMT.getConst s (smtFinName v)
case yes of
SMT.Bool ans -> return (not ans)
_ -> panic "cryCheck.isInf"
[ "Not a boolean value", show yes ]
getVals :: SMT.Solver -> [Name] -> IO (Map Name Nat')
getVals s xs =
do es <- mapM (getVal s) xs
return (Map.fromList (zip xs es))
toSubst :: Map Name Expr -> Subst
toSubst m0 = last (m0 : unfoldr step m0)
where step m = do m1 <- anyJust (apSubst m) m
return (m1,m1)
cryImproveModel :: SMT.Solver -> SMT.Logger -> Map Name Nat'
-> IO (Map Name Expr, [Prop])
cryImproveModel solver logger model =
do (imps,subGoals) <- go Map.empty [] initialTodo
return (toSubst imps, subGoals)
where
initialTodo = uncurry (++) $ partition (isUniVar . fst) $ Map.toList model
isUniVar x = case x of
UserName (TVFree {}) -> True
_ -> False
go done extra [] = return (done,extra)
go done extra ((x,e) : rest) =
do mbCounter <- cryMustEqualK solver (Map.keys model) x e
case mbCounter of
Nothing -> go (Map.insert x (K e) done) extra rest
Just ce -> goV ce done extra [] x e rest
goV _ done extra todo _ _ [] = go done extra (reverse todo)
goV ce done extra todo x e ((y,e') : more)
| e == e' = do yesK <- cryMustEqualV solver x y
if yesK then go (Map.insert x (Var y) done)
extra
(reverse todo ++ more)
else tryLR
| otherwise = tryLR
where
next = goV ce done extra ((y,e'):todo) x e more
tryLR =
do mb <- tryLR_with x e y e'
case mb of
Just (r,subGoals) -> go (Map.insert x r done)
(subGoals ++ extra)
(reverse todo ++ more)
Nothing ->
do mb1 <- tryLR_with y e' x e
case mb1 of
Nothing -> next
Just (r, subGoals) -> go (Map.insert y r done)
(subGoals ++ extra)
(reverse todo ++ more)
tryLR_with v1 v1Expr v2 v2Expr =
case ( isUniVar v1
, v1Expr
, v2Expr
, Map.lookup v1 ce
, Map.lookup v2 ce
) of
(True, x1, y1, Just x2, Just y2) ->
cryCheckLinRel solver logger v2 v1 (y1,x1) (y2,x2)
_ -> return Nothing
checkUnsat :: SMT.Solver -> SExpr -> IO Bool
checkUnsat s e =
do SMT.push s
SMT.assert s e
res <- SMT.check s
SMT.pop s
return (res == SMT.Unsat)
getCounterExample :: SMT.Solver -> [Name] -> SExpr -> IO (Maybe (Map Name Nat'))
getCounterExample s xs e =
do SMT.push s
SMT.assert s e
res <- SMT.check s
ans <- case res of
SMT.Unsat -> return Nothing
SMT.Unknown -> return (Just Map.empty)
SMT.Sat -> Just `fmap` getVals s xs
SMT.pop s
return ans
cryMustEqualK :: SMT.Solver -> [Name] ->
Name -> Nat' -> IO (Maybe (Map Name Nat'))
cryMustEqualK solver xs x val =
case val of
Inf -> getCounterExample solver xs (SMT.const (smtFinName x))
Nat n -> getCounterExample solver xs $
SMT.not $
SMT.and (SMT.const $ smtFinName x)
(SMT.eq (SMT.const (smtName x)) (SMT.int n))
cryMustEqualV :: SMT.Solver -> Name -> Name -> IO Bool
cryMustEqualV solver x y =
checkUnsat solver $
SMT.not $
SMT.or (SMT.not (fin x) `SMT.and` SMT.not (fin y))
(fin x `SMT.and` fin y `SMT.and` SMT.eq (var x) (var y))
where fin a = SMT.const (smtFinName a)
var a = SMT.const (smtName a)
cryCheckLinRel :: SMT.Solver -> SMT.Logger ->
Name ->
Name ->
(Nat',Nat') ->
(Nat',Nat') ->
IO (Maybe (Expr,[Prop]))
cryCheckLinRel s logger x y p1 p2 =
do SMT.push s
SMT.assert s (isFin x)
SMT.assert s (isFin y)
ans <- case (p1,p2) of
((Nat x1, Nat y1), (Nat x2, Nat y2)) ->
checkLR x1 y1 x2 y2
((Inf, Inf), (Nat x2, Nat y2)) ->
mbGoOn (getFinPt x2) $ \(x1,y1) -> checkLR x1 y1 x2 y2
((Nat x1, Nat y1), (Inf, Inf)) ->
mbGoOn (getFinPt x1) $ \(x2,y2) -> checkLR x1 y1 x2 y2
_ -> return Nothing
SMT.pop s
case ans of
Nothing -> return Nothing
Just e ->
do SMT.push s
SMT.assert s (SMT.not (SMT.eq (isFin x) (isFin y)))
c <- SMT.check s
SMT.pop s
case c of
SMT.Unsat -> return (Just e)
_ -> return Nothing
where
isFin a = SMT.const (smtFinName a)
wellDefined e =
case e of
(K (Nat a) :* t) :- K (Nat b) ->
let c = div (b + a 1) a
in [ t :>= K (Nat c) ]
K (Nat b) :- (K (Nat a) :* t)
-> let c = div b a
in [ K (Nat c) :>= t ]
a :- b -> [ a :>= b ]
_ -> []
checkLR x1 y1 x2 y2 =
do SMT.logMessage logger ("checkLR: " ++ show (x1,y1) ++ " "
++ show (x2,y2))
mbGoOn (return (linRel (x1,y1) (x2,y2))) (\(a,b) ->
do let sumTerm v
| b == 0 = v
| b < 0 = v :- K (Nat (negate b))
| otherwise = v :+ K (Nat b)
expr
| a == 1 = sumTerm (Var x)
| a < 0 = K (Nat b) :- K (Nat (negate a)) :* Var x
| otherwise = sumTerm (K (Nat a) :* Var x)
SMT.logMessage logger ("candidate: " ++ show (ppProp (Var y :==: expr)))
proved <- checkUnsat s
$ propToSmtLib $ crySimplify
$ Not $ Var y :==: expr
if not proved
then SMT.logMessage logger "failed" >> return Nothing
else return (Just (expr,wellDefined expr)))
getFinPt otherX =
do SMT.push s
SMT.assert s (SMT.not (SMT.eq (SMT.const (smtName x)) (SMT.int otherX)))
smtAns <- SMT.check s
ans <- case smtAns of
SMT.Sat ->
do vX <- SMT.getConst s (smtName x)
vY <- SMT.getConst s (smtName y)
case (vX, vY) of
(SMT.Int vx, SMT.Int vy)
| vx >= 0 && vy >= 0 -> return (Just (vx,vy))
_ -> return Nothing
_ -> return Nothing
SMT.pop s
return ans
mbGoOn m k = do ans <- m
case ans of
Nothing -> return Nothing
Just a -> k a
linRel :: (Integer,Integer) ->
(Integer,Integer) ->
Maybe (Integer,Integer)
linRel (x1,y1) (x2,y2) =
do guard (x1 /= x2)
let (a,r) = divMod (y2 y1) (x2 x1)
guard (r == 0 && a /= 0)
let b = y1 a * x1
guard $ not $ a < 0 && b < 0
return (a,b)