module Cryptol.TypeCheck.Solver.SMT (proveImp,checkUnsolvable) where
import SimpleSMT (SExpr)
import qualified SimpleSMT as SMT
import Data.Map ( Map )
import qualified Data.Map as Map
import qualified Data.Set as Set
import Data.Maybe(catMaybes)
import Data.List(partition)
import Control.Monad(msum,zipWithM)
import Cryptol.TypeCheck.Type
import Cryptol.TypeCheck.Solver.CrySAT
import Cryptol.TypeCheck.InferTypes
import Cryptol.TypeCheck.TypePat hiding ((~>),(~~>))
import Cryptol.Utils.Panic
proveImp :: Solver -> [Prop] -> [Goal] -> IO [Goal]
proveImp sol ps gs0 =
debugBlock sol "PROVE IMP" $
do let s = rawSolver sol
let (gs,rest) = partition (isNumeric . goal) gs0
numAsmp = filter isNumeric ps
vs = Set.toList (fvs (numAsmp, map goal gs))
tvs <- debugBlock sol "VARIABLES" $
do SMT.push s
Map.fromList <$> zipWithM (declareVar s) [ 0 .. ] vs
debugBlock sol "ASSUMPTIONS" $
mapM_ (assume s tvs) numAsmp
gs' <- mapM (prove sol tvs) gs
SMT.pop s
return (catMaybes gs' ++ rest)
checkUnsolvable :: Solver -> [Goal] -> IO Bool
checkUnsolvable sol gs0 =
debugBlock sol "CHECK UNSOLVABLE" $
do let s = rawSolver sol
ps = filter isNumeric (map goal gs0)
vs = Set.toList (fvs ps)
tvs <- debugBlock sol "VARIABLES" $
do SMT.push s
Map.fromList <$> zipWithM (declareVar s) [ 0 .. ] vs
ans <- unsolvable sol tvs ps
SMT.pop s
return ans
declareVar :: SMT.Solver -> Int -> TVar -> IO (TVar, SExpr)
declareVar s x v = do let name = (if isFreeTV v then "fv" else "kv") ++ show x
e <- SMT.declare s name cryInfNat
SMT.assert s (SMT.fun "cryVar" [ e ])
return (v,e)
assume :: SMT.Solver -> TVars -> Prop -> IO ()
assume s tvs p = SMT.assert s (SMT.fun "cryAssume" [ toSMT tvs p ])
prove :: Solver -> TVars -> Goal -> IO (Maybe Goal)
prove sol tvs g =
debugBlock sol "PROVE" $
do let s = rawSolver sol
SMT.push s
SMT.assert s (SMT.fun "cryProve" [ toSMT tvs (goal g) ])
res <- SMT.check s
SMT.pop s
case res of
SMT.Unsat -> return Nothing
_ -> return (Just g)
unsolvable :: Solver -> TVars -> [Prop] -> IO Bool
unsolvable sol tvs ps =
debugBlock sol "UNSOLVABLE" $
do let s = rawSolver sol
SMT.push s
mapM_ (assume s tvs) ps
res <- SMT.check s
SMT.pop s
case res of
SMT.Unsat -> return True
_ -> return False
isNumeric :: Prop -> Bool
isNumeric ty = matchDefault False
$ msum [ is (|=|), is (|>=|), is aFin, is aTrue, andNum ty ]
where
andNum t = anAdd t >>= \(x,y) -> return (isNumeric x && isNumeric y)
is f = f ty >> return True
type TVars = Map TVar SExpr
cryInfNat :: SExpr
cryInfNat = SMT.const "InfNat"
toSMT :: TVars -> Type -> SExpr
toSMT tvs ty = matchDefault (panic "toSMT" [ "Unexpected type", show ty ])
$ msum $ map (\f -> f tvs ty)
[ aInf ~> "cryInf"
, aNat ~> "cryNat"
, aFin ~> "cryFin"
, (|=|) ~> "cryEq"
, (|>=|) ~> "cryGeq"
, aAnd ~> "cryAnd"
, aTrue ~> "cryTrue"
, anAdd ~> "cryAdd"
, (|-|) ~> "crySub"
, aMul ~> "cryMul"
, (|^|) ~> "cryExp"
, (|/|) ~> "cryDiv"
, (|%|) ~> "cryMod"
, aMin ~> "cryMin"
, aMax ~> "cryMax"
, aWidth ~> "cryWidth"
, aLenFromThen ~> "cryLenFromThen"
, aLenFromThenTo ~> "cryLenFromThenTo"
, anError KNum ~> "cryErr"
, anError KProp ~> "cryErrProp"
, aTVar ~> "(unused)"
]
(~>) :: Mk a => (Type -> Match a) -> String -> TVars -> Type -> Match SExpr
(m ~> f) tvs t = m t >>= \a -> return (mk tvs f a)
class Mk t where
mk :: TVars -> String -> t -> SExpr
instance Mk () where
mk _ f _ = SMT.const f
instance Mk Integer where
mk _ f x = SMT.fun f [ SMT.int x ]
instance Mk TVar where
mk tvs _ x = tvs Map.! x
instance Mk Type where
mk tvs f x = SMT.fun f [toSMT tvs x]
instance Mk TCErrorMessage where
mk _ f _ = SMT.fun f []
instance Mk (Type,Type) where
mk tvs f (x,y) = SMT.fun f [ toSMT tvs x, toSMT tvs y]
instance Mk (Type,Type,Type) where
mk tvs f (x,y,z) = SMT.fun f [ toSMT tvs x, toSMT tvs y, toSMT tvs z ]