module Language.Wasm.Script (
runScript,
OnAssertFail
) where
import qualified Data.Map as Map
import qualified Data.Vector as Vector
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TLEncoding
import Numeric.IEEE (identicalIEEE)
import qualified Control.DeepSeq as DeepSeq
import Data.Maybe (fromJust, isNothing)
import Language.Wasm.Parser (
Ident(..),
Script,
ModuleDef(..),
Command(..),
Action(..),
Assertion(..)
)
import qualified Language.Wasm.Interpreter as Interpreter
import qualified Language.Wasm.Validate as Validate
import qualified Language.Wasm.Structure as Struct
import qualified Language.Wasm.Parser as Parser
import qualified Language.Wasm.Lexer as Lexer
import qualified Language.Wasm.Binary as Binary
type OnAssertFail = String -> Assertion -> IO ()
data ScriptState = ScriptState {
store :: Interpreter.Store,
lastModule :: Maybe Interpreter.ModuleInstance,
modules :: Map.Map TL.Text Interpreter.ModuleInstance,
moduleRegistery :: Map.Map TL.Text Interpreter.ModuleInstance
}
emptyState :: ScriptState
emptyState = ScriptState {
store = Interpreter.emptyStore,
lastModule = Nothing,
modules = Map.empty,
moduleRegistery = Map.empty
}
runScript :: OnAssertFail -> Script -> IO ()
runScript onAssertFail script = do
(globI32, globF32, globF64) <- hostGlobals
(st, inst) <- Interpreter.makeHostModule Interpreter.emptyStore [
("print", hostPrint []),
("print_i32", hostPrint [Struct.I32]),
("print_i32_f32", hostPrint [Struct.I32, Struct.F32]),
("print_f64_f64", hostPrint [Struct.F64, Struct.F64]),
("print_f32", hostPrint [Struct.F32]),
("print_f64", hostPrint [Struct.F64]),
("global_i32", globI32),
("global_f32", globF32),
("global_f64", globF64),
("memory", Interpreter.HostMemory $ Struct.Limit 1 (Just 2)),
("table", Interpreter.HostTable $ Struct.Limit 10 (Just 20))
]
go script $ emptyState { store = st, moduleRegistery = Map.singleton "spectest" inst }
where
hostPrint paramTypes = Interpreter.HostFunction (Struct.FuncType paramTypes []) (\args -> return [])
hostGlobals = do
globI32 <- Interpreter.makeMutGlobal $ Interpreter.VI32 666
globF32 <- Interpreter.makeMutGlobal $ Interpreter.VF32 666
globF64 <- Interpreter.makeMutGlobal $ Interpreter.VF64 666
return (Interpreter.HostGlobal globI32, Interpreter.HostGlobal globF32, Interpreter.HostGlobal globF64)
go [] _ = return ()
go (c:cs) st = runCommand st c >>= go cs
addToRegistery :: TL.Text -> Maybe Ident -> ScriptState -> ScriptState
addToRegistery name i st =
case getModule st i of
Just m -> st { moduleRegistery = Map.insert name m $ moduleRegistery st }
Nothing -> error $ "Cannot register module with identifier '" ++ show i ++ "'. No such module"
addToStore :: Maybe Ident -> Interpreter.ModuleInstance -> ScriptState -> ScriptState
addToStore (Just (Ident ident)) m st = st { modules = Map.insert ident m $ modules st }
addToStore Nothing _ st = st
buildImports :: ScriptState -> Interpreter.Imports
buildImports st =
Map.fromList $ concat $ map toImports $ Map.toList $ moduleRegistery st
where
toImports :: (TL.Text, Interpreter.ModuleInstance) -> [((TL.Text, TL.Text), Interpreter.ExternalValue)]
toImports (modName, mod) = map (asImport modName) $ Vector.toList $ Interpreter.exports mod
asImport :: TL.Text -> Interpreter.ExportInstance -> ((TL.Text, TL.Text), Interpreter.ExternalValue)
asImport modName (Interpreter.ExportInstance name val) = ((modName, name), val)
addModule :: Maybe Ident -> Struct.Module -> ScriptState -> IO ScriptState
addModule ident m st =
case Validate.validate m of
Right m -> do
res <- Interpreter.instantiate (store st) (buildImports st) m
case res of
Right (modInst, store') -> return $ addToStore ident modInst $ st { lastModule = Just modInst, store = store' }
Left reason -> error $ "Module instantiation failed dut to invalid module with reason: " ++ show reason
Left reason -> error $ "Module instantiation failed dut to invalid module with reason: " ++ show reason
getModule :: ScriptState -> Maybe Ident -> Maybe Interpreter.ModuleInstance
getModule st (Just (Ident i)) = Map.lookup i (modules st)
getModule st Nothing = lastModule st
asArg :: Struct.Expression -> Interpreter.Value
asArg [Struct.I32Const v] = Interpreter.VI32 v
asArg [Struct.F32Const v] = Interpreter.VF32 v
asArg [Struct.I64Const v] = Interpreter.VI64 v
asArg [Struct.F64Const v] = Interpreter.VF64 v
asArg _ = error "Only const instructions supported as arguments for actions"
runAction :: ScriptState -> Action -> IO (Maybe [Interpreter.Value])
runAction st (Invoke ident name args) = do
case getModule st ident of
Just m -> Interpreter.invokeExport (store st) m name $ map asArg args
Nothing -> error $ "Cannot invoke function on module with identifier '" ++ show ident ++ "'. No such module"
runAction st (Get ident name) = do
case getModule st ident of
Just m -> Interpreter.getGlobalValueByName (store st) m name >>= return . Just . (: [])
Nothing -> error $ "Cannot invoke function on module with identifier '" ++ show ident ++ "'. No such module"
isValueEqual :: Interpreter.Value -> Interpreter.Value -> Bool
isValueEqual (Interpreter.VI32 v1) (Interpreter.VI32 v2) = v1 == v2
isValueEqual (Interpreter.VI64 v1) (Interpreter.VI64 v2) = v1 == v2
isValueEqual (Interpreter.VF32 v1) (Interpreter.VF32 v2) = identicalIEEE v1 v2
isValueEqual (Interpreter.VF64 v1) (Interpreter.VF64 v2) = identicalIEEE v1 v2
isValueEqual _ _ = False
isNaNReturned :: ScriptState -> Action -> Assertion -> IO ()
isNaNReturned st action assert = do
result <- runAction st action
case result of
Just [Interpreter.VF32 v] ->
if isNaN v
then return ()
else onAssertFail ("Expected NaN, but action returned " ++ show v) assert
Just [Interpreter.VF64 v] ->
if isNaN v
then return ()
else onAssertFail ("Expected NaN, but action returned " ++ show v) assert
_ -> onAssertFail ("Expected NaN, but action returned " ++ show result) assert
buildModule :: ModuleDef -> (Maybe Ident, Struct.Module)
buildModule (RawModDef ident m) = (ident, m)
buildModule (TextModDef ident textRep) =
let Right m = Lexer.scanner (TLEncoding.encodeUtf8 textRep) >>= Parser.parseModule in
(ident, m)
buildModule (BinaryModDef ident binaryRep) =
let Right m = Binary.decodeModuleLazy binaryRep in
(ident, m)
checkModuleInvalid :: Struct.Module -> IO ()
checkModuleInvalid _ = return ()
getFailureString :: Validate.ValidationError -> [TL.Text]
getFailureString (Validate.TypeMismatch _ _) = ["type mismatch"]
getFailureString Validate.ResultTypeDoesntMatch = ["type mismatch"]
getFailureString Validate.MoreThanOneMemory = ["multiple memories"]
getFailureString Validate.MoreThanOneTable = ["multiple tables"]
getFailureString Validate.LocalIndexOutOfRange = ["unknown local"]
getFailureString Validate.MemoryIndexOutOfRange = ["unknown memory", "unknown memory 0"]
getFailureString Validate.TableIndexOutOfRange = ["unknown table", "unknown table 0"]
getFailureString Validate.FunctionIndexOutOfRange = ["unknown function", "unknown function 0"]
getFailureString Validate.GlobalIndexOutOfRange = ["unknown global"]
getFailureString Validate.LabelIndexOutOfRange = ["unknown label"]
getFailureString Validate.TypeIndexOutOfRange = ["unknown type"]
getFailureString Validate.MinMoreThanMaxInMemoryLimit = ["size minimum must not be greater than maximum"]
getFailureString Validate.MemoryLimitExceeded = ["memory size must be at most 65536 pages (4GiB)"]
getFailureString Validate.AlignmentOverflow = ["alignment", "alignment must not be larger than natural"]
getFailureString (Validate.DuplicatedExportNames _) = ["duplicate export name"]
getFailureString Validate.InvalidConstantExpr = ["constant expression required"]
getFailureString Validate.InvalidResultArity = ["invalid result arity"]
getFailureString Validate.GlobalIsImmutable = ["global is immutable"]
getFailureString Validate.ImportedGlobalIsNotConst = ["mutable globals cannot be imported"]
getFailureString Validate.ExportedGlobalIsNotConst = ["mutable globals cannot be exported"]
getFailureString Validate.InvalidStartFunctionType = ["start function"]
getFailureString r = [TL.concat ["not implemented ", (TL.pack $ show r)]]
runAssert :: ScriptState -> Assertion -> IO ()
runAssert st assert@(AssertReturn action expected) = do
result <- runAction st action
case result of
Just result -> do
if length result == length expected && (all id $ zipWith isValueEqual result (map asArg expected))
then return ()
else onAssertFail ("Expected " ++ show (map asArg expected) ++ ", but action returned " ++ show result) assert
Nothing -> onAssertFail ("Expected " ++ show (map asArg expected) ++ ", but action returned Trap") assert
runAssert st assert@(AssertReturnCanonicalNaN action) = isNaNReturned st action assert
runAssert st assert@(AssertReturnArithmeticNaN action) = isNaNReturned st action assert
runAssert st assert@(AssertInvalid moduleDef failureString) =
let (_, m) = buildModule moduleDef in
case Validate.validate m of
Right _ -> onAssertFail "Invalid module pass validation" assert
Left reason ->
if failureString `elem` getFailureString reason
then return ()
else
let msg = "Module invalid for other reason. Expected "
++ show failureString
++ ", but actual is "
++ show (getFailureString reason)
in onAssertFail msg assert
runAssert st assert@(AssertMalformed (TextModDef _ textRep) failureString) =
case DeepSeq.force $ Lexer.scanner (TLEncoding.encodeUtf8 textRep) >>= Parser.parseModule of
Right _ -> onAssertFail ("Module parsing should fail with failure string " ++ show failureString) assert
Left _ -> return ()
runAssert st assert@(AssertMalformed (BinaryModDef ident binaryRep) failureString) =
case Binary.decodeModuleLazy binaryRep of
Right _ -> onAssertFail ("Module decoding should fail with failure string " ++ show failureString) assert
Left _ -> return ()
runAssert st assert@(AssertMalformed (RawModDef _ _) failureString) = return ()
runAssert st assert@(AssertUnlinkable moduleDef failureString) =
let (_, m) = buildModule moduleDef in
case Validate.validate m of
Right m -> do
res <- Interpreter.instantiate (store st) (buildImports st) m
case res of
Left err -> return ()
Right _ -> onAssertFail ("Module linking should fail with failure string " ++ show failureString) assert
Left reason -> error $ "Module linking failed dut to invalid module with reason: " ++ show reason
runAssert st assert@(AssertTrap (Left action) failureString) = do
result <- runAction st action
if isNothing result
then return ()
else onAssertFail ("Expected trap, but action returned " ++ show (fromJust result)) assert
runAssert st assert@(AssertTrap (Right moduleDef) failureString) =
let (_, m) = buildModule moduleDef in
case Validate.validate m of
Right m -> do
res <- Interpreter.instantiate (store st) (buildImports st) m
case res of
Left "Start function terminated with trap" -> return ()
_ -> onAssertFail ("Module linking should fail with trap during execution of a start function") assert
Left reason -> error $ "Module linking failed dut to invalid module with reason: " ++ show reason
runAssert st assert@(AssertExhaustion action failureString) = do
result <- runAction st action
if isNothing result
then return ()
else onAssertFail ("Expected exhaustion, but action returned " ++ show (fromJust result)) assert
runCommand :: ScriptState -> Command -> IO ScriptState
runCommand st (ModuleDef moduleDef) =
let (ident, m) = buildModule moduleDef in
addModule ident m st
runCommand st (Register name i) = return $ addToRegistery name i st
runCommand st (Action action) = runAction st action >> return st
runCommand st (Assertion assertion) = runAssert st assertion >> return st
runCommand st _ = return st