diff --git a/backends/functional/smtlib_rosette.cc b/backends/functional/smtlib_rosette.cc index d291bb7e2d2..312fdba18a0 100644 --- a/backends/functional/smtlib_rosette.cc +++ b/backends/functional/smtlib_rosette.cc @@ -172,26 +172,13 @@ template struct SmtPrintVisitor { std::string arithmetic_shift_right(Node, Node a, Node b, int, int) { return format("(bvashr %0 %1)", np(a), np(b)); } - std::string mux(Node, Node a, Node b, Node s, int) { return format("(if %2 %0 %1)", np(a), np(b), np(s)); } - - // How does pmux? - std::string pmux(Node, Node a, Node b, Node s, int, int) - { - // Assume s is a bit vector, combine a and b based on the selection bits - return format("(pmux %0 %1 %2)", np(a), np(b), np(s)); - } + std::string mux(Node, Node a, Node b, Node s, int) { return format("(if (bitvector->bool %2) %1 %0)", np(a), np(b), np(s)); } std::string constant(Node, RTLIL::Const value) { return format("(bv #b%0 %1)", value.as_string(), value.size()); } std::string input(Node, IdString name) { return format("%0", scope[name]); } - // How does state? - std::string state(Node, IdString name) { return format("(%0 current_state)", scope[name]); } - - // How does memory? - std::string memory_read(Node, Node mem, Node addr, int, int) { return format("(select %0 %1)", np(mem), np(addr)); } - - std::string memory_write(Node, Node mem, Node addr, Node data, int, int) { return format("(store %0 %1 %2)", np(mem), np(addr), np(data)); } + std::string state(Node, IdString name) { return format("%0", scope[name]); } std::string undriven(Node, int width) { return format("(bv 0 %0)", width); } }; @@ -238,6 +225,27 @@ struct SmtModule { } writer.print(")\n"); + if (stateful) { + // State + std::stringstream state_list; + std::stringstream state_values; + for (const auto &state : ir.state()) { + auto state_name = scope[state.first]; + state_list << state_name << " "; + if (guarded) { + state_values << end_part << indent << indent << indent; + auto width = state.second.width(); + state_values << "(extract " << width-1 << " 0 (concat (bv 0 " << width << ") " << state_name << "))"; + } + } + writer.print("(struct State (%s)", state_list.str().c_str()); + if (guarded) { + writer.print("%s%s#:guard (lambda (%sname)%s", end_part.c_str(), indent.c_str(), state_list.str().c_str(), end_part.c_str()); + writer.print("%s%s(values%s))", indent.c_str(), indent.c_str(), state_values.str().c_str()); + } + writer.print(")\n"); + } + // Outputs writer.print("(struct Outputs ("); for (const auto &output : ir.outputs()) { @@ -246,20 +254,10 @@ struct SmtModule { } writer.print("))\n"); - if (stateful) { - // ? - writer.print("(declare-datatypes () ((State (mk_state"); - for (const auto &state : ir.state()) { - std::string state_name = scope[state.first]; - writer.print(" (%s (_ BitVec %d))", state_name.c_str(), state.second.width()); - } - writer.print("))))\n"); - - writer.print("(declare-datatypes () ((Pair (mk-pair (outputs Outputs) (next_state State)))))\n"); - } - // Function start - writer.print("(define (%s_step inputs)%s", name.c_str(), end_part.c_str()); + writer.print("(define (%s_step inputs", name.c_str()); + if (stateful) { writer.print(" current_state"); } + writer.print(")%s", end_part.c_str()); // Bind inputs writer.print("%s(let (", indent.c_str()); @@ -267,51 +265,49 @@ struct SmtModule { auto input_name = scope[input.first]; writer.print("[%s (Inputs-%s inputs)] ", input_name.c_str(), input_name.c_str()); } + // Bind states + for (const auto &state : ir.state()) { + auto state_name = scope[state.first]; + writer.print("[%s (State-%s current_state)] ", state_name.c_str(), state_name.c_str()); + } writer.print(")"); auto node_to_string = [&](FunctionalIR::Node n) { return scope[n.name()]; }; SmtPrintVisitor visitor(node_to_string, scope); + auto depth = 1; // Bind operators for (auto it = ir.begin(); it != ir.end(); ++it) { const FunctionalIR::Node &node = *it; - if (ir.inputs().count(node.name()) > 0) + // Skip input and state binds + if (ir.inputs().count(node.name()) > 0 || ir.state().count(node.name()) > 0) continue; std::string node_name = scope[node.name()]; std::string node_expr = node.visit(visitor); - writer.print(" (let ([%s %s])", node_name.c_str(), node_expr.c_str()); + writer.print(end_part.c_str()); + writer.print(std::string(++depth, indent.c_str()[0]).c_str()); + writer.print("(let ([%s %s])", node_name.c_str(), node_expr.c_str()); } - // Bind next state if (stateful) { - // ? - writer.print(" (let ( (next_state (mk_state "); + // Bind outputs and next state + writer.print(" (cons (Outputs"); + for (const auto &output : ir.outputs()) { + std::string output_name = scope[output.first]; + writer.print(" %s", output_name.c_str()); + } + writer.print(") (State"); for (const auto &state : ir.state()) { std::string state_name = scope[state.first]; const std::string state_assignment = ir.get_state_next_node(state.first).name().c_str(); writer.print(" %s", state_assignment.substr(1).c_str()); } - writer.print(" )))"); - } - - // Bind outputs - if (stateful) { - // ? - writer.print(" (let ( (outputs (mk_outputs "); - for (const auto &output : ir.outputs()) { - std::string output_name = scope[output.first]; - writer.print(" %s", output_name.c_str()); - } - writer.print(" )))"); - - writer.print("(mk-pair outputs next_state)"); - writer.print(" )"); // Closing outputs let statement - writer.print(" )"); // Closing next_state let statement - } - else { + writer.print("))"); + } else { + // Bind outputs writer.print(" (Outputs "); for (const auto &output : ir.outputs()) { auto output_name = scope[output.first]; @@ -321,13 +317,10 @@ struct SmtModule { } // Close the nested lets - for (auto i = ir.inputs().size(); i < ir.size(); ++i) { + for (auto i = 0; i < depth; ++i) { writer.print(")"); // Closing each node } - if (ir.size() == ir.inputs().size()) - writer.print(")"); // Corner case - writer.print(")"); // Closing inputs let statement writer.print(")\n"); // Closing step function } };