Skip to content

Commit

Permalink
change smtlib backend to use list() function instead of SExpr{} const…
Browse files Browse the repository at this point in the history
…ructor (leads to weird constructor overloading resolution issues)
  • Loading branch information
aiju committed Jul 10, 2024
1 parent af83d8d commit e786722
Showing 1 changed file with 55 additions and 50 deletions.
105 changes: 55 additions & 50 deletions backends/functional/smtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ struct SmtScope : public FunctionalTools::Scope<int> {
};

class SExpr {
public:
std::variant<std::vector<SExpr>, std::string> _v;
public:
SExpr(std::string a) : _v(std::move(a)) {}
SExpr(const char *a) : _v(a) {}
SExpr(int n) : _v(std::to_string(n)) {}
SExpr(std::vector<SExpr> const &a) : _v(std::in_place_index<0>, a) {}
SExpr(std::vector<SExpr> &&a) : _v(std::in_place_index<0>, std::move(a)) {}
SExpr(std::initializer_list<SExpr> a) : _v(std::in_place_index<0>, a) {}
SExpr(std::vector<SExpr> const &l) : _v(l) {}
SExpr(std::vector<SExpr> &&l) : _v(std::move(l)) {}
bool is_atom() const { return std::holds_alternative<std::string>(_v); }
std::string const &atom() const { return std::get<std::string>(_v); }
bool is_list() const { return std::holds_alternative<std::vector<SExpr>>(_v); }
Expand All @@ -81,6 +81,9 @@ class SExpr {
return ss.str();
}
};
template<typename... Args> SExpr list(Args&&... args) {
return SExpr(std::vector<SExpr>{std::forward<Args>(args)...});
}

class SExprWriter {
std::ostream &os;
Expand Down Expand Up @@ -209,9 +212,9 @@ struct SmtSort {
SmtSort(FunctionalIR::Sort sort) : sort(sort) {}
SExpr to_sexpr() const {
if(sort.is_memory()) {
return SExpr{"Array", {"_", "BitVec", sort.addr_width()}, {"_", "BitVec", sort.data_width()}};
return list("Array", list("_", "BitVec", sort.addr_width()), list("_", "BitVec", sort.data_width()));
} else if(sort.is_signal()) {
return SExpr{"_", "BitVec", sort.width()};
return list("_", "BitVec", sort.width());
} else {
log_error("unknown sort");
}
Expand All @@ -235,15 +238,15 @@ class SmtStruct {
fields.emplace_back(Field{sort, accessor});
}
void write_definition(SExprWriter &w) {
w.open(SExpr{"declare-datatype", name});
w.open(SExpr({}));
w.open(SExpr{name});
w.open(list("declare-datatype", name));
w.open(list());
w.open(list(name));
for(const auto &field : fields)
w << SExpr{field.accessor, field.sort.to_sexpr()};
w << list(field.accessor, field.sort.to_sexpr());
w.close(3);
}
template<typename Fn> void write_value(SExprWriter &w, Fn fn) {
w.open(SExpr(std::initializer_list<SExpr>{name}));
w.open(list(name));
for(auto field_name : field_names) {
w << fn(field_name);
w.comment(RTLIL::unescape_id(field_name), true);
Expand All @@ -252,7 +255,7 @@ class SmtStruct {
}
SExpr access(SExpr record, IdString name) {
size_t i = field_names.at(name);
return SExpr{fields[i].accessor, std::move(record)};
return list(fields[i].accessor, std::move(record));
}
};

Expand All @@ -272,64 +275,64 @@ struct SmtPrintVisitor {
}

SExpr from_bool(SExpr &&arg) {
return SExpr{"ite", std::move(arg), "#b1", "#b0"};
return list("ite", std::move(arg), "#b1", "#b0");
}
SExpr to_bool(SExpr &&arg) {
return SExpr{"=", std::move(arg), "#b1"};
return list("=", std::move(arg), "#b1");
}
SExpr extract(SExpr &&arg, int offset, int out_width = 1) {
return SExpr{{"_", "extract", offset + out_width - 1, offset}, std::move(arg)};
return list(list("_", "extract", offset + out_width - 1, offset), std::move(arg));
}

SExpr buf(Node, Node a) { return n(a); }
SExpr slice(Node, Node a, int, int offset, int out_width) { return extract(n(a), offset, out_width); }
SExpr zero_extend(Node, Node a, int, int out_width) { return SExpr{{"_", "zero_extend", out_width - a.width()}, n(a)}; }
SExpr sign_extend(Node, Node a, int, int out_width) { return SExpr{{"_", "sign_extend", out_width - a.width()}, n(a)}; }
SExpr concat(Node, Node a, int, Node b, int) { return SExpr{"concat", n(a), n(b)}; }
SExpr add(Node, Node a, Node b, int) { return SExpr{"bvadd", n(a), n(b)}; }
SExpr sub(Node, Node a, Node b, int) { return SExpr{"bvsub", n(a), n(b)}; }
SExpr mul(Node, Node a, Node b, int) { return SExpr{"bvmul", n(a), n(b)}; }
SExpr unsigned_div(Node, Node a, Node b, int) { return SExpr{"bvudiv", n(a), n(b)}; }
SExpr unsigned_mod(Node, Node a, Node b, int) { return SExpr{"bvurem", n(a), n(b)}; }
SExpr bitwise_and(Node, Node a, Node b, int) { return SExpr{"bvand", n(a), n(b)}; }
SExpr bitwise_or(Node, Node a, Node b, int) { return SExpr{"bvor", n(a), n(b)}; }
SExpr bitwise_xor(Node, Node a, Node b, int) { return SExpr{"bvxor", n(a), n(b)}; }
SExpr bitwise_not(Node, Node a, int) { return SExpr{"bvnot", n(a)}; }
SExpr unary_minus(Node, Node a, int) { return SExpr{"bvneg", n(a)}; }
SExpr reduce_and(Node, Node a, int) { return from_bool(SExpr{"=", n(a), literal(RTLIL::Const(State::S1, a.width()))}); }
SExpr reduce_or(Node, Node a, int) { return from_bool(SExpr{"=", n(a), literal(RTLIL::Const(State::S0, a.width()))}); }
SExpr zero_extend(Node, Node a, int, int out_width) { return list(list("_", "zero_extend", out_width - a.width()), n(a)); }
SExpr sign_extend(Node, Node a, int, int out_width) { return list(list("_", "sign_extend", out_width - a.width()), n(a)); }
SExpr concat(Node, Node a, int, Node b, int) { return list("concat", n(a), n(b)); }
SExpr add(Node, Node a, Node b, int) { return list("bvadd", n(a), n(b)); }
SExpr sub(Node, Node a, Node b, int) { return list("bvsub", n(a), n(b)); }
SExpr mul(Node, Node a, Node b, int) { return list("bvmul", n(a), n(b)); }
SExpr unsigned_div(Node, Node a, Node b, int) { return list("bvudiv", n(a), n(b)); }
SExpr unsigned_mod(Node, Node a, Node b, int) { return list("bvurem", n(a), n(b)); }
SExpr bitwise_and(Node, Node a, Node b, int) { return list("bvand", n(a), n(b)); }
SExpr bitwise_or(Node, Node a, Node b, int) { return list("bvor", n(a), n(b)); }
SExpr bitwise_xor(Node, Node a, Node b, int) { return list("bvxor", n(a), n(b)); }
SExpr bitwise_not(Node, Node a, int) { return list("bvnot", n(a)); }
SExpr unary_minus(Node, Node a, int) { return list("bvneg", n(a)); }
SExpr reduce_and(Node, Node a, int) { return from_bool(list("=", n(a), literal(RTLIL::Const(State::S1, a.width())))); }
SExpr reduce_or(Node, Node a, int) { return from_bool(list("=", n(a), literal(RTLIL::Const(State::S0, a.width())))); }
SExpr reduce_xor(Node, Node a, int) {
vector<SExpr> s { "bvxor" };
for(int i = 0; i < a.width(); i++)
s.push_back(extract(n(a), i));
return s;
}
SExpr equal(Node, Node a, Node b, int) { return from_bool(SExpr{"=", n(a), n(b)}); }
SExpr not_equal(Node, Node a, Node b, int) { return from_bool(SExpr{"distinct", n(a), n(b)}); }
SExpr signed_greater_than(Node, Node a, Node b, int) { return from_bool(SExpr{"bvsgt", n(a), n(b)}); }
SExpr signed_greater_equal(Node, Node a, Node b, int) { return from_bool(SExpr{"bvsge", n(a), n(b)}); }
SExpr unsigned_greater_than(Node, Node a, Node b, int) { return from_bool(SExpr{"bvugt", n(a), n(b)}); }
SExpr unsigned_greater_equal(Node, Node a, Node b, int) { return from_bool(SExpr{"bvuge", n(a), n(b)}); }
SExpr equal(Node, Node a, Node b, int) { return from_bool(list("=", n(a), n(b))); }
SExpr not_equal(Node, Node a, Node b, int) { return from_bool(list("distinct", n(a), n(b))); }
SExpr signed_greater_than(Node, Node a, Node b, int) { return from_bool(list("bvsgt", n(a), n(b))); }
SExpr signed_greater_equal(Node, Node a, Node b, int) { return from_bool(list("bvsge", n(a), n(b))); }
SExpr unsigned_greater_than(Node, Node a, Node b, int) { return from_bool(list("bvugt", n(a), n(b))); }
SExpr unsigned_greater_equal(Node, Node a, Node b, int) { return from_bool(list("bvuge", n(a), n(b))); }

SExpr extend(SExpr &&a, int in_width, int out_width) {
if(in_width < out_width)
return SExpr{{"_", "zero_extend", out_width - in_width}, std::move(a)};
return list(list("_", "zero_extend", out_width - in_width), std::move(a));
else
return std::move(a);
}
SExpr logical_shift_left(Node, Node a, Node b, int, int) { return SExpr{"bvshl", n(a), extend(n(b), b.width(), a.width())}; }
SExpr logical_shift_right(Node, Node a, Node b, int, int) { return SExpr{"bvshr", n(a), extend(n(b), b.width(), a.width())}; }
SExpr arithmetic_shift_right(Node, Node a, Node b, int, int) { return SExpr{"bvasr", n(a), extend(n(b), b.width(), a.width())}; }
SExpr mux(Node, Node a, Node b, Node s, int) { return SExpr{"ite", to_bool(n(s)), n(a), n(b)}; }
SExpr logical_shift_left(Node, Node a, Node b, int, int) { return list("bvshl", n(a), extend(n(b), b.width(), a.width())); }
SExpr logical_shift_right(Node, Node a, Node b, int, int) { return list("bvshr", n(a), extend(n(b), b.width(), a.width())); }
SExpr arithmetic_shift_right(Node, Node a, Node b, int, int) { return list("bvasr", n(a), extend(n(b), b.width(), a.width())); }
SExpr mux(Node, Node a, Node b, Node s, int) { return list("ite", to_bool(n(s)), n(a), n(b)); }
SExpr pmux(Node, Node a, Node b, Node s, int, int) {
SExpr rv = n(a);
for(int i = 0; i < s.width(); i++)
rv = SExpr{"ite", to_bool(extract(n(s), i)), extract(n(b), a.width() * i, a.width()), rv};
rv = list("ite", to_bool(extract(n(s), i)), extract(n(b), a.width() * i, a.width()), rv);
return rv;
}
SExpr constant(Node, RTLIL::Const value) { return literal(value); }
SExpr memory_read(Node, Node mem, Node addr, int, int) { return SExpr{"select", n(mem), n(addr)}; }
SExpr memory_write(Node, Node mem, Node addr, Node data, int, int) { return SExpr{"store", n(mem), n(addr), n(data)}; }
SExpr memory_read(Node, Node mem, Node addr, int, int) { return list("select", n(mem), n(addr)); }
SExpr memory_write(Node, Node mem, Node addr, Node data, int, int) { return list("store", n(mem), n(addr), n(data)); }

SExpr input(Node, IdString name) { return input_struct.access("inputs", name); }
SExpr state(Node, IdString name) { return state_struct.access("state", name); }
Expand Down Expand Up @@ -370,13 +373,15 @@ struct SmtModule {
output_struct.write_definition(w);
state_struct.write_definition(w);

w << SExpr{"declare-datatypes", {{"Pair", 2}}, {{"par", {"X", "Y"}, {{"pair", {"first", "X"}, {"second", "Y"}}}}}};
w << list("declare-datatypes",
list(list("Pair", 2)),
list(list("par", list("X", "Y"), list(list("pair", list("first", "X"), list("second", "Y"))))));

w.push();
w.open(SExpr{"define-fun", name,
{{"inputs", input_struct.name},
{"state", state_struct.name}},
{"Pair", output_struct.name, state_struct.name}});
w.open(list("define-fun", name,
list(list("inputs", input_struct.name),
list("state", state_struct.name)),
list("Pair", output_struct.name, state_struct.name)));
auto inlined = [&](FunctionalIR::Node n) {
return n.fn() == FunctionalIR::Fn::constant ||
n.fn() == FunctionalIR::Fn::undriven;
Expand All @@ -391,10 +396,10 @@ struct SmtModule {
visitor.n = node_to_sexpr;
for(auto n : ir)
if(!inlined(n)) {
w.open(SExpr{"let", {{node_to_sexpr(n), n.visit(visitor)}}}, false);
w.open(list("let", list(list(node_to_sexpr(n), n.visit(visitor)))), false);
w.comment(SmtSort(n.sort()).to_sexpr().to_string(), true);
}
w.open(SExpr{"pair"});
w.open(list("pair"));
output_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_output_node(name)); });
state_struct.write_value(w, [&](IdString name) { return node_to_sexpr(ir.get_state_next_node(name)); });
w.pop();
Expand Down

0 comments on commit e786722

Please sign in to comment.