/* * Copyright 2017 WebAssembly Community Group participants * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef wasm_ir_module_h #define wasm_ir_module_h #include "ir/find_all.h" #include "ir/manipulation.h" #include "ir/properties.h" #include "pass.h" #include "support/unique_deferring_queue.h" #include "wasm.h" namespace wasm { namespace ModuleUtils { inline Function* copyFunction(Function* func, Module& out) { auto* ret = new Function(); ret->name = func->name; ret->sig = func->sig; ret->vars = func->vars; ret->localNames = func->localNames; ret->localIndices = func->localIndices; ret->debugLocations = func->debugLocations; ret->body = ExpressionManipulator::copy(func->body, out); ret->module = func->module; ret->base = func->base; // TODO: copy Stack IR assert(!func->stackIR); out.addFunction(ret); return ret; } inline Global* copyGlobal(Global* global, Module& out) { auto* ret = new Global(); ret->name = global->name; ret->type = global->type; ret->mutable_ = global->mutable_; ret->module = global->module; ret->base = global->base; if (global->imported()) { ret->init = nullptr; } else { ret->init = ExpressionManipulator::copy(global->init, out); } out.addGlobal(ret); return ret; } inline Event* copyEvent(Event* event, Module& out) { auto* ret = new Event(); ret->name = event->name; ret->attribute = event->attribute; ret->sig = event->sig; out.addEvent(ret); return ret; } inline void copyModule(const Module& in, Module& out) { // we use names throughout, not raw pointers, so simple copying is fine // for everything *but* expressions for (auto& curr : in.exports) { out.addExport(new Export(*curr)); } for (auto& curr : in.functions) { copyFunction(curr.get(), out); } for (auto& curr : in.globals) { copyGlobal(curr.get(), out); } for (auto& curr : in.events) { copyEvent(curr.get(), out); } out.table = in.table; for (auto& segment : out.table.segments) { segment.offset = ExpressionManipulator::copy(segment.offset, out); } out.memory = in.memory; for (auto& segment : out.memory.segments) { segment.offset = ExpressionManipulator::copy(segment.offset, out); } out.start = in.start; out.userSections = in.userSections; out.debugInfoFileNames = in.debugInfoFileNames; } inline void clearModule(Module& wasm) { wasm.~Module(); new (&wasm) Module; } // Renaming // Rename functions along with all their uses. // Note that for this to work the functions themselves don't necessarily need // to exist. For example, it is possible to remove a given function and then // call this redirect all of its uses. template inline void renameFunctions(Module& wasm, T& map) { // Update the function itself. for (auto& pair : map) { if (Function* F = wasm.getFunctionOrNull(pair.first)) { assert(!wasm.getFunctionOrNull(pair.second) || F->name == pair.second); F->name = pair.second; } } wasm.updateMaps(); // Update other global things. auto maybeUpdate = [&](Name& name) { auto iter = map.find(name); if (iter != map.end()) { name = iter->second; } }; maybeUpdate(wasm.start); for (auto& segment : wasm.table.segments) { for (auto& name : segment.data) { maybeUpdate(name); } } for (auto& exp : wasm.exports) { if (exp->kind == ExternalKind::Function) { maybeUpdate(exp->value); } } // Update call instructions. for (auto& func : wasm.functions) { // TODO: parallelize if (!func->imported()) { FindAll calls(func->body); for (auto* call : calls.list) { maybeUpdate(call->target); } } } } inline void renameFunction(Module& wasm, Name oldName, Name newName) { std::map map; map[oldName] = newName; renameFunctions(wasm, map); } // Convenient iteration over imported/non-imported module elements template inline void iterImportedMemories(Module& wasm, T visitor) { if (wasm.memory.exists && wasm.memory.imported()) { visitor(&wasm.memory); } } template inline void iterDefinedMemories(Module& wasm, T visitor) { if (wasm.memory.exists && !wasm.memory.imported()) { visitor(&wasm.memory); } } template inline void iterImportedTables(Module& wasm, T visitor) { if (wasm.table.exists && wasm.table.imported()) { visitor(&wasm.table); } } template inline void iterDefinedTables(Module& wasm, T visitor) { if (wasm.table.exists && !wasm.table.imported()) { visitor(&wasm.table); } } template inline void iterImportedGlobals(Module& wasm, T visitor) { for (auto& import : wasm.globals) { if (import->imported()) { visitor(import.get()); } } } template inline void iterDefinedGlobals(Module& wasm, T visitor) { for (auto& import : wasm.globals) { if (!import->imported()) { visitor(import.get()); } } } template inline void iterImportedFunctions(Module& wasm, T visitor) { for (auto& import : wasm.functions) { if (import->imported()) { visitor(import.get()); } } } template inline void iterDefinedFunctions(Module& wasm, T visitor) { for (auto& import : wasm.functions) { if (!import->imported()) { visitor(import.get()); } } } template inline void iterImportedEvents(Module& wasm, T visitor) { for (auto& import : wasm.events) { if (import->imported()) { visitor(import.get()); } } } template inline void iterDefinedEvents(Module& wasm, T visitor) { for (auto& import : wasm.events) { if (!import->imported()) { visitor(import.get()); } } } template inline void iterImports(Module& wasm, T visitor) { iterImportedMemories(wasm, visitor); iterImportedTables(wasm, visitor); iterImportedGlobals(wasm, visitor); iterImportedFunctions(wasm, visitor); iterImportedEvents(wasm, visitor); } // Helper class for performing an operation on all the functions in the module, // in parallel, with an Info object for each one that can contain results of // some computation that the operation performs. // The operation performend should not modify the wasm module in any way. // TODO: enforce this template struct ParallelFunctionAnalysis { Module& wasm; typedef std::map Map; Map map; typedef std::function Func; ParallelFunctionAnalysis(Module& wasm, Func work) : wasm(wasm) { // Fill in map, as we operate on it in parallel (each function to its own // entry). for (auto& func : wasm.functions) { map[func.get()]; } // Run on the imports first. TODO: parallelize this too for (auto& func : wasm.functions) { if (func->imported()) { work(func.get(), map[func.get()]); } } struct Mapper : public WalkerPass> { bool isFunctionParallel() override { return true; } bool modifiesBinaryenIR() override { return false; } Mapper(Module& module, Map& map, Func work) : module(module), map(map), work(work) {} Mapper* create() override { return new Mapper(module, map, work); } void doWalkFunction(Function* curr) { assert(map.count(curr)); work(curr, map[curr]); } private: Module& module; Map& map; Func work; }; PassRunner runner(&wasm); Mapper(wasm, map, work).run(&runner, &wasm); } }; // Helper class for analyzing the call graph. // // Provides hooks for running some initial calculation on each function (which // is done in parallel), writing to a FunctionInfo structure for each function. // Then you can call propagateBack() to propagate a property of interest to the // calling functions, transitively. // // For example, if some functions are known to call an import "foo", then you // can use this to find which functions call something that might eventually // reach foo, by initially marking the direct callers as "calling foo" and // propagating that backwards. template struct CallGraphPropertyAnalysis { Module& wasm; // The basic information for each function about whom it calls and who is // called by it. struct FunctionInfo { std::set callsTo; std::set calledBy; // A non-direct call is any call that is not direct. That includes // CallIndirect and CallRef. bool hasNonDirectCall = false; }; typedef std::map Map; Map map; typedef std::function Func; CallGraphPropertyAnalysis(Module& wasm, Func work) : wasm(wasm) { ParallelFunctionAnalysis analysis(wasm, [&](Function* func, T& info) { work(func, info); if (func->imported()) { return; } struct Mapper : public PostWalker { Mapper(Module* module, T& info, Func work) : module(module), info(info), work(work) {} void visitCall(Call* curr) { info.callsTo.insert(module->getFunction(curr->target)); } void visitCallIndirect(CallIndirect* curr) { info.hasNonDirectCall = true; } void visitCallRef(CallRef* curr) { info.hasNonDirectCall = true; } private: Module* module; T& info; Func work; } mapper(&wasm, info, work); mapper.walk(func->body); }); map.swap(analysis.map); // Find what is called by what. for (auto& pair : map) { auto* func = pair.first; auto& info = pair.second; for (auto* target : info.callsTo) { map[target].calledBy.insert(func); } } } enum NonDirectCalls { IgnoreNonDirectCalls, NonDirectCallsHaveProperty }; // Propagate a property from a function to those that call it. // // hasProperty() - Check if the property is present. // canHaveProperty() - Check if the property could be present. // addProperty() - Adds the property. This receives a second parameter which // is the function due to which we are adding the property. void propagateBack(std::function hasProperty, std::function canHaveProperty, std::function addProperty, NonDirectCalls nonDirectCalls) { // The work queue contains items we just learned can change the state. UniqueDeferredQueue work; for (auto& func : wasm.functions) { if (hasProperty(map[func.get()]) || (nonDirectCalls == NonDirectCallsHaveProperty && map[func.get()].hasNonDirectCall)) { addProperty(map[func.get()], func.get()); work.push(func.get()); } } while (!work.empty()) { auto* func = work.pop(); for (auto* caller : map[func].calledBy) { // If we don't already have the property, and we are not forbidden // from getting it, then it propagates back to us now. if (!hasProperty(map[caller]) && canHaveProperty(map[caller])) { addProperty(map[caller], func); work.push(caller); } } } } }; // Helper function for collecting all the types that are declared in a module, // which means the HeapTypes (that are non-basic, that is, not eqref etc., which // do not need to be defined). // // Used when emitting or printing a module to give HeapTypes canonical // indices. HeapTypes are sorted in order of decreasing frequency to minize the // size of their collective encoding. Both a vector mapping indices to // HeapTypes and a map mapping HeapTypes to indices are produced. inline void collectHeapTypes(Module& wasm, std::vector& types, std::unordered_map& typeIndices) { struct Counts : public std::unordered_map { bool isRelevant(Type type) { if (type.isRef()) { return !type.getHeapType().isBasic(); } return type.isRtt(); } void note(HeapType type) { (*this)[type]++; } void maybeNote(Type type) { if (isRelevant(type)) { note(type.getHeapType()); } } }; // Collect the type use counts for a single function auto updateCounts = [&](Function* func, Counts& counts) { if (func->imported()) { return; } struct TypeCounter : PostWalker> { Counts& counts; TypeCounter(Counts& counts) : counts(counts) {} void visitExpression(Expression* curr) { if (auto* call = curr->dynCast()) { counts.note(call->sig); } else if (curr->is()) { counts.maybeNote(curr->type); } else if (curr->is() || curr->is()) { counts.note(curr->type.getRtt().heapType); } else if (auto* get = curr->dynCast()) { counts.maybeNote(get->ref->type); } else if (auto* set = curr->dynCast()) { counts.maybeNote(set->ref->type); } else if (Properties::isControlFlowStructure(curr)) { counts.maybeNote(curr->type); if (curr->type.isTuple()) { // TODO: Allow control flow to have input types as well counts.note(Signature(Type::none, curr->type)); } } } }; TypeCounter(counts).walk(func->body); }; ModuleUtils::ParallelFunctionAnalysis analysis(wasm, updateCounts); // Collect all the counts. Counts counts; for (auto& curr : wasm.functions) { counts.note(curr->sig); for (auto type : curr->vars) { counts.maybeNote(type); if (type.isTuple()) { for (auto t : type) { counts.maybeNote(t); } } } } for (auto& curr : wasm.events) { counts.note(curr->sig); } for (auto& curr : wasm.globals) { counts.maybeNote(curr->type); } for (auto& pair : analysis.map) { Counts& functionCounts = pair.second; for (auto& innerPair : functionCounts) { counts[innerPair.first] += innerPair.second; } } // A generic utility to traverse the child types of a type. // TODO: work with tlively to refactor this to a shared place auto walkRelevantChildren = [&](HeapType type, std::function callback) { auto callIfRelevant = [&](Type type) { if (counts.isRelevant(type)) { callback(type.getHeapType()); } }; if (type.isSignature()) { auto sig = type.getSignature(); for (Type type : {sig.params, sig.results}) { for (auto element : type) { callIfRelevant(element); } } } else if (type.isArray()) { callIfRelevant(type.getArray().element.type); } else if (type.isStruct()) { auto fields = type.getStruct().fields; for (auto field : fields) { callIfRelevant(field.type); } } }; // Recursively traverse each reference type, which may have a child type that // is itself a reference type. This reflects an appearance in the binary // format that is in the type section itself. // As we do this we may find more and more types, as nested children of // previous ones. Each such type will appear in the type section once, so // we just need to visit it once. // TODO: handle struct and array fields std::unordered_set newTypes; for (auto& pair : counts) { newTypes.insert(pair.first); } while (!newTypes.empty()) { auto iter = newTypes.begin(); auto type = *iter; newTypes.erase(iter); walkRelevantChildren(type, [&](HeapType type) { if (!counts.count(type)) { newTypes.insert(type); } counts.note(type); }); } // We must sort all the dependencies of a type before it. For example, // (func (param (ref (func)))) must appear after (func). To do that, find the // depth of dependencies of each type. For example, if A depends on B // which depends on C, then A's depth is 2, B's is 1, and C's is 0 (assuming // no other dependencies). Counts depthOfDependencies; std::unordered_map> isDependencyOf; // To calculate the depth of dependencies, we'll do a flow analysis, visiting // each type as we find out new things about it. std::set toVisit; for (auto& pair : counts) { auto type = pair.first; depthOfDependencies[type] = 0; toVisit.insert(type); walkRelevantChildren(type, [&](HeapType childType) { isDependencyOf[childType].insert(type); // XXX flip? }); } while (!toVisit.empty()) { auto iter = toVisit.begin(); auto type = *iter; toVisit.erase(iter); // Anything that depends on this has a depth of dependencies equal to this // type's, plus this type itself. auto newDepth = depthOfDependencies[type] + 1; if (newDepth > counts.size()) { Fatal() << "Cyclic types detected, cannot sort them."; } for (auto& other : isDependencyOf[type]) { if (depthOfDependencies[other] < newDepth) { // We found something new to propagate. depthOfDependencies[other] = newDepth; toVisit.insert(other); } } } // Sort by frequency and then simplicity, and also keeping every type // before things that depend on it. std::vector> sorted(counts.begin(), counts.end()); std::sort(sorted.begin(), sorted.end(), [&](auto a, auto b) { if (depthOfDependencies[a.first] != depthOfDependencies[b.first]) { return depthOfDependencies[a.first] < depthOfDependencies[b.first]; } if (a.second != b.second) { return a.second > b.second; } return a.first < b.first; }); for (Index i = 0; i < sorted.size(); ++i) { typeIndices[sorted[i].first] = i; types.push_back(sorted[i].first); } } } // namespace ModuleUtils } // namespace wasm #endif // wasm_ir_module_h