From 19c6aa2532080fb20829d670a01fcf0d935d5db6 Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Thu, 4 Jul 2024 10:22:51 +0100 Subject: [PATCH] add generic writer class with formatting function to FunctionalTools --- backends/functional/cxx.cc | 180 +++++++++++--------------- backends/functional/cxx_runtime/sim.h | 4 +- kernel/functionalir.cc | 51 ++++++++ kernel/functionalir.h | 128 ++++++++++-------- 4 files changed, 203 insertions(+), 160 deletions(-) diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index d866dfea210..21e287a96c4 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -56,17 +56,7 @@ struct CxxType { } }; -struct CxxWriter { - std::ostream &f; - CxxWriter(std::ostream &out) : f(out) {} - void printf(const char *fmt, ...) - { - va_list va; - va_start(va, fmt); - f << vstringf(fmt, va); - va_end(va); - } -}; +using CxxWriter = FunctionalTools::Writer; struct CxxStruct { std::string name; @@ -74,111 +64,85 @@ struct CxxStruct { FunctionalTools::Scope scope; CxxStruct(std::string name) : name(name), scope(illegal_characters, reserved_keywords) { - scope.reserve("out"); - scope.reserve("dump"); + scope.reserve("fn"); + scope.reserve("visit"); } void insert(IdString name, CxxType type) { scope(name); types.insert({name, type}); } void print(CxxWriter &f) { - f.printf("\tstruct %s {\n", name.c_str()); + f.print("\tstruct {} {{\n", name); for (auto p : types) { - f.printf("\t\t%s %s;\n", p.second.to_string().c_str(), scope(p.first).c_str()); + f.print("\t\t{} {};\n", p.second.to_string(), scope(p.first)); } - f.printf("\n\t\ttemplate void visit(T &&fn) {\n"); + f.print("\n\t\ttemplate void visit(T &&fn) {{\n"); for (auto p : types) { - f.printf("\t\t\tfn(\"%s\", %s);\n", RTLIL::unescape_id(p.first).c_str(), scope(p.first).c_str()); + f.print("\t\t\tfn(\"{}\", {});\n", RTLIL::unescape_id(p.first), scope(p.first)); } - f.printf("\t\t}\n"); - f.printf("\t};\n\n"); + f.print("\t\t}}\n"); + f.print("\t}};\n\n"); }; std::string operator[](IdString field) { return scope(field); } }; -struct CxxTemplate { - vector> _v; -public: - CxxTemplate(std::string fmt) { - std::string buf; - for(auto it = fmt.begin(); it != fmt.end(); it++){ - if(*it == '%'){ - it++; - log_assert(it != fmt.end()); - if(*it == '%') - buf += *it; - else { - log_assert(*it >= '0' && *it <= '9'); - _v.emplace_back(std::move(buf)); - _v.emplace_back((int)(*it - '0')); - } - }else - buf += *it; - } - if(!buf.empty()) - _v.emplace_back(std::move(buf)); - } - template static std::string format(CxxTemplate fmt, Args&&... args) { - vector strs = {args...}; - std::string result; - for(auto &v : fmt._v){ - if(std::string *s = std::get_if(&v)) - result += *s; - else if(int *i = std::get_if(&v)) - result += strs[*i]; - else - log_error("missing case"); - } - return result; - } -}; - -template struct CxxPrintVisitor { +template struct CxxPrintVisitor { using Node = FunctionalIR::Node; - NodeNames np; + CxxWriter &f; + NodePrinter np; CxxStruct &input_struct; CxxStruct &state_struct; - CxxPrintVisitor(NodeNames np, CxxStruct &input_struct, CxxStruct &state_struct) : np(np), input_struct(input_struct), state_struct(state_struct) { } - template std::string arg_to_string(T n) { return std::to_string(n); } - std::string arg_to_string(std::string n) { return n; } - std::string arg_to_string(Node n) { return np(n); } - template std::string format(std::string fmt, Args&&... args) { - return CxxTemplate::format(fmt, arg_to_string(args)...); + CxxPrintVisitor(CxxWriter &f, NodePrinter np, CxxStruct &input_struct, CxxStruct &state_struct) : f(f), np(np), input_struct(input_struct), state_struct(state_struct) { } + template void print(const char *fmt, Args&&... args) { + f.print_with(np, fmt, std::forward(args)...); } - std::string buf(Node, Node n) { return np(n); } - std::string slice(Node, Node a, int, int offset, int out_width) { return format("%0.slice<%2>(%1)", a, offset, out_width); } - std::string zero_extend(Node, Node a, int, int out_width) { return format("%0.zero_extend<%1>()", a, out_width); } - std::string sign_extend(Node, Node a, int, int out_width) { return format("%0.sign_extend<%1>()", a, out_width); } - std::string concat(Node, Node a, int, Node b, int) { return format("%0.concat(%1)", a, b); } - std::string add(Node, Node a, Node b, int) { return format("%0 + %1", a, b); } - std::string sub(Node, Node a, Node b, int) { return format("%0 - %1", a, b); } - std::string bitwise_and(Node, Node a, Node b, int) { return format("%0 & %1", a, b); } - std::string bitwise_or(Node, Node a, Node b, int) { return format("%0 | %1", a, b); } - std::string bitwise_xor(Node, Node a, Node b, int) { return format("%0 ^ %1", a, b); } - std::string bitwise_not(Node, Node a, int) { return format("~%0", a); } - std::string unary_minus(Node, Node a, int) { return format("-%0", a); } - std::string reduce_and(Node, Node a, int) { return format("%0.all()", a); } - std::string reduce_or(Node, Node a, int) { return format("%0.any()", a); } - std::string reduce_xor(Node, Node a, int) { return format("%0.parity()", a); } - std::string equal(Node, Node a, Node b, int) { return format("%0 == %1", a, b); } - std::string not_equal(Node, Node a, Node b, int) { return format("%0 != %1", a, b); } - std::string signed_greater_than(Node, Node a, Node b, int) { return format("%0.signed_greater_than(%1)", a, b); } - std::string signed_greater_equal(Node, Node a, Node b, int) { return format("%0.signed_greater_equal(%1)", a, b); } - std::string unsigned_greater_than(Node, Node a, Node b, int) { return format("%0 > %1", a, b); } - std::string unsigned_greater_equal(Node, Node a, Node b, int) { return format("%0 >= %1", a, b); } - std::string logical_shift_left(Node, Node a, Node b, int, int) { return format("%0 << %1", a, b); } - std::string logical_shift_right(Node, Node a, Node b, int, int) { return format("%0 >> %1", a, b); } - std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return format("%0.arithmetic_shift_right(%1)", a, b); } - std::string mux(Node, Node a, Node b, Node s, int) { return format("%2.any() ? %1 : %0", a, b, s); } - std::string pmux(Node, Node a, Node b, Node s, int, int) { return format("%0.pmux(%1, %2)", a, b, s); } - std::string constant(Node, RTLIL::Const value) { return format("Signal<%0>(%1)", value.size(), value.as_int()); } - std::string input(Node, IdString name) { return format("input.%0", input_struct[name]); } - std::string state(Node, IdString name) { return format("current_state.%0", state_struct[name]); } - std::string memory_read(Node, Node mem, Node addr, int, int) { return format("%0.read(%1)", mem, addr); } - std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return format("%0.write(%1, %2)", mem, addr, data); } - std::string undriven(Node, int width) { return format("Signal<%0>(0)", width); } + void buf(Node, Node n) { print("{}", n); } + void slice(Node, Node a, int, int offset, int out_width) { print("{0}.slice<{2}>({1})", a, offset, out_width); } + void zero_extend(Node, Node a, int, int out_width) { print("{}.zero_extend<{}>()", a, out_width); } + void sign_extend(Node, Node a, int, int out_width) { print("{}.sign_extend<{}>()", a, out_width); } + void concat(Node, Node a, int, Node b, int) { print("{}.concat({})", a, b); } + void add(Node, Node a, Node b, int) { print("{} + {}", a, b); } + void sub(Node, Node a, Node b, int) { print("{} - {}", a, b); } + void bitwise_and(Node, Node a, Node b, int) { print("{} & {}", a, b); } + void bitwise_or(Node, Node a, Node b, int) { print("{} | {}", a, b); } + void bitwise_xor(Node, Node a, Node b, int) { print("{} ^ {}", a, b); } + void bitwise_not(Node, Node a, int) { print("~{}", a); } + void unary_minus(Node, Node a, int) { print("-{}", a); } + void reduce_and(Node, Node a, int) { print("{}.all()", a); } + void reduce_or(Node, Node a, int) { print("{}.any()", a); } + void reduce_xor(Node, Node a, int) { print("{}.parity()", a); } + void equal(Node, Node a, Node b, int) { print("{} == {}", a, b); } + void not_equal(Node, Node a, Node b, int) { print("{} != {}", a, b); } + void signed_greater_than(Node, Node a, Node b, int) { print("{}.signed_greater_than({})", a, b); } + void signed_greater_equal(Node, Node a, Node b, int) { print("{}.signed_greater_equal({})", a, b); } + void unsigned_greater_than(Node, Node a, Node b, int) { print("{} > {}", a, b); } + void unsigned_greater_equal(Node, Node a, Node b, int) { print("{} >= {}", a, b); } + void logical_shift_left(Node, Node a, Node b, int, int) { print("{} << {}", a, b); } + void logical_shift_right(Node, Node a, Node b, int, int) { print("{} >> {}", a, b); } + void arithmetic_shift_right(Node, Node a, Node b, int, int) { print("{}.arithmetic_shift_right{})", a, b); } + void mux(Node, Node a, Node b, Node s, int) { print("{2}.any() ? {1} : {0}", a, b, s); } + void pmux(Node, Node a, Node b, Node s, int, int) { print("{0}.pmux({1}, {2})", a, b, s); } + void constant(Node, RTLIL::Const value) { + std::stringstream ss; + bool multiple = value.size() > 32; + ss << "Signal<" << value.size() << ">(" << std::hex << std::showbase; + if(multiple) ss << "{"; + while(value.size() > 32) { + ss << value.as_int() << ", "; + value = value.extract(32, value.size() - 32); + } + ss << value.as_int(); + if(multiple) ss << "}"; + ss << ")"; + print("{}", ss.str()); + } + void input(Node, IdString name) { print("input.{}", input_struct[name]); } + void state(Node, IdString name) { print("current_state.{}", state_struct[name]); } + void memory_read(Node, Node mem, Node addr, int, int) { print("{}.read({})", mem, addr); } + void memory_write(Node, Node mem, Node addr, Node data, int, int) { print("{}.write({}, {})", mem, addr, data); } + void undriven(Node, int width) { print("Signal<{}>(0)", width); } }; struct CxxModule { @@ -201,31 +165,35 @@ struct CxxModule { module_name = FunctionalTools::Scope(illegal_characters, reserved_keywords)(module->name); } void write_header(CxxWriter &f) { - f.printf("#include \"sim.h\"\n\n"); + f.print("#include \"sim.h\"\n\n"); } void write_struct_def(CxxWriter &f) { - f.printf("struct %s {\n", module_name.c_str()); + f.print("struct {} {{\n", module_name); input_struct.print(f); output_struct.print(f); state_struct.print(f); - f.printf("\tstatic void eval(Inputs const &, Outputs &, State const &, State &);\n"); - f.printf("};\n\n"); + f.print("\tstatic void eval(Inputs const &, Outputs &, State const &, State &);\n"); + f.print("}};\n\n"); } void write_eval_def(CxxWriter &f) { - f.printf("void %s::eval(%s::Inputs const &input, %s::Outputs &output, %s::State const ¤t_state, %s::State &next_state)\n{\n", module_name.c_str(), module_name.c_str(), module_name.c_str(), module_name.c_str(), module_name.c_str()); + f.print("void {0}::eval({0}::Inputs const &input, {0}::Outputs &output, {0}::State const ¤t_state, {0}::State &next_state)\n{{\n", module_name); FunctionalTools::Scope locals(illegal_characters, reserved_keywords); locals.reserve("input"); locals.reserve("output"); locals.reserve("current_state"); locals.reserve("next_state"); auto node_name = [&](FunctionalIR::Node n) { return locals(n.id(), n.name()); }; - for (auto node : ir) - f.printf("\t%s %s = %s;\n", CxxType(node.sort()).to_string().c_str(), node_name(node).c_str(), node.visit(CxxPrintVisitor(node_name, input_struct, state_struct)).c_str()); + CxxPrintVisitor printVisitor(f, node_name, input_struct, state_struct); + for (auto node : ir) { + f.print("\t{} {} = ", CxxType(node.sort()).to_string(), node_name(node)); + node.visit(printVisitor); + f.print(";\n"); + } for (auto [name, sort] : ir.state()) - f.printf("\tnext_state.%s = %s;\n", state_struct[name].c_str(), node_name(ir.get_state_next_node(name)).c_str()); + f.print("\tnext_state.{} = {};\n", state_struct[name], node_name(ir.get_state_next_node(name))); for (auto [name, sort] : ir.outputs()) - f.printf("\toutput.%s = %s;\n", output_struct[name].c_str(), node_name(ir.get_output_node(name)).c_str()); - f.printf("}\n"); + f.print("\toutput.{} = {};\n", output_struct[name], node_name(ir.get_output_node(name))); + f.print("}}\n"); } }; diff --git a/backends/functional/cxx_runtime/sim.h b/backends/functional/cxx_runtime/sim.h index 0614fb34cf7..1985322f1c8 100644 --- a/backends/functional/cxx_runtime/sim.h +++ b/backends/functional/cxx_runtime/sim.h @@ -143,11 +143,11 @@ class Signal { { for(size_t i = sizeof(T) * 8; i < n; i++) if(_bits[i]) - return ~0; + return ~((T)0); return as_numeric(); } - uint32_t as_int() { return as_numeric(); } + uint32_t as_int() const { return as_numeric(); } Signal operator ~() const { diff --git a/kernel/functionalir.cc b/kernel/functionalir.cc index 6ae66d6802b..3fd2d8240e7 100644 --- a/kernel/functionalir.cc +++ b/kernel/functionalir.cc @@ -443,4 +443,55 @@ void FunctionalIR::forward_buf() { _graph.permute(perm, alias); } +static std::string quote_fmt(const char *fmt) +{ + std::string r; + for(const char *p = fmt; *p != 0; p++) { + switch(*p) { + case '\n': r += "\\n"; break; + case '\t': r += "\\t"; break; + case '"': r += "\\\""; break; + case '\\': r += "\\\\"; break; + default: r += *p; break; + } + } + return r; +} + +void FunctionalTools::Writer::print_impl(const char *fmt, vector> &fns) +{ + size_t next_index = 0; + for(const char *p = fmt; *p != 0; p++) + switch(*p) { + case '{': + if(*++p == '{') { + *os << '{'; + } else { + char *pe; + size_t index = strtoul(p, &pe, 10); + if(*pe != '}') + log_error("invalid format string: expected {}, {} or {{, got \"%s\": \"%s\"\n", + quote_fmt(std::string(p - 1, pe - p + 2).c_str()).c_str(), + quote_fmt(fmt).c_str()); + if(p == pe) + index = next_index; + else + p = pe; + if(index >= fns.size()) + log_error("invalid format string: index %zu out of bounds (%zu): \"%s\"\n", index, fns.size(), quote_fmt(fmt).c_str()); + fns[index](); + next_index = index + 1; + } + break; + case '}': + p++; + if(*p != '}') + log_error("invalid format string: unescaped }: \"%s\"\n", quote_fmt(fmt).c_str()); + *os << '}'; + break; + default: + *os << *p; + } +} + YOSYS_NAMESPACE_END diff --git a/kernel/functionalir.h b/kernel/functionalir.h index f49a659f37b..2c0d0f55c2c 100644 --- a/kernel/functionalir.h +++ b/kernel/functionalir.h @@ -29,58 +29,6 @@ USING_YOSYS_NAMESPACE YOSYS_NAMESPACE_BEGIN -namespace FunctionalTools { - class Scope { - const char *_illegal_characters; - pool _used_names; - dict _by_id; - dict _by_name; - std::string allocate_name(IdString suggestion) { - std::string str = RTLIL::unescape_id(suggestion); - for(size_t i = 0; i < str.size(); i++) - if(strchr(_illegal_characters, str[i])) - str[i] = '_'; - if(_used_names.count(str) == 0) { - _used_names.insert(str); - return str; - } - for (int idx = 0 ; ; idx++){ - std::string suffixed = str + "_" + std::to_string(idx); - if(_used_names.count(suffixed) == 0) { - _used_names.insert(suffixed); - return suffixed; - } - } - } - public: - Scope(const char *illegal_characters = "", const char **keywords = nullptr) { - _illegal_characters = illegal_characters; - if(keywords != nullptr) - for(const char **p = keywords; *p != nullptr; p++) - reserve(*p); - } - void reserve(std::string name) { - _used_names.insert(std::move(name)); - } - std::string operator()(int id, IdString suggestion) { - auto it = _by_id.find(id); - if(it != _by_id.end()) - return it->second; - std::string str = allocate_name(suggestion); - _by_id.insert({id, str}); - return str; - } - std::string operator()(IdString idstring) { - auto it = _by_name.find(idstring); - if(it != _by_name.end()) - return it->second; - std::string str = allocate_name(idstring); - _by_name.insert({idstring, str}); - return str; - } - }; -} - class FunctionalIR { enum class Fn { invalid, @@ -418,6 +366,82 @@ class FunctionalIR { Iterator end() { return Iterator(this, _graph.size()); } }; +namespace FunctionalTools { + class Scope { + const char *_illegal_characters; + pool _used_names; + dict _by_id; + dict _by_name; + std::string allocate_name(IdString suggestion) { + std::string str = RTLIL::unescape_id(suggestion); + for(size_t i = 0; i < str.size(); i++) + if(strchr(_illegal_characters, str[i])) + str[i] = '_'; + if(_used_names.count(str) == 0) { + _used_names.insert(str); + return str; + } + for (int idx = 0 ; ; idx++){ + std::string suffixed = str + "_" + std::to_string(idx); + if(_used_names.count(suffixed) == 0) { + _used_names.insert(suffixed); + return suffixed; + } + } + } + public: + Scope(const char *illegal_characters = "", const char **keywords = nullptr) { + _illegal_characters = illegal_characters; + if(keywords != nullptr) + for(const char **p = keywords; *p != nullptr; p++) + reserve(*p); + } + void reserve(std::string name) { + _used_names.insert(std::move(name)); + } + std::string operator()(int id, IdString suggestion) { + auto it = _by_id.find(id); + if(it != _by_id.end()) + return it->second; + std::string str = allocate_name(suggestion); + _by_id.insert({id, str}); + return str; + } + std::string operator()(IdString idstring) { + auto it = _by_name.find(idstring); + if(it != _by_name.end()) + return it->second; + std::string str = allocate_name(idstring); + _by_name.insert({idstring, str}); + return str; + } + }; + class Writer { + std::ostream *os; + void print_impl(const char *fmt, vector>& fns); + public: + Writer(std::ostream &os) : os(&os) {} + template Writer& operator <<(T&& arg) { *os << std::forward(arg); return *this; } + template + void print(const char *fmt, Args&&... args) + { + vector> fns { [&]() { *this << args; }... }; + print_impl(fmt, fns); + } + template + void print_with(Fn fn, const char *fmt, Args&&... args) + { + vector> fns { [&]() { + if constexpr (std::is_invocable_v) + *this << fn(args); + else + *this << args; }... + }; + print_impl(fmt, fns); + } + }; +} + YOSYS_NAMESPACE_END #endif