/* * Souffle - A Datalog Compiler * Copyright (c) 2021, The Souffle Developers. All rights reserved * Licensed under the Universal Permissive License v 1.0 as shown at: * - https://opensource.org/licenses/UPL * - /licenses/SOUFFLE-UPL.txt */ /************************************************************************ * * @file ReadStream.h * ***********************************************************************/ #pragma once #include "souffle/RamTypes.h" #include "souffle/RecordTable.h" #include "souffle/SymbolTable.h" #include "souffle/io/SerialisationStream.h" #include "souffle/utility/ContainerUtil.h" #include "souffle/utility/MiscUtil.h" #include "souffle/utility/StringUtil.h" #include "souffle/utility/json11.h" #include #include #include #include #include #include #include #include namespace souffle { class ReadStream : public SerialisationStream { protected: ReadStream( const std::map& rwOperation, SymbolTable& symTab, RecordTable& recTab) : SerialisationStream(symTab, recTab, rwOperation) {} public: template void readAll(T& relation) { while (const auto next = readNextTuple()) { const RamDomain* ramDomain = next.get(); relation.insert(ramDomain); } } protected: /** * Read a record from a string. * * @param source - string containing a record * @param recordTypeName - record type. * @parem pos - start parsing from this position. * @param consumed - if not nullptr: number of characters read. * */ RamDomain readRecord(const std::string& source, const std::string& recordTypeName, std::size_t pos = 0, std::size_t* charactersRead = nullptr) { const std::size_t initial_position = pos; // Check if record type information are present auto&& recordInfo = types["records"][recordTypeName]; if (recordInfo.is_null()) { throw std::invalid_argument("Missing record type information: " + recordTypeName); } // Handle nil case consumeWhiteSpace(source, pos); if (source.substr(pos, 3) == "nil") { if (charactersRead != nullptr) { *charactersRead = 3; } return 0; } auto&& recordTypes = recordInfo["types"]; const std::size_t recordArity = recordInfo["arity"].long_value(); std::vector recordValues(recordArity); consumeChar(source, '[', pos); for (std::size_t i = 0; i < recordArity; ++i) { const std::string& recordType = recordTypes[i].string_value(); std::size_t consumed = 0; if (i > 0) { consumeChar(source, ',', pos); } consumeWhiteSpace(source, pos); switch (recordType[0]) { case 's': { recordValues[i] = symbolTable.encode(readSymbol(source, ",]", pos, &consumed)); break; } case 'i': { recordValues[i] = RamSignedFromString(source.substr(pos), &consumed); break; } case 'u': { recordValues[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed)); break; } case 'f': { recordValues[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed)); break; } case 'r': { recordValues[i] = readRecord(source, recordType, pos, &consumed); break; } case '+': { recordValues[i] = readADT(source, recordType, pos, &consumed); break; } default: fatal("Invalid type attribute"); } pos += consumed; } consumeChar(source, ']', pos); if (charactersRead != nullptr) { *charactersRead = pos - initial_position; } return recordTable.pack(recordValues.data(), recordValues.size()); } RamDomain readADT(const std::string& source, const std::string& adtName, std::size_t pos = 0, std::size_t* charactersRead = nullptr) { const std::size_t initial_position = pos; // Branch will are encoded as one of the: // [branchIdx, [branchValues...]] // [branchIdx, branchValue] // branchIdx RamDomain branchIdx = -1; auto&& adtInfo = types["ADTs"][adtName]; const auto& branches = adtInfo["branches"]; if (adtInfo.is_null() || !branches.is_array()) { throw std::invalid_argument("Missing ADT information: " + adtName); } // Consume initial character consumeChar(source, '$', pos); std::string constructor = readIdentifier(source, pos); json11::Json branchInfo = [&]() -> json11::Json { for (auto branch : branches.array_items()) { ++branchIdx; if (branch["name"].string_value() == constructor) { return branch; } } throw std::invalid_argument("Missing branch information: " + constructor); }(); assert(branchInfo["types"].is_array()); auto branchTypes = branchInfo["types"].array_items(); // Handle a branch without arguments. if (branchTypes.empty()) { if (charactersRead != nullptr) { *charactersRead = pos - initial_position; } if (adtInfo["enum"].bool_value()) { return branchIdx; } RamDomain emptyArgs = recordTable.pack(toVector().data(), 0); const RamDomain record[] = {branchIdx, emptyArgs}; return recordTable.pack(record, 2); } consumeChar(source, '(', pos); std::vector branchArgs(branchTypes.size()); for (std::size_t i = 0; i < branchTypes.size(); ++i) { auto argType = branchTypes[i].string_value(); assert(!argType.empty()); std::size_t consumed = 0; if (i > 0) { consumeChar(source, ',', pos); } consumeWhiteSpace(source, pos); switch (argType[0]) { case 's': { branchArgs[i] = symbolTable.encode(readSymbol(source, ",)", pos, &consumed)); break; } case 'i': { branchArgs[i] = RamSignedFromString(source.substr(pos), &consumed); break; } case 'u': { branchArgs[i] = ramBitCast(RamUnsignedFromString(source.substr(pos), &consumed)); break; } case 'f': { branchArgs[i] = ramBitCast(RamFloatFromString(source.substr(pos), &consumed)); break; } case 'r': { branchArgs[i] = readRecord(source, argType, pos, &consumed); break; } case '+': { branchArgs[i] = readADT(source, argType, pos, &consumed); break; } default: fatal("Invalid type attribute"); } pos += consumed; } consumeChar(source, ')', pos); if (charactersRead != nullptr) { *charactersRead = pos - initial_position; } // Store branch either as [branch_id, [arguments]] or [branch_id, argument]. RamDomain branchValue = [&]() -> RamDomain { if (branchArgs.size() != 1) { return recordTable.pack(branchArgs.data(), branchArgs.size()); } else { return branchArgs[0]; } }(); RamDomain rec[2] = {branchIdx, branchValue}; return recordTable.pack(rec, 2); } /** * Read the next alphanumeric + ('_', '?') sequence (corresponding to IDENT). * Consume preceding whitespace. * TODO (darth_tytus): use std::string_view? */ std::string readIdentifier(const std::string& source, std::size_t& pos) { consumeWhiteSpace(source, pos); if (pos >= source.length()) { throw std::invalid_argument("Unexpected end of input"); } const std::size_t bgn = pos; while (pos < source.length()) { unsigned char ch = static_cast(source[pos]); bool valid = std::isalnum(ch) || ch == '_' || ch == '?'; if (!valid) break; ++pos; } return source.substr(bgn, pos - bgn); } std::string readUntil(const std::string& source, const std::string& stopChars, const std::size_t pos, std::size_t* charactersRead) { std::size_t endOfSymbol = source.find_first_of(stopChars, pos); if (endOfSymbol == std::string::npos) { throw std::invalid_argument("Unexpected end of input"); } *charactersRead = endOfSymbol - pos; return source.substr(pos, *charactersRead); } std::string readQuotedSymbol(const std::string& source, std::size_t pos, std::size_t* charactersRead) { const std::size_t start = pos; const std::size_t end = source.length(); const char quoteMark = source[pos]; ++pos; const std::size_t startOfSymbol = pos; std::size_t endOfSymbol = std::string::npos; bool hasEscaped = false; bool escaped = false; while (pos < end) { if (escaped) { hasEscaped = true; escaped = false; ++pos; continue; } const char c = source[pos]; if (c == quoteMark) { endOfSymbol = pos; ++pos; break; } if (c == '\\') { escaped = true; } ++pos; } if (endOfSymbol == std::string::npos) { throw std::invalid_argument("Unexpected end of input"); } *charactersRead = pos - start; std::size_t lengthOfSymbol = endOfSymbol - startOfSymbol; // fast handling of symbol without escape sequence if (!hasEscaped) { return source.substr(startOfSymbol, lengthOfSymbol); } else { // slow handling of symbol with escape sequence std::string symbol; symbol.reserve(lengthOfSymbol); bool escaped = false; for (std::size_t pos = startOfSymbol; pos < endOfSymbol; ++pos) { char ch = source[pos]; if (escaped || ch != '\\') { symbol.push_back(ch); escaped = false; } else { escaped = true; } } return symbol; } } /** * Read the next symbol. * It is either a double-quoted symbol with backslash-escaped chars, or the * longuest sequence that do not contains any of the given stopChars. * */ std::string readSymbol(const std::string& source, const std::string& stopChars, const std::size_t pos, std::size_t* charactersRead) { if (source[pos] == '"') { return readQuotedSymbol(source, pos, charactersRead); } else { return readUntil(source, stopChars, pos, charactersRead); } } /** * Read past given character, consuming any preceding whitespace. */ void consumeChar(const std::string& str, char c, std::size_t& pos) { consumeWhiteSpace(str, pos); if (pos >= str.length()) { throw std::invalid_argument("Unexpected end of input"); } if (str[pos] != c) { std::stringstream error; error << "Expected: \'" << c << "\', got: " << str[pos]; throw std::invalid_argument(error.str()); } ++pos; } /** * Advance position in the string until first non-whitespace character. */ void consumeWhiteSpace(const std::string& str, std::size_t& pos) { while (pos < str.length() && std::isspace(static_cast(str[pos]))) { ++pos; } } virtual Own readNextTuple() = 0; }; class ReadStreamFactory { public: virtual Own getReader( const std::map&, SymbolTable&, RecordTable&) = 0; virtual const std::string& getName() const = 0; virtual ~ReadStreamFactory() = default; }; } /* namespace souffle */