From 5468bdf88ed0ceab87d7286bfccac60612aff18e Mon Sep 17 00:00:00 2001 From: Krystine Sherwin <93062060+KrystalDelusion@users.noreply.github.com> Date: Sat, 13 Jul 2024 13:06:55 +1200 Subject: [PATCH] smtr: Make Rosette compatible Convert most of the operators, except pmux and memory. Convert formatting for non-stateful modules. --- backends/functional/smtlib_rosette.cc | 201 ++++++++++---------------- 1 file changed, 77 insertions(+), 124 deletions(-) diff --git a/backends/functional/smtlib_rosette.cc b/backends/functional/smtlib_rosette.cc index 1c71f45e8ce..d291bb7e2d2 100644 --- a/backends/functional/smtlib_rosette.cc +++ b/backends/functional/smtlib_rosette.cc @@ -107,12 +107,12 @@ template struct SmtPrintVisitor { std::string slice(Node, Node a, int, int offset, int out_width) { - return format("((_ extract %2 %1) %0)", np(a), offset, offset + out_width - 1); + return format("(extract %2 %1 %0)", np(a), offset, offset + out_width - 1); } - std::string zero_extend(Node, Node a, int, int out_width) { return format("((_ zero_extend %1) %0)", np(a), out_width - a.width()); } + std::string zero_extend(Node, Node a, int, int out_width) { return format("(zero-extend %0 (bitvector %1))", np(a), out_width); } - std::string sign_extend(Node, Node a, int, int out_width) { return format("((_ sign_extend %1) %0)", np(a), out_width - a.width()); } + std::string sign_extend(Node, Node a, int, int out_width) { return format("(sign-extend %0 (bitvector %1))", np(a), out_width); } std::string concat(Node, Node a, int, Node b, int) { return format("(concat %0 %1)", np(a), np(b)); } @@ -136,137 +136,64 @@ template struct SmtPrintVisitor { std::string unary_minus(Node, Node a, int) { return format("(bvneg %0)", np(a)); } - std::string reduce_and(Node, Node a, int) { - std::stringstream ss; - // We use ite to set the result to bit vector, to ensure appropriate type - ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '1') << ") #b1 #b0)"; - return ss.str(); - } + std::string reduce_and(Node, Node a, int) { return format("(apply bvand (bitvector->bits %0))", np(a)); } - std::string reduce_or(Node, Node a, int) - { - std::stringstream ss; - // We use ite to set the result to bit vector, to ensure appropriate type - ss << "(ite (= " << np(a) << " #b" << std::string(a.width(), '0') << ") #b0 #b1)"; - return ss.str(); - } + std::string reduce_or(Node, Node a, int) { return format("(apply bvor (bitvector->bits %0))", np(a)); } - std::string reduce_xor(Node, Node a, int) { - std::stringstream ss; - ss << "(bvxor "; - for (int i = 0; i < a.width(); ++i) { - if (i > 0) ss << " "; - ss << "((_ extract " << i << " " << i << ") " << np(a) << ")"; - } - ss << ")"; - return ss.str(); - } + std::string reduce_xor(Node, Node a, int) { return format("(apply bvxor (bitvector->bits %0))", np(a)); } std::string equal(Node, Node a, Node b, int) { - return format("(ite (= %0 %1) #b1 #b0)", np(a), np(b)); + return format("(bool->bitvector (bveq %0 %1))", np(a), np(b)); } std::string not_equal(Node, Node a, Node b, int) { - return format("(ite (distinct %0 %1) #b1 #b0)", np(a), np(b)); + return format("(bool->bitvector (not (bveq %0 %1)))", np(a), np(b)); } std::string signed_greater_than(Node, Node a, Node b, int) { - return format("(ite (bvsgt %0 %1) #b1 #b0)", np(a), np(b)); + return format("(bool->bitvector (bvsgt %0 %1))", np(a), np(b)); } std::string signed_greater_equal(Node, Node a, Node b, int) { - return format("(ite (bvsge %0 %1) #b1 #b0)", np(a), np(b)); + return format("(bool->bitvector (bvsge %0 %1))", np(a), np(b)); } std::string unsigned_greater_than(Node, Node a, Node b, int) { - return format("(ite (bvugt %0 %1) #b1 #b0)", np(a), np(b)); + return format("(bool->bitvector (bvugt %0 %1))", np(a), np(b)); } std::string unsigned_greater_equal(Node, Node a, Node b, int) { - return format("(ite (bvuge %0 %1) #b1 #b0)", np(a), np(b)); - } - - std::string logical_shift_left(Node, Node a, Node b, int, int) { - // Get the bit-widths of a and b - int bit_width_a = a.width(); - int bit_width_b = b.width(); - - // Extend b to match the bit-width of a if necessary - std::ostringstream oss; - if (bit_width_a > bit_width_b) { - oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")"; - } else { - oss << np(b); // No extension needed if b's width is already sufficient - } - std::string b_extended = oss.str(); - - // Format the bvshl operation with the extended b - oss.str(""); // Clear the stringstream - oss << "(bvshl " << np(a) << " " << b_extended << ")"; - return oss.str(); - } - - std::string logical_shift_right(Node, Node a, Node b, int, int) { - // Get the bit-widths of a and b - int bit_width_a = a.width(); - int bit_width_b = b.width(); - - // Extend b to match the bit-width of a if necessary - std::ostringstream oss; - if (bit_width_a > bit_width_b) { - oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")"; - } else { - oss << np(b); // No extension needed if b's width is already sufficient - } - std::string b_extended = oss.str(); - - // Format the bvlshr operation with the extended b - oss.str(""); // Clear the stringstream - oss << "(bvlshr " << np(a) << " " << b_extended << ")"; - return oss.str(); + return format("(bool->bitvector (bvuge %0 %1))", np(a), np(b)); } - std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { - // Get the bit-widths of a and b - int bit_width_a = a.width(); - int bit_width_b = b.width(); + std::string logical_shift_left(Node, Node a, Node b, int, int) { return format("(bvshl %0 %1)", np(a), np(b)); } - // Extend b to match the bit-width of a if necessary - std::ostringstream oss; - if (bit_width_a > bit_width_b) { - oss << "((_ zero_extend " << (bit_width_a - bit_width_b) << ") " << np(b) << ")"; - } else { - oss << np(b); // No extension needed if b's width is already sufficient - } - std::string b_extended = oss.str(); + std::string logical_shift_right(Node, Node a, Node b, int, int) { return format("(bvlshr %0 %1)", np(a), np(b)); } - // Format the bvashr operation with the extended b - oss.str(""); // Clear the stringstream - oss << "(bvashr " << np(a) << " " << b_extended << ")"; - return oss.str(); - } + std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return format("(bvashr %0 %1)", np(a), np(b)); } - std::string mux(Node, Node a, Node b, Node s, int) { - return format("(ite (= %2 #b1) %0 %1)", np(a), np(b), np(s)); - } + std::string mux(Node, Node a, Node b, Node s, int) { return format("(if %2 %0 %1)", np(a), np(b), np(s)); } + // How does pmux? std::string pmux(Node, Node a, Node b, Node s, int, int) { // Assume s is a bit vector, combine a and b based on the selection bits return format("(pmux %0 %1 %2)", np(a), np(b), np(s)); } - std::string constant(Node, RTLIL::Const value) { return format("#b%0", value.as_string()); } + std::string constant(Node, RTLIL::Const value) { return format("(bv #b%0 %1)", value.as_string(), value.size()); } std::string input(Node, IdString name) { return format("%0", scope[name]); } + // How does state? std::string state(Node, IdString name) { return format("(%0 current_state)", scope[name]); } + // How does memory? std::string memory_read(Node, Node mem, Node addr, int, int) { return format("(select %0 %1)", np(mem), np(addr)); } std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return format("(store %0 %1 %2)", np(mem), np(addr), np(data)); } - std::string undriven(Node, int width) { return format("#b%0", std::string(width, '0')); } + std::string undriven(Node, int width) { return format("(bv 0 %0)", width); } }; struct SmtModule { @@ -281,23 +208,46 @@ struct SmtModule { const bool stateful = ir.state().size() != 0; SmtWriter writer(out); - writer.print("(declare-fun %s () Bool)\n\n", name.c_str()); + // Rosette lang header + writer.print("#lang rosette\n\n"); + std::string end_part = "\n"; + std::string indent = "\t"; + + // Not sure if this is actually necessary or not, so make it optional I guess? + bool guarded = true; - writer.print("(declare-datatypes () ((Inputs (mk_inputs"); + // ??? + // writer.print("(declare-fun %s () Bool)\n\n", name.c_str()); + + // Inputs + std::stringstream input_list; + std::stringstream input_values; for (const auto &input : ir.inputs()) { - std::string input_name = scope[input.first]; - writer.print(" (%s (_ BitVec %d))", input_name.c_str(), input.second.width()); + auto input_name = scope[input.first]; + input_list << input_name << " "; + if (guarded) { + input_values << end_part << indent << indent << indent; + auto width = input.second.width(); + input_values << "(extract " << width-1 << " 0 (concat (bv 0 " << width << ") " << input_name << "))"; + } } - writer.print("))))\n\n"); + writer.print("(struct Inputs (%s)", input_list.str().c_str()); + if (guarded) { + writer.print("%s%s#:guard (lambda (%sname)%s", end_part.c_str(), indent.c_str(), input_list.str().c_str(), end_part.c_str()); + writer.print("%s%s(values%s))", indent.c_str(), indent.c_str(), input_values.str().c_str()); + } + writer.print(")\n"); - writer.print("(declare-datatypes () ((Outputs (mk_outputs"); + // Outputs + writer.print("(struct Outputs ("); for (const auto &output : ir.outputs()) { - std::string output_name = scope[output.first]; - writer.print(" (%s (_ BitVec %d))", output_name.c_str(), output.second.width()); + auto output_name = scope[output.first]; + writer.print("%s ", output_name.c_str()); } - writer.print("))))\n"); + writer.print("))\n"); if (stateful) { + // ? writer.print("(declare-datatypes () ((State (mk_state"); for (const auto &state : ir.state()) { std::string state_name = scope[state.first]; @@ -308,21 +258,21 @@ struct SmtModule { writer.print("(declare-datatypes () ((Pair (mk-pair (outputs Outputs) (next_state State)))))\n"); } - if (stateful) - writer.print("(define-fun %s_step ((current_state State) (inputs Inputs)) Pair", name.c_str()); - else - writer.print("(define-fun %s_step ((inputs Inputs)) Outputs", name.c_str()); + // Function start + writer.print("(define (%s_step inputs)%s", name.c_str(), end_part.c_str()); - writer.print(" (let ("); + // Bind inputs + writer.print("%s(let (", indent.c_str()); for (const auto &input : ir.inputs()) { - std::string input_name = scope[input.first]; - writer.print(" (%s (%s inputs))", input_name.c_str(), input_name.c_str()); + auto input_name = scope[input.first]; + writer.print("[%s (Inputs-%s inputs)] ", input_name.c_str(), input_name.c_str()); } - writer.print(" )"); + writer.print(")"); auto node_to_string = [&](FunctionalIR::Node n) { return scope[n.name()]; }; SmtPrintVisitor visitor(node_to_string, scope); + // Bind operators for (auto it = ir.begin(); it != ir.end(); ++it) { const FunctionalIR::Node &node = *it; @@ -332,10 +282,12 @@ struct SmtModule { std::string node_name = scope[node.name()]; std::string node_expr = node.visit(visitor); - writer.print(" (let ( (%s %s))", node_name.c_str(), node_expr.c_str()); + writer.print(" (let ([%s %s])", node_name.c_str(), node_expr.c_str()); } + // Bind next state if (stateful) { + // ? writer.print(" (let ( (next_state (mk_state "); for (const auto &state : ir.state()) { std::string state_name = scope[state.first]; @@ -345,7 +297,9 @@ struct SmtModule { writer.print(" )))"); } + // Bind outputs if (stateful) { + // ? writer.print(" (let ( (outputs (mk_outputs "); for (const auto &output : ir.outputs()) { std::string output_name = scope[output.first]; @@ -354,27 +308,26 @@ struct SmtModule { writer.print(" )))"); writer.print("(mk-pair outputs next_state)"); + writer.print(" )"); // Closing outputs let statement + writer.print(" )"); // Closing next_state let statement } else { - writer.print(" (mk_outputs "); + writer.print(" (Outputs "); for (const auto &output : ir.outputs()) { - std::string output_name = scope[output.first]; - writer.print(" %s", output_name.c_str()); + auto output_name = scope[output.first]; + writer.print("%s ", output_name.c_str()); } - writer.print(" )"); // Closing mk_outputs - } - if (stateful) { - writer.print(" )"); // Closing outputs let statement - writer.print(" )"); // Closing next_state let statement + writer.print(")"); // Closing outputs } + // Close the nested lets - for (size_t i = 0; i < ir.size() - ir.inputs().size(); ++i) { - writer.print(" )"); // Closing each node + for (auto i = ir.inputs().size(); i < ir.size(); ++i) { + writer.print(")"); // Closing each node } if (ir.size() == ir.inputs().size()) - writer.print(" )"); // Corner case + writer.print(")"); // Corner case - writer.print(" )"); // Closing inputs let statement + writer.print(")"); // Closing inputs let statement writer.print(")\n"); // Closing step function } };