Skip to content

Commit

Permalink
document functionalir.h and change visitors to derive from AbstractVi…
Browse files Browse the repository at this point in the history
…sitor. remove extraneous widths arguments from visitors.
  • Loading branch information
aiju committed Jul 17, 2024
1 parent 5800801 commit b4b36ea
Show file tree
Hide file tree
Showing 5 changed files with 400 additions and 170 deletions.
72 changes: 36 additions & 36 deletions backends/functional/cxx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ struct CxxStruct {
}
};

template<class NodePrinter> struct CxxPrintVisitor {
template<class NodePrinter> struct CxxPrintVisitor : public FunctionalIR::AbstractVisitor<void> {
using Node = FunctionalIR::Node;
CxxWriter &f;
NodePrinter np;
Expand All @@ -108,36 +108,36 @@ template<class NodePrinter> struct CxxPrintVisitor {
template<typename... Args> void print(const char *fmt, Args&&... args) {
f.print_with(np, fmt, std::forward<Args>(args)...);
}
void buf(Node, Node n) { print("{}", n); }
void slice(Node, Node a, int, int offset, int out_width) { print("{0}.slice<{2}>({1})", a, offset, out_width); }
void zero_extend(Node, Node a, int, int out_width) { print("{}.zero_extend<{}>()", a, out_width); }
void sign_extend(Node, Node a, int, int out_width) { print("{}.sign_extend<{}>()", a, out_width); }
void concat(Node, Node a, int, Node b, int) { print("{}.concat({})", a, b); }
void add(Node, Node a, Node b, int) { print("{} + {}", a, b); }
void sub(Node, Node a, Node b, int) { print("{} - {}", a, b); }
void mul(Node, Node a, Node b, int) { print("{} * {}", a, b); }
void unsigned_div(Node, Node a, Node b, int) { print("{} / {}", a, b); }
void unsigned_mod(Node, Node a, Node b, int) { print("{} % {}", a, b); }
void bitwise_and(Node, Node a, Node b, int) { print("{} & {}", a, b); }
void bitwise_or(Node, Node a, Node b, int) { print("{} | {}", a, b); }
void bitwise_xor(Node, Node a, Node b, int) { print("{} ^ {}", a, b); }
void bitwise_not(Node, Node a, int) { print("~{}", a); }
void unary_minus(Node, Node a, int) { print("-{}", a); }
void reduce_and(Node, Node a, int) { print("{}.all()", a); }
void reduce_or(Node, Node a, int) { print("{}.any()", a); }
void reduce_xor(Node, Node a, int) { print("{}.parity()", a); }
void equal(Node, Node a, Node b, int) { print("{} == {}", a, b); }
void not_equal(Node, Node a, Node b, int) { print("{} != {}", a, b); }
void signed_greater_than(Node, Node a, Node b, int) { print("{}.signed_greater_than({})", a, b); }
void signed_greater_equal(Node, Node a, Node b, int) { print("{}.signed_greater_equal({})", a, b); }
void unsigned_greater_than(Node, Node a, Node b, int) { print("{} > {}", a, b); }
void unsigned_greater_equal(Node, Node a, Node b, int) { print("{} >= {}", a, b); }
void logical_shift_left(Node, Node a, Node b, int, int) { print("{} << {}", a, b); }
void logical_shift_right(Node, Node a, Node b, int, int) { print("{} >> {}", a, b); }
void arithmetic_shift_right(Node, Node a, Node b, int, int) { print("{}.arithmetic_shift_right({})", a, b); }
void mux(Node, Node a, Node b, Node s, int) { print("{2}.any() ? {1} : {0}", a, b, s); }
void pmux(Node, Node a, Node b, Node s, int, int) { print("{0}.pmux({1}, {2})", a, b, s); }
void constant(Node, RTLIL::Const value) {
void buf(Node, Node n) override { print("{}", n); }
void slice(Node, Node a, int offset, int out_width) override { print("{0}.slice<{2}>({1})", a, offset, out_width); }
void zero_extend(Node, Node a, int out_width) override { print("{}.zero_extend<{}>()", a, out_width); }
void sign_extend(Node, Node a, int out_width) override { print("{}.sign_extend<{}>()", a, out_width); }
void concat(Node, Node a, Node b) override { print("{}.concat({})", a, b); }
void add(Node, Node a, Node b) override { print("{} + {}", a, b); }
void sub(Node, Node a, Node b) override { print("{} - {}", a, b); }
void mul(Node, Node a, Node b) override { print("{} * {}", a, b); }
void unsigned_div(Node, Node a, Node b) override { print("{} / {}", a, b); }
void unsigned_mod(Node, Node a, Node b) override { print("{} % {}", a, b); }
void bitwise_and(Node, Node a, Node b) override { print("{} & {}", a, b); }
void bitwise_or(Node, Node a, Node b) override { print("{} | {}", a, b); }
void bitwise_xor(Node, Node a, Node b) override { print("{} ^ {}", a, b); }
void bitwise_not(Node, Node a) override { print("~{}", a); }
void unary_minus(Node, Node a) override { print("-{}", a); }
void reduce_and(Node, Node a) override { print("{}.all()", a); }
void reduce_or(Node, Node a) override { print("{}.any()", a); }
void reduce_xor(Node, Node a) override { print("{}.parity()", a); }
void equal(Node, Node a, Node b) override { print("{} == {}", a, b); }
void not_equal(Node, Node a, Node b) override { print("{} != {}", a, b); }
void signed_greater_than(Node, Node a, Node b) override { print("{}.signed_greater_than({})", a, b); }
void signed_greater_equal(Node, Node a, Node b) override { print("{}.signed_greater_equal({})", a, b); }
void unsigned_greater_than(Node, Node a, Node b) override { print("{} > {}", a, b); }
void unsigned_greater_equal(Node, Node a, Node b) override { print("{} >= {}", a, b); }
void logical_shift_left(Node, Node a, Node b) override { print("{} << {}", a, b); }
void logical_shift_right(Node, Node a, Node b) override { print("{} >> {}", a, b); }
void arithmetic_shift_right(Node, Node a, Node b) override { print("{}.arithmetic_shift_right({})", a, b); }
void mux(Node, Node a, Node b, Node s) override { print("{2}.any() ? {1} : {0}", a, b, s); }
void pmux(Node, Node a, Node b, Node s) override { print("{0}.pmux({1}, {2})", a, b, s); }
void constant(Node, RTLIL::Const value) override {
std::stringstream ss;
bool multiple = value.size() > 32;
ss << "Signal<" << value.size() << ">(" << std::hex << std::showbase;
Expand All @@ -151,11 +151,11 @@ template<class NodePrinter> struct CxxPrintVisitor {
ss << ")";
print("{}", ss.str());
}
void input(Node, IdString name) { print("input.{}", input_struct[name]); }
void state(Node, IdString name) { print("current_state.{}", state_struct[name]); }
void memory_read(Node, Node mem, Node addr, int, int) { print("{}.read({})", mem, addr); }
void memory_write(Node, Node mem, Node addr, Node data, int, int) { print("{}.write({}, {})", mem, addr, data); }
void undriven(Node, int width) { print("Signal<{}>(0)", width); }
void input(Node, IdString name) override { print("input.{}", input_struct[name]); }
void state(Node, IdString name) override { print("current_state.{}", state_struct[name]); }
void memory_read(Node, Node mem, Node addr) override { print("{}.read({})", mem, addr); }
void memory_write(Node, Node mem, Node addr, Node data) override { print("{}.write({}, {})", mem, addr, data); }
void undriven(Node, int width) override { print("Signal<{}>(0)", width); }
};

struct CxxModule {
Expand Down
72 changes: 36 additions & 36 deletions backends/functional/smtlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class SmtStruct {
}
};

struct SmtPrintVisitor {
struct SmtPrintVisitor : public FunctionalIR::AbstractVisitor<SExpr> {
using Node = FunctionalIR::Node;
std::function<SExpr(Node)> n;
SmtStruct &input_struct;
Expand All @@ -134,60 +134,60 @@ struct SmtPrintVisitor {
return list(list("_", "extract", offset + out_width - 1, offset), std::move(arg));
}

SExpr buf(Node, Node a) { return n(a); }
SExpr slice(Node, Node a, int, int offset, int out_width) { return extract(n(a), offset, out_width); }
SExpr zero_extend(Node, Node a, int, int out_width) { return list(list("_", "zero_extend", out_width - a.width()), n(a)); }
SExpr sign_extend(Node, Node a, int, int out_width) { return list(list("_", "sign_extend", out_width - a.width()), n(a)); }
SExpr concat(Node, Node a, int, Node b, int) { return list("concat", n(b), n(a)); }
SExpr add(Node, Node a, Node b, int) { return list("bvadd", n(a), n(b)); }
SExpr sub(Node, Node a, Node b, int) { return list("bvsub", n(a), n(b)); }
SExpr mul(Node, Node a, Node b, int) { return list("bvmul", n(a), n(b)); }
SExpr unsigned_div(Node, Node a, Node b, int) { return list("bvudiv", n(a), n(b)); }
SExpr unsigned_mod(Node, Node a, Node b, int) { return list("bvurem", n(a), n(b)); }
SExpr bitwise_and(Node, Node a, Node b, int) { return list("bvand", n(a), n(b)); }
SExpr bitwise_or(Node, Node a, Node b, int) { return list("bvor", n(a), n(b)); }
SExpr bitwise_xor(Node, Node a, Node b, int) { return list("bvxor", n(a), n(b)); }
SExpr bitwise_not(Node, Node a, int) { return list("bvnot", n(a)); }
SExpr unary_minus(Node, Node a, int) { return list("bvneg", n(a)); }
SExpr reduce_and(Node, Node a, int) { return from_bool(list("=", n(a), literal(RTLIL::Const(State::S1, a.width())))); }
SExpr reduce_or(Node, Node a, int) { return from_bool(list("distinct", n(a), literal(RTLIL::Const(State::S0, a.width())))); }
SExpr reduce_xor(Node, Node a, int) {
SExpr buf(Node, Node a) override { return n(a); }
SExpr slice(Node, Node a, int offset, int out_width) override { return extract(n(a), offset, out_width); }
SExpr zero_extend(Node, Node a, int out_width) override { return list(list("_", "zero_extend", out_width - a.width()), n(a)); }
SExpr sign_extend(Node, Node a, int out_width) override { return list(list("_", "sign_extend", out_width - a.width()), n(a)); }
SExpr concat(Node, Node a, Node b) override { return list("concat", n(b), n(a)); }
SExpr add(Node, Node a, Node b) override { return list("bvadd", n(a), n(b)); }
SExpr sub(Node, Node a, Node b) override { return list("bvsub", n(a), n(b)); }
SExpr mul(Node, Node a, Node b) override { return list("bvmul", n(a), n(b)); }
SExpr unsigned_div(Node, Node a, Node b) override { return list("bvudiv", n(a), n(b)); }
SExpr unsigned_mod(Node, Node a, Node b) override { return list("bvurem", n(a), n(b)); }
SExpr bitwise_and(Node, Node a, Node b) override { return list("bvand", n(a), n(b)); }
SExpr bitwise_or(Node, Node a, Node b) override { return list("bvor", n(a), n(b)); }
SExpr bitwise_xor(Node, Node a, Node b) override { return list("bvxor", n(a), n(b)); }
SExpr bitwise_not(Node, Node a) override { return list("bvnot", n(a)); }
SExpr unary_minus(Node, Node a) override { return list("bvneg", n(a)); }
SExpr reduce_and(Node, Node a) override { return from_bool(list("=", n(a), literal(RTLIL::Const(State::S1, a.width())))); }
SExpr reduce_or(Node, Node a) override { return from_bool(list("distinct", n(a), literal(RTLIL::Const(State::S0, a.width())))); }
SExpr reduce_xor(Node, Node a) override {
vector<SExpr> s { "bvxor" };
for(int i = 0; i < a.width(); i++)
s.push_back(extract(n(a), i));
return s;
}
SExpr equal(Node, Node a, Node b, int) { return from_bool(list("=", n(a), n(b))); }
SExpr not_equal(Node, Node a, Node b, int) { return from_bool(list("distinct", n(a), n(b))); }
SExpr signed_greater_than(Node, Node a, Node b, int) { return from_bool(list("bvsgt", n(a), n(b))); }
SExpr signed_greater_equal(Node, Node a, Node b, int) { return from_bool(list("bvsge", n(a), n(b))); }
SExpr unsigned_greater_than(Node, Node a, Node b, int) { return from_bool(list("bvugt", n(a), n(b))); }
SExpr unsigned_greater_equal(Node, Node a, Node b, int) { return from_bool(list("bvuge", n(a), n(b))); }
SExpr equal(Node, Node a, Node b) override { return from_bool(list("=", n(a), n(b))); }
SExpr not_equal(Node, Node a, Node b) override { return from_bool(list("distinct", n(a), n(b))); }
SExpr signed_greater_than(Node, Node a, Node b) override { return from_bool(list("bvsgt", n(a), n(b))); }
SExpr signed_greater_equal(Node, Node a, Node b) override { return from_bool(list("bvsge", n(a), n(b))); }
SExpr unsigned_greater_than(Node, Node a, Node b) override { return from_bool(list("bvugt", n(a), n(b))); }
SExpr unsigned_greater_equal(Node, Node a, Node b) override { return from_bool(list("bvuge", n(a), n(b))); }

SExpr extend(SExpr &&a, int in_width, int out_width) {
if(in_width < out_width)
return list(list("_", "zero_extend", out_width - in_width), std::move(a));
else
return std::move(a);
}
SExpr logical_shift_left(Node, Node a, Node b, int, int) { return list("bvshl", n(a), extend(n(b), b.width(), a.width())); }
SExpr logical_shift_right(Node, Node a, Node b, int, int) { return list("bvlshr", n(a), extend(n(b), b.width(), a.width())); }
SExpr arithmetic_shift_right(Node, Node a, Node b, int, int) { return list("bvashr", n(a), extend(n(b), b.width(), a.width())); }
SExpr mux(Node, Node a, Node b, Node s, int) { return list("ite", to_bool(n(s)), n(b), n(a)); }
SExpr pmux(Node, Node a, Node b, Node s, int, int) {
SExpr logical_shift_left(Node, Node a, Node b) override { return list("bvshl", n(a), extend(n(b), b.width(), a.width())); }
SExpr logical_shift_right(Node, Node a, Node b) override { return list("bvlshr", n(a), extend(n(b), b.width(), a.width())); }
SExpr arithmetic_shift_right(Node, Node a, Node b) override { return list("bvashr", n(a), extend(n(b), b.width(), a.width())); }
SExpr mux(Node, Node a, Node b, Node s) override { return list("ite", to_bool(n(s)), n(b), n(a)); }
SExpr pmux(Node, Node a, Node b, Node s) override {
SExpr rv = n(a);
for(int i = 0; i < s.width(); i++)
rv = list("ite", to_bool(extract(n(s), i)), extract(n(b), a.width() * i, a.width()), rv);
return rv;
}
SExpr constant(Node, RTLIL::Const value) { return literal(value); }
SExpr memory_read(Node, Node mem, Node addr, int, int) { return list("select", n(mem), n(addr)); }
SExpr memory_write(Node, Node mem, Node addr, Node data, int, int) { return list("store", n(mem), n(addr), n(data)); }
SExpr constant(Node, RTLIL::Const value) override { return literal(value); }
SExpr memory_read(Node, Node mem, Node addr) override { return list("select", n(mem), n(addr)); }
SExpr memory_write(Node, Node mem, Node addr, Node data) override { return list("store", n(mem), n(addr), n(data)); }

SExpr input(Node, IdString name) { return input_struct.access("inputs", name); }
SExpr state(Node, IdString name) { return state_struct.access("state", name); }
SExpr input(Node, IdString name) override { return input_struct.access("inputs", name); }
SExpr state(Node, IdString name) override { return state_struct.access("state", name); }

SExpr undriven(Node, int width) { return literal(RTLIL::Const(State::S0, width)); }
SExpr undriven(Node, int width) override { return literal(RTLIL::Const(State::S0, width)); }
};

struct SmtModule {
Expand Down
2 changes: 1 addition & 1 deletion kernel/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct ComputeGraph
{
auto found = graph_->sparse_attrs.find(index_);
log_assert(found != graph_->sparse_attrs.end());
return *found;
return found->second;
}
};

Expand Down
Loading

0 comments on commit b4b36ea

Please sign in to comment.