From af83d8da19f4ee7900a556eb984a7d1ac82d71dd Mon Sep 17 00:00:00 2001 From: Emily Schmidt Date: Wed, 10 Jul 2024 14:28:48 +0100 Subject: [PATCH] rewrite smtlib pass to use SExpr class --- backends/functional/cxx.cc | 28 +- backends/functional/smtlib.cc | 725 +++++++++++++++++----------------- kernel/functionalir.h | 45 +-- 3 files changed, 409 insertions(+), 389 deletions(-) diff --git a/backends/functional/cxx.cc b/backends/functional/cxx.cc index 1d8caaaf440..4ab9d53a1d4 100644 --- a/backends/functional/cxx.cc +++ b/backends/functional/cxx.cc @@ -19,11 +19,11 @@ #include "kernel/yosys.h" #include "kernel/functionalir.h" +#include USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN -const char illegal_characters[] = "!\"#%&'()*+,-./:;<=>?@[]\\^`{|}~ "; const char *reserved_keywords[] = { "alignas","alignof","and","and_eq","asm","atomic_cancel","atomic_commit", "atomic_noexcept","auto","bitand","bitor","bool","break","case", @@ -42,6 +42,16 @@ const char *reserved_keywords[] = { nullptr }; +template struct CxxScope : public FunctionalTools::Scope { + CxxScope() { + for(const char **p = reserved_keywords; *p != nullptr; p++) + this->reserve(*p); + } + bool is_character_legal(char c) override { + return isascii(c) && (isalnum(c) || c == '_' || c == '$'); + } +}; + struct CxxType { FunctionalIR::Sort sort; CxxType(FunctionalIR::Sort sort) : sort(sort) {} @@ -61,30 +71,30 @@ using CxxWriter = FunctionalTools::Writer; struct CxxStruct { std::string name; dict types; - FunctionalTools::Scope scope; + CxxScope scope; CxxStruct(std::string name) - : name(name), scope(illegal_characters, reserved_keywords) { + : name(name) { scope.reserve("fn"); scope.reserve("visit"); } void insert(IdString name, CxxType type) { - scope(name); + scope(name, name); types.insert({name, type}); } void print(CxxWriter &f) { f.print("\tstruct {} {{\n", name); for (auto p : types) { - f.print("\t\t{} {};\n", p.second.to_string(), scope(p.first)); + f.print("\t\t{} {};\n", p.second.to_string(), scope(p.first, p.first)); } f.print("\n\t\ttemplate void visit(T &&fn) {{\n"); for (auto p : types) { - f.print("\t\t\tfn(\"{}\", {});\n", RTLIL::unescape_id(p.first), scope(p.first)); + f.print("\t\t\tfn(\"{}\", {});\n", RTLIL::unescape_id(p.first), scope(p.first, p.first)); } f.print("\t\t}}\n"); f.print("\t}};\n\n"); }; std::string operator[](IdString field) { - return scope(field); + return scope(field, field); } }; @@ -165,7 +175,7 @@ struct CxxModule { output_struct.insert(name, sort); for (auto [name, sort] : ir.state()) state_struct.insert(name, sort); - module_name = FunctionalTools::Scope(illegal_characters, reserved_keywords)(module->name); + module_name = CxxScope().unique_name(module->name); } void write_header(CxxWriter &f) { f.print("#include \"sim.h\"\n\n"); @@ -180,7 +190,7 @@ struct CxxModule { } void write_eval_def(CxxWriter &f) { f.print("void {0}::eval({0}::Inputs const &input, {0}::Outputs &output, {0}::State const ¤t_state, {0}::State &next_state)\n{{\n", module_name); - FunctionalTools::Scope locals(illegal_characters, reserved_keywords); + CxxScope locals; locals.reserve("input"); locals.reserve("output"); locals.reserve("current_state"); diff --git a/backends/functional/smtlib.cc b/backends/functional/smtlib.cc index 501d195ebde..c3e159eb31e 100644 --- a/backends/functional/smtlib.cc +++ b/backends/functional/smtlib.cc @@ -19,385 +19,406 @@ #include "kernel/functionalir.h" #include "kernel/yosys.h" +#include USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN -const char illegal_characters[] = "#:\\"; -const char *reserved_keywords[] = {nullptr}; - -struct SmtScope { - pool used_names; - dict name_map; - FunctionalTools::Scope scope; - SmtScope() : scope(illegal_characters, reserved_keywords) {} - - void reserve(const std::string &name) { used_names.insert(name); } - - std::string insert(IdString id) - { - - std::string name = scope(id); - if (used_names.count(name) == 0) { - used_names.insert(name); - name_map[id] = name; - return name; - } - for (int idx = 0;; ++idx) { - std::string new_name = name + "_" + std::to_string(idx); - if (used_names.count(new_name) == 0) { - used_names.insert(new_name); - name_map[id] = new_name; - return new_name; - } - } - } - - std::string operator[](IdString id) - { - if (name_map.count(id)) { - return name_map[id]; - } else { - return insert(id); - } - } +const char *reserved_keywords[] = { + "BINARY", "DECIMAL", "HEXADECIMAL", "NUMERAL", "STRING", "_", "!", "as", "let", "exists", "forall", "match", "par", + "assert", "check-sat", "check-sat-assuming", "declare-const", "declare-datatype", "declare-datatypes", + "declare-fun", "declare-sort", "define-fun", "define-fun-rec", "define-funs-rec", "define-sort", + "exit", "get-assertions", "symbol", "sort", "get-assignment", "get-info", "get-model", + "get-option", "get-proof", "get-unsat-assumptions", "get-unsat-core", "get-value", + "pop", "push", "reset", "reset-assertions", "set-info", "set-logic", "set-option", + + "pair", "Pair", "first", "second", + "inputs", "state", + nullptr }; -struct SmtWriter { - std::ostream &stream; - - SmtWriter(std::ostream &out) : stream(out) {} - - void print(const char *fmt, ...) - { - va_list args; - va_start(args, fmt); - stream << vstringf(fmt, args); - va_end(args); - } +struct SmtScope : public FunctionalTools::Scope { + SmtScope() { + for(const char **p = reserved_keywords; *p != nullptr; p++) + reserve(*p); + } + bool is_character_legal(char c) override { + return isascii(c) && (isalnum(c) || strchr("~!@$%^&*_-+=<>.?/", c)); + } }; -template struct SmtPrintVisitor { - using Node = FunctionalIR::Node; - NodeNames np; - SmtScope &scope; - - SmtPrintVisitor(NodeNames np, SmtScope &scope) : np(np), scope(scope) {} - - template std::string arg_to_string(T n) { return std::to_string(n); } - - std::string arg_to_string(std::string n) { return n; } - - std::string arg_to_string(Node n) { return np(n); } - - template std::string format(std::string fmt, Args &&...args) - { - std::vector arg_strings = {arg_to_string(std::forward(args))...}; - for (size_t i = 0; i < arg_strings.size(); ++i) { - std::string placeholder = "%" + std::to_string(i); - size_t pos = 0; - while ((pos = fmt.find(placeholder, pos)) != std::string::npos) { - fmt.replace(pos, placeholder.length(), arg_strings[i]); - pos += arg_strings[i].length(); - } - } - return fmt; - } - std::string buf(Node, Node n) { return np(n); } - - std::string slice(Node, Node a, int, int offset, int out_width) - { - return format("((_ extract %2 %1) %0)", np(a), offset, offset + out_width - 1); - } - - std::string zero_extend(Node, Node a, int, int out_width) { return format("((_ zero_extend %1) %0)", np(a), out_width - a.width()); } - - std::string sign_extend(Node, Node a, int, int out_width) { return format("((_ sign_extend %1) %0)", np(a), out_width - a.width()); } - - std::string concat(Node, Node a, int, Node b, int) { return format("(concat %0 %1)", np(a), np(b)); } - - std::string add(Node, Node a, Node b, int) { return format("(bvadd %0 %1)", np(a), np(b)); } - - std::string sub(Node, Node a, Node b, int) { return format("(bvsub %0 %1)", np(a), np(b)); } - - std::string mul(Node, Node a, Node b, int) { return format("(bvmul %0 %1)", np(a), np(b)); } - - std::string unsigned_div(Node, Node a, Node b, int) { return format("(bvudiv %0 %1)", np(a), np(b)); } - - std::string unsigned_mod(Node, Node a, Node b, int) { return format("(bvurem %0 %1)", np(a), np(b)); } - - std::string bitwise_and(Node, Node a, Node b, int) { return format("(bvand %0 %1)", np(a), np(b)); } - - std::string bitwise_or(Node, Node a, Node b, int) { return format("(bvor %0 %1)", np(a), np(b)); } - - std::string bitwise_xor(Node, Node a, Node b, int) { return format("(bvxor %0 %1)", np(a), np(b)); } - - std::string bitwise_not(Node, Node a, int) { return format("(bvnot %0)", np(a)); } - - std::string unary_minus(Node, Node a, int) { return format("(bvneg %0)", np(a)); } - - std::string reduce_and(Node, Node a, int) { - std::stringstream ss; - // We use ite to set the result to bit vector, to ensure appropriate type - ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '1') << ") #b1 #b0)"; - return ss.str(); - } - - std::string reduce_or(Node, Node a, int) - { - std::stringstream ss; - // We use ite to set the result to bit vector, to ensure appropriate type - ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '0') << ") #b0 #b1)"; - return ss.str(); - } - - std::string reduce_xor(Node, Node a, int) { - std::stringstream ss; - ss << "(bvxor "; - for (int i = 0; i < a.width(); ++i) { - if (i > 0) ss << " "; - ss << "((_ extract " << i << " " << i << ") " << np(a) << ")"; - } - ss << ")"; - return ss.str(); - } - - std::string equal(Node, Node a, Node b, int) { - return format("(ite (= %0 %1) #b1 #b0)", np(a), np(b)); - } - - std::string not_equal(Node, Node a, Node b, int) { - return format("(ite (distinct %0 %1) #b1 #b0)", np(a), np(b)); - } - - std::string signed_greater_than(Node, Node a, Node b, int) { - return format("(ite (bvsgt %0 %1) #b1 #b0)", np(a), np(b)); - } - - std::string signed_greater_equal(Node, Node a, Node b, int) { - return format("(ite (bvsge %0 %1) #b1 #b0)", np(a), np(b)); - } - - std::string unsigned_greater_than(Node, Node a, Node b, int) { - return format("(ite (bvugt %0 %1) #b1 #b0)", np(a), np(b)); - } - - std::string unsigned_greater_equal(Node, Node a, Node b, int) { - return format("(ite (bvuge %0 %1) #b1 #b0)", np(a), np(b)); - } - - std::string logical_shift_left(Node, Node a, Node b, int, int) { - // Get the bit-widths of a and b - int bit_width_a = a.width(); - int bit_width_b = b.width(); - - // Extend b to match the bit-width of a if necessary - std::ostringstream oss; - if (bit_width_a > bit_width_b) { - oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")"; - } else { - oss << np(b); // No extension needed if b's width is already sufficient - } - std::string b_extended = oss.str(); - - // Format the bvshl operation with the extended b - oss.str(""); // Clear the stringstream - oss << "(bvshl " << np(a) << " " << b_extended << ")"; - return oss.str(); - } - - std::string logical_shift_right(Node, Node a, Node b, int, int) { - // Get the bit-widths of a and b - int bit_width_a = a.width(); - int bit_width_b = b.width(); +class SExpr { + std::variant, std::string> _v; +public: + SExpr(std::string a) : _v(std::move(a)) {} + SExpr(const char *a) : _v(a) {} + SExpr(int n) : _v(std::to_string(n)) {} + SExpr(std::vector const &a) : _v(std::in_place_index<0>, a) {} + SExpr(std::vector &&a) : _v(std::in_place_index<0>, std::move(a)) {} + SExpr(std::initializer_list a) : _v(std::in_place_index<0>, a) {} + bool is_atom() const { return std::holds_alternative(_v); } + std::string const &atom() const { return std::get(_v); } + bool is_list() const { return std::holds_alternative>(_v); } + std::vector const &list() const { return std::get>(_v); } + friend std::ostream &operator<<(std::ostream &os, SExpr const &sexpr) { + if(sexpr.is_atom()) + os << sexpr.atom(); + else if(sexpr.is_list()){ + os << "("; + auto l = sexpr.list(); + for(size_t i = 0; i < l.size(); i++) { + if(i > 0) os << " "; + os << l[i]; + } + os << ")"; + }else + os << ""; + return os; + } + std::string to_string() const { + std::stringstream ss; + ss << *this; + return ss.str(); + } +}; - // Extend b to match the bit-width of a if necessary - std::ostringstream oss; - if (bit_width_a > bit_width_b) { - oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")"; - } else { - oss << np(b); // No extension needed if b's width is already sufficient +class SExprWriter { + std::ostream &os; + int _max_line_width; + int _indent = 0; + int _pos = 0; + bool _pending_nl = false; + vector _unclosed; + vector _unclosed_stack; + void nl_if_pending() { + if(_pending_nl) { + os << '\n'; + _pos = 0; + _pending_nl = false; + } + } + void puts(std::string const &s) { + if(s.empty()) return; + nl_if_pending(); + for(auto c : s) { + if(c == '\n') { + os << c; + _pos = 0; + } else { + if(_pos == 0) { + for(int i = 0; i < _indent; i++) + os << " "; + _pos = 2 * _indent; + } + os << c; + _pos++; + } + } } - std::string b_extended = oss.str(); - - // Format the bvlshr operation with the extended b - oss.str(""); // Clear the stringstream - oss << "(bvlshr " << np(a) << " " << b_extended << ")"; - return oss.str(); - } - - std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { - // Get the bit-widths of a and b - int bit_width_a = a.width(); - int bit_width_b = b.width(); - - // Extend b to match the bit-width of a if necessary - std::ostringstream oss; - if (bit_width_a > bit_width_b) { - oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")"; - } else { - oss << np(b); // No extension needed if b's width is already sufficient + int width(SExpr const &sexpr) { + if(sexpr.is_atom()) + return sexpr.atom().size(); + else if(sexpr.is_list()) { + int w = 2; + for(auto arg : sexpr.list()) + w += width(arg); + if(sexpr.list().size() > 1) + w += sexpr.list().size() - 1; + return w; + } else + return 0; } - std::string b_extended = oss.str(); - - // Format the bvashr operation with the extended b - oss.str(""); // Clear the stringstream - oss << "(bvashr " << np(a) << " " << b_extended << ")"; - return oss.str(); - } - - std::string mux(Node, Node a, Node b, Node s, int) { - return format("(ite (= %2 #b1) %0 %1)", np(a), np(b), np(s)); - } - - std::string pmux(Node, Node a, Node b, Node s, int, int) - { - // Assume s is a bit vector, combine a and b based on the selection bits - return format("(pmux %0 %1 %2)", np(a), np(b), np(s)); - } - - std::string constant(Node, RTLIL::Const value) { return format("#b%0", value.as_string()); } - - std::string input(Node, IdString name) { return format("%0", scope[name]); } - - std::string state(Node, IdString name) { return format("(%0 current_state)", scope[name]); } - - std::string memory_read(Node, Node mem, Node addr, int, int) { return format("(select %0 %1)", np(mem), np(addr)); } - - std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return format("(store %0 %1 %2)", np(mem), np(addr), np(data)); } - - std::string undriven(Node, int width) { return format("#b%0", std::string(width, '0')); } -}; - -struct SmtModule { - std::string name; - SmtScope scope; - FunctionalIR ir; - - SmtModule(const std::string &module_name, FunctionalIR ir) : name(module_name), ir(std::move(ir)) {} - - void write(std::ostream &out) - { - const bool stateful = ir.state().size() != 0; - SmtWriter writer(out); - - writer.print("(declare-fun %s () Bool)\n\n", name.c_str()); - - writer.print("(declare-datatypes () ((Inputs (mk_inputs"); - for (const auto &input : ir.inputs()) { - std::string input_name = scope[input.first]; - writer.print(" (%s (_ BitVec %d))", input_name.c_str(), input.second.width()); + void print(SExpr const &sexpr, bool close = true, bool indent_rest = true) { + if(sexpr.is_atom()) + puts(sexpr.atom()); + else if(sexpr.is_list()) { + auto args = sexpr.list(); + puts("("); + bool vertical = args.size() > 1 && _pos + width(sexpr) > _max_line_width; + if(vertical) _indent++; + for(size_t i = 0; i < args.size(); i++) { + if(i > 0) puts(vertical ? "\n" : " "); + print(args[i]); + } + _indent += (!close && indent_rest) - vertical; + if(close) + puts(")"); + else { + _unclosed.push_back(indent_rest); + _pending_nl = true; + } + }else + log_error("shouldn't happen in SExprWriter::print"); } - writer.print("))))\n\n"); - - writer.print("(declare-datatypes () ((Outputs (mk_outputs"); - for (const auto &output : ir.outputs()) { - std::string output_name = scope[output.first]; - writer.print(" (%s (_ BitVec %d))", output_name.c_str(), output.second.width()); +public: + SExprWriter(std::ostream &os, int max_line_width = 80) + : os(os) + , _max_line_width(max_line_width) + {} + void open(SExpr const &sexpr, bool indent_rest = true) { + log_assert(sexpr.is_list()); + print(sexpr, false, indent_rest); } - writer.print("))))\n"); - - if (stateful) { - writer.print("(declare-datatypes () ((State (mk_state"); - for (const auto &state : ir.state()) { - std::string state_name = scope[state.first]; - writer.print(" (%s (_ BitVec %d))", state_name.c_str(), state.second.width()); - } - writer.print("))))\n"); - - writer.print("(declare-datatypes () ((Pair (mk-pair (outputs Outputs) (next_state State)))))\n"); + void close(size_t n = 1) { + log_assert(_unclosed.size() - (_unclosed_stack.empty() ? 0 : _unclosed_stack.back()) >= n); + while(n-- > 0) { + bool indented = _unclosed[_unclosed.size() - 1]; + _unclosed.pop_back(); + _pending_nl = _pos >= _max_line_width; + if(indented) + _indent--; + puts(")"); + _pending_nl = true; + } } - - if (stateful) - writer.print("(define-fun %s_step ((current_state State) (inputs Inputs)) Pair", name.c_str()); - else - writer.print("(define-fun %s_step ((inputs Inputs)) Outputs", name.c_str()); - - writer.print(" (let ("); - for (const auto &input : ir.inputs()) { - std::string input_name = scope[input.first]; - writer.print(" (%s (%s inputs))", input_name.c_str(), input_name.c_str()); + void push() { + _unclosed_stack.push_back(_unclosed.size()); + } + void pop() { + auto t = _unclosed_stack.back(); + log_assert(_unclosed.size() >= t); + close(_unclosed.size() - t); + _unclosed_stack.pop_back(); + } + SExprWriter &operator <<(SExpr const &sexpr) { + print(sexpr); + _pending_nl = true; + return *this; } - writer.print(" )"); - - auto node_to_string = [&](FunctionalIR::Node n) { return scope[n.name()]; }; - SmtPrintVisitor visitor(node_to_string, scope); - - for (auto it = ir.begin(); it != ir.end(); ++it) { - const FunctionalIR::Node &node = *it; - - if (ir.inputs().count(node.name()) > 0) - continue; - - std::string node_name = scope[node.name()]; - std::string node_expr = node.visit(visitor); - - writer.print(" (let ( (%s %s))", node_name.c_str(), node_expr.c_str()); + void comment(std::string const &str, bool hanging = false) { + if(hanging) { + if(_pending_nl) { + _pending_nl = false; + puts(" "); + } + } + puts("; "); + puts(str); + puts("\n"); + } + ~SExprWriter() { + while(!_unclosed_stack.empty()) + pop(); + close(_unclosed.size()); + nl_if_pending(); } +}; - if (stateful) { - writer.print(" (let ( (next_state (mk_state "); - for (const auto &state : ir.state()) { - std::string state_name = scope[state.first]; - const std::string state_assignment = ir.get_state_next_node(state.first).name().c_str(); - writer.print(" %s", state_assignment.substr(1).c_str()); - } - writer.print(" )))"); - } +struct SmtSort { + FunctionalIR::Sort sort; + SmtSort(FunctionalIR::Sort sort) : sort(sort) {} + SExpr to_sexpr() const { + if(sort.is_memory()) { + return SExpr{"Array", {"_", "BitVec", sort.addr_width()}, {"_", "BitVec", sort.data_width()}}; + } else if(sort.is_signal()) { + return SExpr{"_", "BitVec", sort.width()}; + } else { + log_error("unknown sort"); + } + } +}; - if (stateful) { - writer.print(" (let ( (outputs (mk_outputs "); - for (const auto &output : ir.outputs()) { - std::string output_name = scope[output.first]; - writer.print(" %s", output_name.c_str()); - } - writer.print(" )))"); +class SmtStruct { + struct Field { + SmtSort sort; + std::string accessor; + }; + idict field_names; + vector fields; + SmtScope &scope; +public: + std::string name; + SmtStruct(std::string name, SmtScope &scope) : scope(scope), name(name) {} + void insert(IdString field_name, SmtSort sort) { + field_names(field_name); + auto accessor = scope.unique_name("\\" + name + "_" + RTLIL::unescape_id(field_name)); + fields.emplace_back(Field{sort, accessor}); + } + void write_definition(SExprWriter &w) { + w.open(SExpr{"declare-datatype", name}); + w.open(SExpr({})); + w.open(SExpr{name}); + for(const auto &field : fields) + w << SExpr{field.accessor, field.sort.to_sexpr()}; + w.close(3); + } + template void write_value(SExprWriter &w, Fn fn) { + w.open(SExpr(std::initializer_list{name})); + for(auto field_name : field_names) { + w << fn(field_name); + w.comment(RTLIL::unescape_id(field_name), true); + } + w.close(); + } + SExpr access(SExpr record, IdString name) { + size_t i = field_names.at(name); + return SExpr{fields[i].accessor, std::move(record)}; + } +}; - writer.print("(mk-pair outputs next_state)"); - } - else { - writer.print(" (mk_outputs "); - for (const auto &output : ir.outputs()) { - std::string output_name = scope[output.first]; - writer.print(" %s", output_name.c_str()); - } - writer.print(" )"); // Closing mk_outputs - } - if (stateful) { - writer.print(" )"); // Closing outputs let statement - writer.print(" )"); // Closing next_state let statement - } - // Close the nested lets - for (size_t i = 0; i < ir.size() - ir.inputs().size(); ++i) { - writer.print(" )"); // Closing each node - } - if (ir.size() == ir.inputs().size()) - writer.print(" )"); // Corner case +struct SmtPrintVisitor { + using Node = FunctionalIR::Node; + std::function n; + SmtStruct &input_struct; + SmtStruct &state_struct; + + SmtPrintVisitor(SmtStruct &input_struct, SmtStruct &state_struct) : input_struct(input_struct), state_struct(state_struct) {} + + std::string literal(RTLIL::Const c) { + std::string s = "#b"; + for(int i = c.size(); i-- > 0; ) + s += c[i] == State::S1 ? '1' : '0'; + return s; + } + + SExpr from_bool(SExpr &&arg) { + return SExpr{"ite", std::move(arg), "#b1", "#b0"}; + } + SExpr to_bool(SExpr &&arg) { + return SExpr{"=", std::move(arg), "#b1"}; + } + SExpr extract(SExpr &&arg, int offset, int out_width = 1) { + return SExpr{{"_", "extract", offset + out_width - 1, offset}, std::move(arg)}; + } + + SExpr buf(Node, Node a) { return n(a); } + SExpr slice(Node, Node a, int, int offset, int out_width) { return extract(n(a), offset, out_width); } + SExpr zero_extend(Node, Node a, int, int out_width) { return SExpr{{"_", "zero_extend", out_width - a.width()}, n(a)}; } + SExpr sign_extend(Node, Node a, int, int out_width) { return SExpr{{"_", "sign_extend", out_width - a.width()}, n(a)}; } + SExpr concat(Node, Node a, int, Node b, int) { return SExpr{"concat", n(a), n(b)}; } + SExpr add(Node, Node a, Node b, int) { return SExpr{"bvadd", n(a), n(b)}; } + SExpr sub(Node, Node a, Node b, int) { return SExpr{"bvsub", n(a), n(b)}; } + SExpr mul(Node, Node a, Node b, int) { return SExpr{"bvmul", n(a), n(b)}; } + SExpr unsigned_div(Node, Node a, Node b, int) { return SExpr{"bvudiv", n(a), n(b)}; } + SExpr unsigned_mod(Node, Node a, Node b, int) { return SExpr{"bvurem", n(a), n(b)}; } + SExpr bitwise_and(Node, Node a, Node b, int) { return SExpr{"bvand", n(a), n(b)}; } + SExpr bitwise_or(Node, Node a, Node b, int) { return SExpr{"bvor", n(a), n(b)}; } + SExpr bitwise_xor(Node, Node a, Node b, int) { return SExpr{"bvxor", n(a), n(b)}; } + SExpr bitwise_not(Node, Node a, int) { return SExpr{"bvnot", n(a)}; } + SExpr unary_minus(Node, Node a, int) { return SExpr{"bvneg", n(a)}; } + SExpr reduce_and(Node, Node a, int) { return from_bool(SExpr{"=", n(a), literal(RTLIL::Const(State::S1, a.width()))}); } + SExpr reduce_or(Node, Node a, int) { return from_bool(SExpr{"=", n(a), literal(RTLIL::Const(State::S0, a.width()))}); } + SExpr reduce_xor(Node, Node a, int) { + vector s { "bvxor" }; + for(int i = 0; i < a.width(); i++) + s.push_back(extract(n(a), i)); + return s; + } + SExpr equal(Node, Node a, Node b, int) { return from_bool(SExpr{"=", n(a), n(b)}); } + SExpr not_equal(Node, Node a, Node b, int) { return from_bool(SExpr{"distinct", n(a), n(b)}); } + SExpr signed_greater_than(Node, Node a, Node b, int) { return from_bool(SExpr{"bvsgt", n(a), n(b)}); } + SExpr signed_greater_equal(Node, Node a, Node b, int) { return from_bool(SExpr{"bvsge", n(a), n(b)}); } + SExpr unsigned_greater_than(Node, Node a, Node b, int) { return from_bool(SExpr{"bvugt", n(a), n(b)}); } + SExpr unsigned_greater_equal(Node, Node a, Node b, int) { return from_bool(SExpr{"bvuge", n(a), n(b)}); } + + SExpr extend(SExpr &&a, int in_width, int out_width) { + if(in_width < out_width) + return SExpr{{"_", "zero_extend", out_width - in_width}, std::move(a)}; + else + return std::move(a); + } + SExpr logical_shift_left(Node, Node a, Node b, int, int) { return SExpr{"bvshl", n(a), extend(n(b), b.width(), a.width())}; } + SExpr logical_shift_right(Node, Node a, Node b, int, int) { return SExpr{"bvshr", n(a), extend(n(b), b.width(), a.width())}; } + SExpr arithmetic_shift_right(Node, Node a, Node b, int, int) { return SExpr{"bvasr", n(a), extend(n(b), b.width(), a.width())}; } + SExpr mux(Node, Node a, Node b, Node s, int) { return SExpr{"ite", to_bool(n(s)), n(a), n(b)}; } + SExpr pmux(Node, Node a, Node b, Node s, int, int) { + SExpr rv = n(a); + for(int i = 0; i < s.width(); i++) + rv = SExpr{"ite", to_bool(extract(n(s), i)), extract(n(b), a.width() * i, a.width()), rv}; + return rv; + } + SExpr constant(Node, RTLIL::Const value) { return literal(value); } + SExpr memory_read(Node, Node mem, Node addr, int, int) { return SExpr{"select", n(mem), n(addr)}; } + SExpr memory_write(Node, Node mem, Node addr, Node data, int, int) { return SExpr{"store", n(mem), n(addr), n(data)}; } + + SExpr input(Node, IdString name) { return input_struct.access("inputs", name); } + SExpr state(Node, IdString name) { return state_struct.access("state", name); } + + SExpr undriven(Node, int width) { return literal(RTLIL::Const(State::S0, width)); } +}; - writer.print(" )"); // Closing inputs let statement - writer.print(")\n"); // Closing step function - } +struct SmtModule { + FunctionalIR ir; + SmtScope scope; + std::string name; + + SmtStruct input_struct; + SmtStruct output_struct; + SmtStruct state_struct; + + SmtModule(Module *module) + : ir(FunctionalIR::from_module(module)) + , scope() + , name(scope.unique_name(module->name)) + , input_struct(scope.unique_name(module->name.str() + "_Inputs"), scope) + , output_struct(scope.unique_name(module->name.str() + "_Outputs"), scope) + , state_struct(scope.unique_name(module->name.str() + "_State"), scope) + { + for (const auto &input : ir.inputs()) + input_struct.insert(input.first, input.second); + for (const auto &output : ir.outputs()) + output_struct.insert(output.first, output.second); + for (const auto &state : ir.state()) + state_struct.insert(state.first, state.second); + } + + void write(std::ostream &out) + { + SExprWriter w(out); + + input_struct.write_definition(w); + output_struct.write_definition(w); + state_struct.write_definition(w); + + w << SExpr{"declare-datatypes", {{"Pair", 2}}, {{"par", {"X", "Y"}, {{"pair", {"first", "X"}, {"second", "Y"}}}}}}; + + w.push(); + w.open(SExpr{"define-fun", name, + {{"inputs", input_struct.name}, + {"state", state_struct.name}}, + {"Pair", output_struct.name, state_struct.name}}); + auto inlined = [&](FunctionalIR::Node n) { + return n.fn() == FunctionalIR::Fn::constant || + n.fn() == FunctionalIR::Fn::undriven; + }; + SmtPrintVisitor visitor(input_struct, state_struct); + auto node_to_sexpr = [&](FunctionalIR::Node n) -> SExpr { + if(inlined(n)) + return n.visit(visitor); + else + return scope(n.id(), n.name()); + }; + visitor.n = node_to_sexpr; + for(auto n : ir) + if(!inlined(n)) { + w.open(SExpr{"let", {{node_to_sexpr(n), n.visit(visitor)}}}, false); + w.comment(SmtSort(n.sort()).to_sexpr().to_string(), true); + } + w.open(SExpr{"pair"}); + output_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_output_node(name)); }); + state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_state_next_node(name)); }); + w.pop(); + } }; struct FunctionalSmtBackend : public Backend { - FunctionalSmtBackend() : Backend("functional_smt2", "Generate SMT-LIB from Functional IR") {} + FunctionalSmtBackend() : Backend("functional_smt2", "Generate SMT-LIB from Functional IR") {} - void help() override { log("\nFunctional SMT Backend.\n\n"); } + void help() override { log("\nFunctional SMT Backend.\n\n"); } - void execute(std::ostream *&f, std::string filename, std::vector args, RTLIL::Design *design) override - { - log_header(design, "Executing Functional SMT Backend.\n"); + void execute(std::ostream *&f, std::string filename, std::vector args, RTLIL::Design *design) override + { + log_header(design, "Executing Functional SMT Backend.\n"); - size_t argidx = 1; - extra_args(f, filename, args, argidx, design); + size_t argidx = 1; + extra_args(f, filename, args, argidx, design); - for (auto module : design->selected_modules()) { - log("Processing module `%s`.\n", module->name.c_str()); - auto ir = FunctionalIR::from_module(module); - SmtModule smt(RTLIL::unescape_id(module->name), ir); - smt.write(*f); - } - } + for (auto module : design->selected_modules()) { + log("Processing module `%s`.\n", module->name.c_str()); + SmtModule smt(module); + smt.write(*f); + } + } } FunctionalSmtBackend; PRIVATE_NAMESPACE_END diff --git a/kernel/functionalir.h b/kernel/functionalir.h index 3e7c55ef4ec..99fc872a81a 100644 --- a/kernel/functionalir.h +++ b/kernel/functionalir.h @@ -30,6 +30,7 @@ USING_YOSYS_NAMESPACE YOSYS_NAMESPACE_BEGIN class FunctionalIR { +public: enum class Fn { invalid, buf, @@ -69,7 +70,6 @@ class FunctionalIR { memory_read, memory_write }; -public: class Sort { std::variant> _v; public: @@ -185,6 +185,7 @@ class FunctionalIR { else return std::string("\\n") + std::to_string(id()); } + Fn fn() const { return _ref.function().fn(); } Sort sort() const { return _ref.attr().sort; } int width() const { return sort().width(); } Node arg(int n) const { return Node(_ref.arg(n)); } @@ -380,16 +381,22 @@ class FunctionalIR { }; namespace FunctionalTools { - class Scope { - const char *_illegal_characters; + template class Scope { + protected: + char substitution_character = '_'; + virtual bool is_character_legal(char) = 0; + private: pool _used_names; - dict _by_id; - dict _by_name; - std::string allocate_name(IdString suggestion) { + dict _by_id; + public: + void reserve(std::string name) { + _used_names.insert(std::move(name)); + } + std::string unique_name(IdString suggestion) { std::string str = RTLIL::unescape_id(suggestion); for(size_t i = 0; i < str.size(); i++) - if(strchr(_illegal_characters, str[i])) - str[i] = '_'; + if(!is_character_legal(str[i])) + str[i] = substitution_character; if(_used_names.count(str) == 0) { _used_names.insert(str); return str; @@ -402,32 +409,14 @@ namespace FunctionalTools { } } } - public: - Scope(const char *illegal_characters = "", const char **keywords = nullptr) { - _illegal_characters = illegal_characters; - if(keywords != nullptr) - for(const char **p = keywords; *p != nullptr; p++) - reserve(*p); - } - void reserve(std::string name) { - _used_names.insert(std::move(name)); - } - std::string operator()(int id, IdString suggestion) { + std::string operator()(Id id, IdString suggestion) { auto it = _by_id.find(id); if(it != _by_id.end()) return it->second; - std::string str = allocate_name(suggestion); + std::string str = unique_name(suggestion); _by_id.insert({id, str}); return str; } - std::string operator()(IdString idstring) { - auto it = _by_name.find(idstring); - if(it != _by_name.end()) - return it->second; - std::string str = allocate_name(idstring); - _by_name.insert({idstring, str}); - return str; - } }; class Writer { std::ostream *os;