{-# LANGUAGE OverloadedStrings #-}
module Language.Wasm.Script (
    runScript,
    OnAssertFail
) where

import qualified Data.Map as Map
import qualified Data.Vector as Vector
import qualified Data.Text.Lazy as TL
import qualified Data.Text.Lazy.Encoding as TLEncoding
import Numeric.IEEE (identicalIEEE)
import qualified Control.DeepSeq as DeepSeq
import Data.Maybe (fromJust, isNothing)

import Language.Wasm.Parser (
        Ident(..),
        Script,
        ModuleDef(..),
        Command(..),
        Action(..),
        Assertion(..)
    )

import qualified Language.Wasm.Interpreter as Interpreter
import qualified Language.Wasm.Validate as Validate
import qualified Language.Wasm.Structure as Struct
import qualified Language.Wasm.Parser as Parser
import qualified Language.Wasm.Lexer as Lexer
import qualified Language.Wasm.Binary as Binary

type OnAssertFail = String -> Assertion -> IO ()

data ScriptState = ScriptState {
    store :: Interpreter.Store,
    lastModule :: Maybe Interpreter.ModuleInstance,
    modules :: Map.Map TL.Text Interpreter.ModuleInstance,
    moduleRegistery :: Map.Map TL.Text Interpreter.ModuleInstance
}

emptyState :: ScriptState
emptyState = ScriptState {
    store = Interpreter.emptyStore,
    lastModule = Nothing,
    modules = Map.empty,
    moduleRegistery = Map.empty
}

runScript :: OnAssertFail -> Script -> IO ()
runScript onAssertFail script = do
    (globI32, globF32, globF64) <- hostGlobals
    (st, inst) <- Interpreter.makeHostModule Interpreter.emptyStore [
            ("print", hostPrint []),
            ("print_i32", hostPrint [Struct.I32]),
            ("print_i32_f32", hostPrint [Struct.I32, Struct.F32]),
            ("print_f64_f64", hostPrint [Struct.F64, Struct.F64]),
            ("print_f32", hostPrint [Struct.F32]),
            ("print_f64", hostPrint [Struct.F64]),
            ("global_i32", globI32),
            ("global_f32", globF32),
            ("global_f64", globF64),
            ("memory", Interpreter.HostMemory $ Struct.Limit 1 (Just 2)),
            ("table", Interpreter.HostTable $ Struct.Limit 10 (Just 20))
        ]
    go script $ emptyState { store = st, moduleRegistery = Map.singleton "spectest" inst }
    where
        hostPrint paramTypes = Interpreter.HostFunction (Struct.FuncType paramTypes []) (\args -> return [])
        hostGlobals = do
            globI32 <- Interpreter.makeMutGlobal $ Interpreter.VI32 666
            globF32 <- Interpreter.makeMutGlobal $ Interpreter.VF32 666
            globF64 <- Interpreter.makeMutGlobal $ Interpreter.VF64 666
            return (Interpreter.HostGlobal globI32, Interpreter.HostGlobal globF32, Interpreter.HostGlobal globF64)

        go [] _ = return ()
        go (c:cs) st = runCommand st c >>= go cs
        
        addToRegistery :: TL.Text -> Maybe Ident -> ScriptState -> ScriptState
        addToRegistery name i st =
            case getModule st i of
                Just m -> st { moduleRegistery = Map.insert name m $ moduleRegistery st }
                Nothing -> error $ "Cannot register module with identifier '" ++ show i  ++ "'. No such module"

        addToStore :: Maybe Ident -> Interpreter.ModuleInstance -> ScriptState -> ScriptState
        addToStore (Just (Ident ident)) m st = st { modules = Map.insert ident m $ modules st }
        addToStore Nothing _ st = st

        buildImports :: ScriptState -> Interpreter.Imports
        buildImports st =
            Map.fromList $ concat $ map toImports $ Map.toList $ moduleRegistery st
            where
                toImports :: (TL.Text, Interpreter.ModuleInstance) -> [((TL.Text, TL.Text), Interpreter.ExternalValue)]
                toImports (modName, mod) = map (asImport modName) $ Vector.toList $ Interpreter.exports mod
                asImport :: TL.Text -> Interpreter.ExportInstance -> ((TL.Text, TL.Text), Interpreter.ExternalValue)
                asImport modName (Interpreter.ExportInstance name val) = ((modName, name), val)

        addModule :: Maybe Ident -> Struct.Module -> ScriptState -> IO ScriptState
        addModule ident m st =
            case Validate.validate m of
                Right m -> do
                    res <- Interpreter.instantiate (store st) (buildImports st) m
                    case res of
                        Right (modInst, store') -> return $ addToStore ident modInst $ st { lastModule = Just modInst, store = store' }
                        Left reason -> error $ "Module instantiation failed dut to invalid module with reason: " ++ show reason
                Left reason -> error $ "Module instantiation failed dut to invalid module with reason: " ++ show reason
        
        getModule :: ScriptState -> Maybe Ident -> Maybe Interpreter.ModuleInstance
        getModule st (Just (Ident i)) = Map.lookup i (modules st)
        getModule st Nothing = lastModule st

        asArg :: Struct.Expression -> Interpreter.Value
        asArg [Struct.I32Const v] = Interpreter.VI32 v
        asArg [Struct.F32Const v] = Interpreter.VF32 v
        asArg [Struct.I64Const v] = Interpreter.VI64 v
        asArg [Struct.F64Const v] = Interpreter.VF64 v
        asArg _                   = error "Only const instructions supported as arguments for actions"

        runAction :: ScriptState -> Action -> IO (Maybe [Interpreter.Value])
        runAction st (Invoke ident name args) = do
            case getModule st ident of
                Just m -> Interpreter.invokeExport (store st) m name $ map asArg args
                Nothing -> error $ "Cannot invoke function on module with identifier '" ++ show ident  ++ "'. No such module"
        runAction st (Get ident name) = do
            case getModule st ident of
                Just m -> Interpreter.getGlobalValueByName (store st) m name >>= return . Just . (: [])
                Nothing -> error $ "Cannot invoke function on module with identifier '" ++ show ident  ++ "'. No such module"

        isValueEqual :: Interpreter.Value -> Interpreter.Value -> Bool
        isValueEqual (Interpreter.VI32 v1) (Interpreter.VI32 v2) = v1 == v2
        isValueEqual (Interpreter.VI64 v1) (Interpreter.VI64 v2) = v1 == v2
        isValueEqual (Interpreter.VF32 v1) (Interpreter.VF32 v2) = identicalIEEE v1 v2
        isValueEqual (Interpreter.VF64 v1) (Interpreter.VF64 v2) = identicalIEEE v1 v2
        isValueEqual _ _ = False

        isNaNReturned :: ScriptState -> Action -> Assertion -> IO ()
        isNaNReturned st action assert = do
            result <- runAction st action
            case result of
                Just [Interpreter.VF32 v] ->
                    if isNaN v
                    then return ()
                    else onAssertFail ("Expected NaN, but action returned " ++ show v) assert
                Just [Interpreter.VF64 v] ->
                    if isNaN v
                    then return ()
                    else onAssertFail ("Expected NaN, but action returned " ++ show v) assert
                _ -> onAssertFail ("Expected NaN, but action returned " ++ show result) assert
        
        buildModule :: ModuleDef -> (Maybe Ident, Struct.Module)
        buildModule (RawModDef ident m) = (ident, m)
        buildModule (TextModDef ident textRep) =
            let Right m = Lexer.scanner (TLEncoding.encodeUtf8 textRep) >>= Parser.parseModule in
            (ident, m)
        buildModule (BinaryModDef ident binaryRep) =
            let Right m = Binary.decodeModuleLazy binaryRep in
            (ident, m)

        checkModuleInvalid :: Struct.Module -> IO ()
        checkModuleInvalid _ = return ()

        getFailureString :: Validate.ValidationError -> [TL.Text]
        getFailureString (Validate.TypeMismatch _ _) = ["type mismatch"]
        getFailureString Validate.ResultTypeDoesntMatch = ["type mismatch"]
        getFailureString Validate.MoreThanOneMemory = ["multiple memories"]
        getFailureString Validate.MoreThanOneTable = ["multiple tables"]
        getFailureString Validate.LocalIndexOutOfRange = ["unknown local"]
        getFailureString Validate.MemoryIndexOutOfRange = ["unknown memory", "unknown memory 0"]
        getFailureString Validate.TableIndexOutOfRange = ["unknown table", "unknown table 0"]
        getFailureString Validate.FunctionIndexOutOfRange = ["unknown function", "unknown function 0"]
        getFailureString Validate.GlobalIndexOutOfRange = ["unknown global"]
        getFailureString Validate.LabelIndexOutOfRange = ["unknown label"]
        getFailureString Validate.TypeIndexOutOfRange = ["unknown type"]
        getFailureString Validate.MinMoreThanMaxInMemoryLimit = ["size minimum must not be greater than maximum"]
        getFailureString Validate.MemoryLimitExceeded = ["memory size must be at most 65536 pages (4GiB)"]
        getFailureString Validate.AlignmentOverflow = ["alignment", "alignment must not be larger than natural"]
        getFailureString (Validate.DuplicatedExportNames _) = ["duplicate export name"]
        getFailureString Validate.InvalidConstantExpr = ["constant expression required"]
        getFailureString Validate.InvalidResultArity = ["invalid result arity"]
        getFailureString Validate.GlobalIsImmutable = ["global is immutable"]
        getFailureString Validate.ImportedGlobalIsNotConst = ["mutable globals cannot be imported"]
        getFailureString Validate.ExportedGlobalIsNotConst = ["mutable globals cannot be exported"]
        getFailureString Validate.InvalidStartFunctionType = ["start function"]
        getFailureString r = [TL.concat ["not implemented ", (TL.pack $ show r)]]

        runAssert :: ScriptState -> Assertion -> IO ()
        runAssert st assert@(AssertReturn action expected) = do
            result <- runAction st action
            case result of
                Just result -> do
                    if length result == length expected && (all id $ zipWith isValueEqual result (map asArg expected))
                    then return ()
                    else onAssertFail ("Expected " ++ show (map asArg expected) ++ ", but action returned " ++ show result) assert
                Nothing -> onAssertFail ("Expected " ++ show (map asArg expected) ++ ", but action returned Trap") assert
        runAssert st assert@(AssertReturnCanonicalNaN action) = isNaNReturned st action assert
        runAssert st assert@(AssertReturnArithmeticNaN action) = isNaNReturned st action assert
        runAssert st assert@(AssertInvalid moduleDef failureString) =
            let (_, m) = buildModule moduleDef in
            case Validate.validate m of
                Right _ -> onAssertFail "Invalid module pass validation" assert
                Left reason ->
                    if failureString `elem` getFailureString reason
                    then return ()
                    else
                        let msg = "Module invalid for other reason. Expected "
                                ++ show failureString
                                ++ ", but actual is "
                                ++ show (getFailureString reason)
                        in onAssertFail msg assert
        runAssert st assert@(AssertMalformed (TextModDef _ textRep) failureString) =
            case DeepSeq.force $ Lexer.scanner (TLEncoding.encodeUtf8 textRep) >>= Parser.parseModule of
                Right _ -> onAssertFail ("Module parsing should fail with failure string " ++ show failureString) assert
                Left _ -> return ()
        runAssert st assert@(AssertMalformed (BinaryModDef ident binaryRep) failureString) =
            case Binary.decodeModuleLazy binaryRep of
                Right _ -> onAssertFail ("Module decoding should fail with failure string " ++ show failureString) assert
                Left _ -> return ()
        runAssert st assert@(AssertMalformed (RawModDef _ _) failureString) = return ()
        runAssert st assert@(AssertUnlinkable moduleDef failureString) =
            let (_, m) = buildModule moduleDef in
            case Validate.validate m of
                Right m -> do
                    res <- Interpreter.instantiate (store st) (buildImports st) m
                    case res of
                        Left err -> return ()
                        Right _ -> onAssertFail ("Module linking should fail with failure string " ++ show failureString) assert
                Left reason -> error $ "Module linking failed dut to invalid module with reason: " ++ show reason
        runAssert st assert@(AssertTrap (Left action) failureString) = do
            result <- runAction st action
            if isNothing result
            then return ()
            else onAssertFail ("Expected trap, but action returned " ++ show (fromJust result)) assert
        runAssert st assert@(AssertTrap (Right moduleDef) failureString) =
            let (_, m) = buildModule moduleDef in
            case Validate.validate m of
                Right m -> do
                    res <- Interpreter.instantiate (store st) (buildImports st) m
                    case res of
                        Left "Start function terminated with trap" -> return ()
                        _ -> onAssertFail ("Module linking should fail with trap during execution of a start function") assert
                Left reason -> error $ "Module linking failed dut to invalid module with reason: " ++ show reason
        runAssert st assert@(AssertExhaustion action failureString) = do
            result <- runAction st action
            if isNothing result
            then return ()
            else onAssertFail ("Expected exhaustion, but action returned " ++ show (fromJust result)) assert

        runCommand :: ScriptState -> Command -> IO ScriptState
        runCommand st (ModuleDef moduleDef) =
            let (ident, m) = buildModule moduleDef in
            addModule ident m st
        runCommand st (Register name i) = return $ addToRegistery name i st
        runCommand st (Action action) = runAction st action >> return st
        runCommand st (Assertion assertion) = runAssert st assertion >> return st
        runCommand st _ = return st