module Cryptol.Symbolic where
import Control.Monad (replicateM, when, zipWithM)
import Data.List (transpose, intercalate)
import qualified Data.Map as Map
import qualified Control.Exception as X
import qualified Data.SBV.Dynamic as SBV
import qualified Cryptol.ModuleSystem as M hiding (getPrimMap)
import qualified Cryptol.ModuleSystem.Env as M
import qualified Cryptol.ModuleSystem.Base as M
import qualified Cryptol.ModuleSystem.Monad as M
import Cryptol.Symbolic.Prims
import Cryptol.Symbolic.Value
import qualified Cryptol.Eval.Value as Eval
import qualified Cryptol.Eval.Type (evalValType, evalNumType)
import qualified Cryptol.Eval.Env (EvalEnv(..))
import Cryptol.TypeCheck.AST
import Cryptol.TypeCheck.Solver.InfNat (Nat'(..))
import Cryptol.Utils.Ident (Ident)
import Cryptol.Utils.PP
import Cryptol.Utils.Panic(panic)
import Prelude ()
import Prelude.Compat
proverConfigs :: [(String, SBV.SMTConfig)]
proverConfigs =
[ ("cvc4" , SBV.cvc4 )
, ("yices" , SBV.yices )
, ("z3" , SBV.z3 )
, ("boolector", SBV.boolector)
, ("mathsat" , SBV.mathSAT )
, ("abc" , SBV.abc )
, ("offline" , SBV.defaultSMTCfg )
, ("any" , SBV.defaultSMTCfg )
]
proverNames :: [String]
proverNames = map fst proverConfigs
lookupProver :: String -> SBV.SMTConfig
lookupProver s =
case lookup s proverConfigs of
Just cfg -> cfg
Nothing -> panic "Cryptol.Symbolic" [ "invalid prover: " ++ s ]
type SatResult = [(Type, Expr, Eval.Value)]
data SatNum = AllSat | SomeSat Int
deriving (Show)
data QueryType = SatQuery SatNum | ProveQuery
deriving (Show)
data ProverCommand = ProverCommand {
pcQueryType :: QueryType
, pcProverName :: String
, pcVerbose :: Bool
, pcExtraDecls :: [DeclGroup]
, pcSmtFile :: Maybe FilePath
, pcExpr :: Expr
, pcSchema :: Schema
}
data ProverResult = AllSatResult [SatResult]
| ThmResult [Type]
| EmptyResult
| ProverError String
satSMTResults :: SBV.SatResult -> [SBV.SMTResult]
satSMTResults (SBV.SatResult r) = [r]
allSatSMTResults :: SBV.AllSatResult -> [SBV.SMTResult]
allSatSMTResults (SBV.AllSatResult (_, rs)) = rs
thmSMTResults :: SBV.ThmResult -> [SBV.SMTResult]
thmSMTResults (SBV.ThmResult r) = [r]
proverError :: String -> M.ModuleCmd ProverResult
proverError msg modEnv = return (Right (ProverError msg, modEnv), [])
satProve :: ProverCommand -> M.ModuleCmd ProverResult
satProve ProverCommand {..} = protectStack proverError $ \modEnv ->
M.runModuleM modEnv $ do
let (isSat, mSatNum) = case pcQueryType of
ProveQuery -> (False, Nothing)
SatQuery sn -> case sn of
SomeSat n -> (True, Just n)
AllSat -> (True, Nothing)
let extDgs = allDeclGroups modEnv ++ pcExtraDecls
provers <-
case pcProverName of
"any" -> M.io SBV.sbvAvailableSolvers
_ -> return [(lookupProver pcProverName) { SBV.smtFile = pcSmtFile }]
let provers' = [ p { SBV.timing = pcVerbose, SBV.verbose = pcVerbose } | p <- provers ]
let tyFn = if isSat then existsFinType else forallFinType
let runProver fn tag e = do
case provers of
[prover] -> do
when pcVerbose $ M.io $
putStrLn $ "Trying proof with " ++ show prover
res <- M.io (fn prover e)
when pcVerbose $ M.io $
putStrLn $ "Got result from " ++ show prover
return (tag res)
_ ->
return [ SBV.ProofError
prover
[":sat with option prover=any requires option satNum=1"]
| prover <- provers ]
runProvers fn tag e = do
when pcVerbose $ M.io $
putStrLn $ "Trying proof with " ++
intercalate ", " (map show provers)
(firstProver, res) <- M.io (fn provers' e)
when pcVerbose $ M.io $
putStrLn $ "Got result from " ++ show firstProver
return (tag res)
let runFn = case pcQueryType of
ProveQuery -> runProvers SBV.proveWithAny thmSMTResults
SatQuery sn -> case sn of
SomeSat 1 -> runProvers SBV.satWithAny satSMTResults
_ -> runProver SBV.allSatWith allSatSMTResults
case predArgTypes pcSchema of
Left msg -> return (ProverError msg)
Right ts -> do when pcVerbose $ M.io $ putStrLn "Simulating..."
let env = evalDecls mempty extDgs
let v = evalExpr env pcExpr
prims <- M.getPrimMap
results' <- runFn $ do
args <- mapM tyFn ts
b <- return $! fromVBit (foldl fromVFun v args)
return b
let results = maybe results' (\n -> take n results') mSatNum
esatexprs <- case results of
(SBV.Satisfiable {} : _) -> do
tevss <- mapM mkTevs results
return $ AllSatResult tevss
where
mkTevs result =
let Right (_, cws) = SBV.getModel result
(vs, _) = parseValues ts cws
sattys = unFinType <$> ts
satexprs = zipWithM (Eval.toExpr prims) sattys vs
in case zip3 sattys <$> satexprs <*> pure vs of
Nothing ->
panic "Cryptol.Symbolic.sat"
[ "unable to make assignment into expression" ]
Just tevs -> return $ tevs
[SBV.Unsatisfiable {}] ->
return $ ThmResult (unFinType <$> ts)
[] -> return $ ThmResult (unFinType <$> ts)
_ -> return $ ProverError (rshow results)
where rshow | isSat = show . SBV.AllSatResult . (boom,)
| otherwise = show . SBV.ThmResult . head
boom = panic "Cryptol.Symbolic.sat"
[ "attempted to evaluate bogus boolean for pretty-printing" ]
return esatexprs
satProveOffline :: ProverCommand -> M.ModuleCmd (Either String String)
satProveOffline ProverCommand {..} =
protectStack (\msg modEnv -> return (Right (Left msg, modEnv), [])) $ \modEnv -> do
let isSat = case pcQueryType of
ProveQuery -> False
SatQuery _ -> True
let extDgs = allDeclGroups modEnv ++ pcExtraDecls
let tyFn = if isSat then existsFinType else forallFinType
case predArgTypes pcSchema of
Left msg -> return (Right (Left msg, modEnv), [])
Right ts ->
do when pcVerbose $ putStrLn "Simulating..."
let env = evalDecls mempty extDgs
let v = evalExpr env pcExpr
smtlib <- SBV.compileToSMTLib SBV.SMTLib2 isSat $ do
args <- mapM tyFn ts
b <- return $! fromVBit (foldl fromVFun v args)
return b
return (Right (Right smtlib, modEnv), [])
protectStack :: (String -> M.ModuleCmd a)
-> M.ModuleCmd a
-> M.ModuleCmd a
protectStack mkErr cmd modEnv =
X.catchJust isOverflow (cmd modEnv) handler
where isOverflow X.StackOverflow = Just ()
isOverflow _ = Nothing
msg = "Symbolic evaluation failed to terminate."
handler () = mkErr msg modEnv
parseValues :: [FinType] -> [SBV.CW] -> ([Eval.Value], [SBV.CW])
parseValues [] cws = ([], cws)
parseValues (t : ts) cws = (v : vs, cws'')
where (v, cws') = parseValue t cws
(vs, cws'') = parseValues ts cws'
parseValue :: FinType -> [SBV.CW] -> (Eval.Value, [SBV.CW])
parseValue FTBit [] = panic "Cryptol.Symbolic.parseValue" [ "empty FTBit" ]
parseValue FTBit (cw : cws) = (Eval.VBit (SBV.cwToBool cw), cws)
parseValue (FTSeq 0 FTBit) cws = (Eval.VWord (Eval.BV 0 0), cws)
parseValue (FTSeq n FTBit) cws =
case SBV.genParse (SBV.KBounded False n) cws of
Just (x, cws') -> (Eval.VWord (Eval.BV (toInteger n) x), cws')
Nothing -> (Eval.VSeq True vs, cws')
where (vs, cws') = parseValues (replicate n FTBit) cws
parseValue (FTSeq n t) cws = (Eval.VSeq False vs, cws')
where (vs, cws') = parseValues (replicate n t) cws
parseValue (FTTuple ts) cws = (Eval.VTuple vs, cws')
where (vs, cws') = parseValues ts cws
parseValue (FTRecord fs) cws = (Eval.VRecord (zip ns vs), cws')
where (ns, ts) = unzip fs
(vs, cws') = parseValues ts cws
allDeclGroups :: M.ModuleEnv -> [DeclGroup]
allDeclGroups = concatMap mDecls . M.loadedModules
data FinType
= FTBit
| FTSeq Int FinType
| FTTuple [FinType]
| FTRecord [(Ident, FinType)]
numType :: Integer -> Maybe Int
numType n
| 0 <= n && n <= toInteger (maxBound :: Int) = Just (fromInteger n)
| otherwise = Nothing
finType :: TValue -> Maybe FinType
finType ty =
case ty of
TVBit -> Just FTBit
TVSeq n t -> FTSeq <$> numType n <*> finType t
TVTuple ts -> FTTuple <$> traverse finType ts
TVRec fields -> FTRecord <$> traverse (traverseSnd finType) fields
_ -> Nothing
unFinType :: FinType -> Type
unFinType fty =
case fty of
FTBit -> tBit
FTSeq l ety -> tSeq (tNum l) (unFinType ety)
FTTuple ftys -> tTuple (unFinType <$> ftys)
FTRecord fs -> tRec (zip fns tys)
where
fns = fst <$> fs
tys = unFinType . snd <$> fs
predArgTypes :: Schema -> Either String [FinType]
predArgTypes schema@(Forall ts ps ty)
| null ts && null ps =
case go (Cryptol.Eval.Type.evalValType mempty ty) of
Just fts -> Right fts
Nothing -> Left $ "Not a valid predicate type:\n" ++ show (pp schema)
| otherwise = Left $ "Not a monomorphic type:\n" ++ show (pp schema)
where
go :: TValue -> Maybe [FinType]
go TVBit = Just []
go (TVFun ty1 ty2) = (:) <$> finType ty1 <*> go ty2
go _ = Nothing
forallFinType :: FinType -> SBV.Symbolic Value
forallFinType ty =
case ty of
FTBit -> VBit <$> forallSBool_
FTSeq 0 FTBit -> return $ VWord (literalSWord 0 0)
FTSeq n FTBit -> VWord <$> (forallBV_ n)
FTSeq n t -> VSeq False <$> replicateM n (forallFinType t)
FTTuple ts -> VTuple <$> mapM forallFinType ts
FTRecord fs -> VRecord <$> mapM (traverseSnd forallFinType) fs
existsFinType :: FinType -> SBV.Symbolic Value
existsFinType ty =
case ty of
FTBit -> VBit <$> existsSBool_
FTSeq 0 FTBit -> return $ VWord (literalSWord 0 0)
FTSeq n FTBit -> VWord <$> existsBV_ n
FTSeq n t -> VSeq False <$> replicateM n (existsFinType t)
FTTuple ts -> VTuple <$> mapM existsFinType ts
FTRecord fs -> VRecord <$> mapM (traverseSnd existsFinType) fs
data Env = Env
{ envVars :: Map.Map Name Value
, envTypes :: Map.Map TVar (Either Nat' TValue)
}
instance Monoid Env where
mempty = Env
{ envVars = Map.empty
, envTypes = Map.empty
}
mappend l r = Env
{ envVars = Map.union (envVars l) (envVars r)
, envTypes = Map.union (envTypes l) (envTypes r)
}
bindVar :: (Name, Value) -> Env -> Env
bindVar (n, thunk) env = env { envVars = Map.insert n thunk (envVars env) }
lookupVar :: Name -> Env -> Maybe Value
lookupVar n env = Map.lookup n (envVars env)
bindType :: TVar -> (Either Nat' TValue) -> Env -> Env
bindType p ty env = env { envTypes = Map.insert p ty (envTypes env) }
lookupType :: TVar -> Env -> Maybe (Either Nat' TValue)
lookupType p env = Map.lookup p (envTypes env)
evalExpr :: Env -> Expr -> Value
evalExpr env expr =
case expr of
EList es ty -> VSeq (tIsBit ty) (map eval es)
ETuple es -> VTuple (map eval es)
ERec fields -> VRecord [ (f, eval e) | (f, e) <- fields ]
ESel e sel -> evalSel sel (eval e)
EIf b e1 e2 -> iteValue (fromVBit (eval b)) (eval e1) (eval e2)
EComp ty e mss -> evalComp env (evalValType env ty) e mss
EVar n -> case lookupVar n env of
Just x -> x
_ -> panic "Cryptol.Symbolic.evalExpr" [ "Variable " ++ show n ++ " not found" ]
ETAbs tv e -> case tpKind tv of
KType -> VPoly $ \ty -> evalExpr (bindType (tpVar tv) (Right ty) env) e
KNum -> VNumPoly $ \n -> evalExpr (bindType (tpVar tv) (Left n) env) e
k -> panic "[Symbolic] evalExpr" ["invalid kind on type abstraction", show k]
ETApp e ty -> case eval e of
VPoly f -> f (evalValType env ty)
VNumPoly f -> f (evalNumType env ty)
_ -> panic "[Symbolic] evalExpr"
[ "expected a polymorphic value"
, show e, show ty
]
EApp e1 e2 -> fromVFun (eval e1) (eval e2)
EAbs n _ty e -> VFun $ \x -> evalExpr (bindVar (n, x) env) e
EProofAbs _prop e -> eval e
EProofApp e -> eval e
ECast e _ty -> eval e
EWhere e ds -> evalExpr (evalDecls env ds) e
where
eval e = evalExpr env e
evalValType :: Env -> Type -> TValue
evalValType env ty = Cryptol.Eval.Type.evalValType env' ty
where env' = Cryptol.Eval.Env.EvalEnv Map.empty (envTypes env)
evalNumType :: Env -> Type -> Nat'
evalNumType env ty = Cryptol.Eval.Type.evalNumType env' ty
where env' = Cryptol.Eval.Env.EvalEnv Map.empty (envTypes env)
evalSel :: Selector -> Value -> Value
evalSel sel v =
case sel of
TupleSel n _ ->
case v of
VTuple xs -> xs !! n
VSeq b xs -> VSeq b (map (evalSel sel) xs)
VStream xs -> VStream (map (evalSel sel) xs)
VFun f -> VFun (\x -> evalSel sel (f x))
_ -> panic "Cryptol.Symbolic.evalSel" [ "Tuple selector applied to incompatible type" ]
RecordSel n _ ->
case v of
VRecord bs -> case lookup n bs of
Just x -> x
_ -> panic "Cryptol.Symbolic.evalSel" [ "Selector " ++ show n ++ " not found" ]
VSeq b xs -> VSeq b (map (evalSel sel) xs)
VStream xs -> VStream (map (evalSel sel) xs)
VFun f -> VFun (\x -> evalSel sel (f x))
_ -> panic "Cryptol.Symbolic.evalSel" [ "Record selector applied to non-record" ]
ListSel n _ -> case v of
VWord s -> VBit (SBV.svTestBit s i)
where i = SBV.intSizeOf s 1 n
_ -> fromSeq v !! n
evalDecls :: Env -> [DeclGroup] -> Env
evalDecls = foldl evalDeclGroup
evalDeclGroup :: Env -> DeclGroup -> Env
evalDeclGroup env dg =
case dg of
NonRecursive d -> bindVar (evalDecl env d) env
Recursive ds -> let env' = foldr bindVar env lazyBindings
bindings = map (evalDecl env') ds
lazyBindings = [ (qname, copyBySchema env (dSignature d) v)
| (d, (qname, v)) <- zip ds bindings ]
in env'
evalDecl :: Env -> Decl -> (Name, Value)
evalDecl env d = (dName d, body)
where
body = case dDefinition d of
DExpr e -> evalExpr env e
DPrim -> evalPrim d
copyBySchema :: Env -> Schema -> Value -> Value
copyBySchema env0 (Forall params _props ty) = go params env0
where
go [] env v = copyByType env (evalValType env ty) v
go (p : ps) env v =
case tpKind p of
KType -> VPoly (\t -> go ps (bindType (tpVar p) (Right t) env) (fromVPoly v t))
KNum -> VNumPoly (\t -> go ps (bindType (tpVar p) (Left t) env) (fromVNumPoly v t))
k -> panic "[Eval] copyBySchema" ["invalid kind on type abstraction", show k]
copyByType :: Env -> TValue -> Value -> Value
copyByType env ty v =
case ty of
TVBit -> VBit (fromVBit v)
TVSeq _ ety -> VSeq (isTBit ety) (fromSeq v)
TVStream _ -> VStream (fromSeq v)
TVFun _ bty -> VFun (\x -> copyByType env bty (fromVFun v x))
TVTuple tys -> VTuple (zipWith (copyByType env) tys (fromVTuple v))
TVRec fs -> VRecord [ (f, copyByType env t (lookupRecord f v)) | (f, t) <- fs ]
evalComp :: Env -> TValue -> Expr -> [[Match]] -> Value
evalComp env seqty body ms =
case Eval.isTSeq seqty of
Just (len, el) -> toSeq len el [ evalExpr e body | e <- envs ]
Nothing -> evalPanic "Cryptol.Eval" ["evalComp given a non sequence", show seqty]
where
benvs = map (branchEnvs env) ms
allBranches es = length es == length ms
slices = takeWhile allBranches (transpose benvs)
envs = map mconcat slices
branchEnvs :: Env -> [Match] -> [Env]
branchEnvs env matches =
case matches of
[] -> [env]
m : ms -> do env' <- evalMatch env m
branchEnvs env' ms
evalMatch :: Env -> Match -> [Env]
evalMatch env m = case m of
From n _ty expr -> [ bindVar (n, v) env | v <- fromSeq (evalExpr env expr) ]
Let d -> [ bindVar (evalDecl env d) env ]