From ad15301d5586d8ab82763554b2148944cb77b957 Mon Sep 17 00:00:00 2001 From: Jannis Harder Date: Thu, 11 Apr 2024 13:48:25 +0200 Subject: [PATCH] ComputeGraph datatype for the upcoming functional backend --- kernel/functional.h | 369 ++++++++++++++++++++++++++++++++++++++ passes/cmds/example_dt.cc | 178 ++++++++++++++---- 2 files changed, 515 insertions(+), 32 deletions(-) create mode 100644 kernel/functional.h diff --git a/kernel/functional.h b/kernel/functional.h new file mode 100644 index 00000000000..e5ee8824099 --- /dev/null +++ b/kernel/functional.h @@ -0,0 +1,369 @@ +/* + * yosys -- Yosys Open SYnthesis Suite + * + * Copyright (C) 2024 Jannis Harder + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + * + */ + +#ifndef FUNCTIONAL_H +#define FUNCTIONAL_H + +#include +#include "kernel/yosys.h" + +YOSYS_NAMESPACE_BEGIN + +template< + typename Fn, // Function type (deduplicated across whole graph) + typename Attr = std::tuple<>, // Call attributes (present in every node) + typename SparseAttr = std::tuple<>, // Sparse call attributes (optional per node) + typename Key = std::tuple<> // Stable keys to refer to nodes +> +struct ComputeGraph +{ + struct Ref; +private: + + // Functions are deduplicated by assigning unique ids + idict functions; + + struct Node { + int fn_index; + int arg_offset; + int arg_count; + Attr attr; + + Node(int fn_index, Attr &&attr, int arg_offset, int arg_count = 0) + : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(std::move(attr)) {} + + Node(int fn_index, Attr const &attr, int arg_offset, int arg_count = 0) + : fn_index(fn_index), arg_offset(arg_offset), arg_count(arg_count), attr(attr) {} + }; + + + std::vector nodes; + std::vector args; + dict keys_; + dict sparse_attrs; + +public: + template + struct BaseRef + { + protected: + friend struct ComputeGraph; + Graph *graph_; + int index_; + BaseRef(Graph *graph, int index) : graph_(graph), index_(index) { + log_assert(index_ >= 0); + check(); + } + + void check() const { log_assert(index_ < graph_->size()); } + + Node const &deref() const { check(); return graph_->nodes[index_]; } + + public: + ComputeGraph const &graph() const { return graph_; } + int index() const { return index_; } + + int size() const { return deref().arg_count; } + + BaseRef arg(int n) const + { + Node const &node = deref(); + log_assert(n >= 0 && n < node.arg_count); + return BaseRef(graph_, graph_->args[node.arg_offset + n]); + } + + std::vector::const_iterator arg_indices_cbegin() const + { + Node const &node = deref(); + return graph_->args.cbegin() + node.arg_offset; + } + + std::vector::const_iterator arg_indices_cend() const + { + Node const &node = deref(); + return graph_->args.cbegin() + node.arg_offset + node.arg_count; + } + + Fn const &function() const { return graph_->functions[deref().fn_index]; } + Attr const &attr() const { return deref().attr; } + + bool has_sparse_attr() const { return graph_->sparse_attrs.count(index_); } + + SparseAttr const &sparse_attr() const + { + auto found = graph_->sparse_attrs.find(index_); + log_assert(found != graph_->sparse_attrs.end()); + return *found; + } + }; + + using ConstRef = BaseRef; + + struct Ref : public BaseRef + { + private: + friend struct ComputeGraph; + Ref(ComputeGraph *graph, int index) : BaseRef(graph, index) {} + Node &deref() const { this->check(); return this->graph_->nodes[this->index_]; } + + public: + void set_function(Fn const &function) const + { + deref().fn_index = this->graph_->functions(function); + } + + Attr &attr() const { return deref().attr; } + + void append_arg(ConstRef arg) const + { + log_assert(arg.graph_ == this->graph_); + append_arg(arg.index()); + } + + void append_arg(int arg) const + { + log_assert(arg >= 0 && arg < this->graph_->size()); + Node &node = deref(); + if (node.arg_offset + node.arg_count != GetSize(this->graph_->args)) + move_args(node); + this->graph_->args.push_back(arg); + node.arg_count++; + } + + operator ConstRef() const + { + return ConstRef(this->graph_, this->index_); + } + + SparseAttr &sparse_attr() const + { + return this->graph_->sparse_attrs[this->index_]; + } + + void clear_sparse_attr() const + { + this->graph_->sparse_attrs.erase(this->index_); + } + + void assign_key(Key const &key) const + { + this->graph_->keys_.emplace(key, this->index_); + } + + private: + void move_args(Node &node) const + { + auto &args = this->graph_->args; + int old_offset = node.arg_offset; + node.arg_offset = GetSize(args); + for (int i = 0; i != node.arg_count; ++i) + args.push_back(args[old_offset + i]); + } + + }; + + bool has_key(Key const &key) const + { + return keys_.count(key); + } + + dict const &keys() const + { + return keys_; + } + + ConstRef operator()(Key const &key) const + { + auto it = keys_.find(key); + log_assert(it != keys_.end()); + return (*this)[it->second]; + } + + Ref operator()(Key const &key) + { + auto it = keys_.find(key); + log_assert(it != keys_.end()); + return (*this)[it->second]; + } + + int size() const { return GetSize(nodes); } + + ConstRef operator[](int index) const { return ConstRef(this, index); } + Ref operator[](int index) { return Ref(this, index); } + + Ref add(Fn const &function, Attr &&attr) + { + int index = GetSize(nodes); + int fn_index = functions(function); + nodes.emplace_back(fn_index, std::move(attr), GetSize(args)); + return Ref(this, index); + } + + Ref add(Fn const &function, Attr const &attr) + { + int index = GetSize(nodes); + int fn_index = functions(function); + nodes.emplace_back(fn_index, attr, GetSize(args)); + return Ref(this, index); + } + + template + Ref add(Fn const &function, Attr const &attr, T const &args) + { + Ref added = add(function, attr); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + template + Ref add(Fn const &function, Attr &&attr, T const &args) + { + Ref added = add(function, std::move(attr)); + for (auto arg : args) + added.append_arg(arg); + return added; + } + + template + Ref add(Fn const &function, Attr const &attr, T begin, T end) + { + Ref added = add(function, attr); + for (; begin != end; ++begin) + added.append_arg(*begin); + return added; + } + + void permute(std::vector const &perm) + { + log_assert(perm.size() <= nodes.size()); + std::vector inv_perm; + inv_perm.resize(nodes.size(), -1); + for (int i = 0; i < GetSize(perm); ++i) + { + int j = perm[i]; + log_assert(j >= 0 && j < GetSize(perm)); + log_assert(inv_perm[j] == -1); + inv_perm[j] = i; + } + permute(perm, inv_perm); + } + + void permute(std::vector const &perm, std::vector const &inv_perm) + { + log_assert(inv_perm.size() == nodes.size()); + std::vector new_nodes; + new_nodes.reserve(perm.size()); + dict new_sparse_attrs; + for (int i : perm) + { + int j = GetSize(new_nodes); + new_nodes.emplace_back(std::move(nodes[i])); + auto found = sparse_attrs.find(i); + if (found != sparse_attrs.end()) + new_sparse_attrs.emplace(j, std::move(found->second)); + } + + std::swap(nodes, new_nodes); + std::swap(sparse_attrs, new_sparse_attrs); + + for (int &arg : args) + { + log_assert(arg < GetSize(inv_perm)); + arg = inv_perm[arg]; + } + + for (auto &key : keys_) + { + log_assert(key.second < GetSize(inv_perm)); + key.second = inv_perm[key.second]; + } + } + + struct SccAdaptor + { + private: + ComputeGraph const &graph_; + std::vector indices_; + public: + SccAdaptor(ComputeGraph const &graph) : graph_(graph) + { + indices_.resize(graph.size(), -1); + } + + + typedef int node_type; + + struct node_enumerator { + private: + friend struct SccAdaptor; + int current, end; + node_enumerator(int current, int end) : current(current), end(end) {} + + public: + + bool finished() const { return current == end; } + node_type next() { + log_assert(!finished()); + node_type result = current; + ++current; + return result; + } + }; + + node_enumerator enumerate_nodes() { + return node_enumerator(0, GetSize(indices_)); + } + + + struct successor_enumerator { + private: + friend struct SccAdaptor; + std::vector::const_iterator current, end; + successor_enumerator(std::vector::const_iterator current, std::vector::const_iterator end) : + current(current), end(end) {} + + public: + bool finished() const { return current == end; } + node_type next() { + log_assert(!finished()); + node_type result = *current; + ++current; + return result; + } + }; + + successor_enumerator enumerate_successors(int index) const { + auto const &ref = graph_[index]; + return successor_enumerator(ref.arg_indices_cbegin(), ref.arg_indices_cend()); + } + + int &dfs_index(node_type const &node) { return indices_[node]; } + + std::vector const &dfs_indices() { return indices_; } + }; + +}; + + + +YOSYS_NAMESPACE_END + + +#endif diff --git a/passes/cmds/example_dt.cc b/passes/cmds/example_dt.cc index de84fa3cda8..dec554d6c71 100644 --- a/passes/cmds/example_dt.cc +++ b/passes/cmds/example_dt.cc @@ -1,6 +1,7 @@ #include "kernel/yosys.h" #include "kernel/drivertools.h" #include "kernel/topo_scc.h" +#include "kernel/functional.h" USING_YOSYS_NAMESPACE PRIVATE_NAMESPACE_BEGIN @@ -38,86 +39,137 @@ struct ExampleDtPass : public Pass ExampleWorker worker(module); DriverMap dm; + struct ExampleFn { + IdString name; + dict parameters; + + ExampleFn(IdString name) : name(name) {} + ExampleFn(IdString name, dict parameters) : name(name), parameters(parameters) {} + + bool operator==(ExampleFn const &other) const { + return name == other.name && parameters == other.parameters; + } + + unsigned int hash() const { + return mkhash(name.hash(), parameters.hash()); + } + }; + + typedef ComputeGraph ExampleGraph; + + ExampleGraph compute_graph; + + dm.add(module); idict queue; idict cells; IntGraph edges; + std::vector graph_nodes; + auto enqueue = [&](DriveSpec const &spec) { + int index = queue(spec); + if (index == GetSize(graph_nodes)) + graph_nodes.emplace_back(compute_graph.add(ID($pending), index).index()); + //if (index >= GetSize(graph_nodes)) + return compute_graph[graph_nodes[index]]; + }; for (auto cell : module->cells()) { if (cell->type.in(ID($assert), ID($assume), ID($cover), ID($check))) - queue(DriveBitMarker(cells(cell), 0)); + enqueue(DriveBitMarker(cells(cell), 0)); } for (auto wire : module->wires()) { if (!wire->port_output) continue; - queue(DriveChunk(DriveChunkWire(wire, 0, wire->width))); + enqueue(DriveChunk(DriveChunkWire(wire, 0, wire->width))).assign_key(wire->name); } -#define emit log -// #define emit(X...) do {} while (false) - for (int i = 0; i != GetSize(queue); ++i) { - emit("n%d: ", i); DriveSpec spec = queue[i]; + ExampleGraph::Ref node = compute_graph[i]; + if (spec.chunks().size() > 1) { - emit("concat %s <-\n", log_signal(spec)); + node.set_function(ID($$concat)); + for (auto const &chunk : spec.chunks()) { - emit(" * %s\n", log_signal(chunk)); - edges.add_edge(i, queue(chunk)); + node.append_arg(enqueue(chunk)); } } else if (spec.chunks().size() == 1) { DriveChunk chunk = spec.chunks()[0]; if (chunk.is_wire()) { DriveChunkWire wire_chunk = chunk.wire(); if (wire_chunk.is_whole()) { + node.sparse_attr() = wire_chunk.wire->name; if (wire_chunk.wire->port_input) { - emit("input %s\n", log_signal(spec)); + node.set_function(ExampleFn(ID($$input), {{wire_chunk.wire->name, {}}})); } else { DriveSpec driver = dm(DriveSpec(wire_chunk)); - edges.add_edge(i, queue(driver)); - emit("wire driver %s <- %s\n", log_signal(spec), log_signal(driver)); + node.set_function(ID($$buf)); + + node.append_arg(enqueue(driver)); } } else { DriveChunkWire whole_wire(wire_chunk.wire, 0, wire_chunk.width); - edges.add_edge(i, queue(whole_wire)); - emit("wire slice %s <- %s\n", log_signal(spec), log_signal(DriveSpec(whole_wire))); + node.set_function(ExampleFn(ID($$slice), {{ID(offset), wire_chunk.offset}, {ID(width), wire_chunk.width}})); + node.append_arg(enqueue(whole_wire)); } } else if (chunk.is_port()) { DriveChunkPort port_chunk = chunk.port(); if (port_chunk.is_whole()) { if (dm.celltypes.cell_output(port_chunk.cell->type, port_chunk.port)) { - int cell_marker = queue(DriveBitMarker(cells(port_chunk.cell), 0)); - if (!port_chunk.cell->type.in(ID($dff), ID($ff))) - edges.add_edge(i, cell_marker); - emit("cell output %s %s\n", log_id(port_chunk.cell), log_id(port_chunk.port)); + if (port_chunk.cell->type.in(ID($dff), ID($ff))) + { + Cell *cell = port_chunk.cell; + node.set_function(ExampleFn(ID($$state), {{cell->name, {}}})); + for (auto const &conn : cell->connections()) { + if (!dm.celltypes.cell_input(cell->type, conn.first)) + continue; + enqueue(DriveChunkPort(cell, conn)).assign_key(cell->name); + } + } + else + { + node.set_function(ExampleFn(ID($$cell_output), {{port_chunk.port, {}}})); + node.append_arg(enqueue(DriveBitMarker(cells(port_chunk.cell), 0))); + } } else { + node.set_function(ID($$buf)); + DriveSpec driver = dm(DriveSpec(port_chunk)); - edges.add_edge(i, queue(driver)); - emit("cell port driver %s <- %s\n", log_signal(spec), log_signal(driver)); + node.append_arg(enqueue(driver)); } } else { DriveChunkPort whole_port(port_chunk.cell, port_chunk.port, 0, GetSize(port_chunk.cell->connections().at(port_chunk.port))); - edges.add_edge(i, queue(whole_port)); - emit("port slice %s <- %s\n", log_signal(spec), log_signal(DriveSpec(whole_port))); + node.set_function(ID($$buf)); + node.append_arg(enqueue(whole_port)); } } else if (chunk.is_constant()) { - emit("constant %s <- %s\n", log_signal(spec), log_const(chunk.constant())); + node.set_function(ExampleFn(ID($$const), {{ID(value), chunk.constant()}})); + + } else if (chunk.is_multiple()) { + node.set_function(ID($$multi)); + for (auto const &driver : chunk.multiple().multiple()) + node.append_arg(enqueue(driver)); } else if (chunk.is_marker()) { Cell *cell = cells[chunk.marker().marker]; - emit("cell %s %s\n", log_id(cell->type), log_id(cell)); + + node.set_function(ExampleFn(cell->type, cell->parameters)); for (auto const &conn : cell->connections()) { if (!dm.celltypes.cell_input(cell->type, conn.first)) continue; - emit(" * %s <- %s\n", log_id(conn.first), log_signal(conn.second)); - edges.add_edge(i, queue(DriveChunkPort(cell, conn))); + + node.append_arg(enqueue(DriveChunkPort(cell, conn))); } + } else if (chunk.is_none()) { + node.set_function(ID($$undriven)); + } else { + log_error("unhandled drivespec: %s\n", log_signal(chunk)); log_abort(); } } else { @@ -125,13 +177,75 @@ struct ExampleDtPass : public Pass } } - topo_sorted_sccs(edges, [&](int *begin, int *end) { - emit("scc:"); - for (int *i = begin; i != end; ++i) - emit(" n%d", *i); - emit("\n"); - }); + // Perform topo sort and detect SCCs + ExampleGraph::SccAdaptor compute_graph_scc(compute_graph); + + + std::vector perm; + topo_sorted_sccs(compute_graph_scc, [&](int *begin, int *end) { + perm.insert(perm.end(), begin, end); + if (end > begin + 1) + { + log_warning("SCC:"); + for (int *i = begin; i != end; ++i) + log(" %d", *i); + log("\n"); + } + }, /* sources_first */ true); + compute_graph.permute(perm); + + + // Forward $$buf unless we have a name in the sparse attribute + std::vector alias; + perm.clear(); + + for (int i = 0; i < compute_graph.size(); ++i) + { + if (compute_graph[i].function().name == ID($$buf) && !compute_graph[i].has_sparse_attr() && compute_graph[i].arg(0).index() < i) + { + + alias.push_back(alias[compute_graph[i].arg(0).index()]); + } + else + { + alias.push_back(GetSize(perm)); + perm.push_back(i); + } + } + compute_graph.permute(perm, alias); + + // Dump the compute graph + for (int i = 0; i < compute_graph.size(); ++i) + { + auto ref = compute_graph[i]; + log("n%d ", i); + log("%s", log_id(ref.function().name)); + for (auto const ¶m : ref.function().parameters) + { + if (param.second.empty()) + log("[%s]", log_id(param.first)); + else + log("[%s=%s]", log_id(param.first), log_const(param.second)); + } + log("("); + + for (int i = 0, end = ref.size(); i != end; ++i) + { + if (i > 0) + log(", "); + log("n%d", ref.arg(i).index()); + } + log(")\n"); + if (ref.has_sparse_attr()) + log("// wire %s\n", log_id(ref.sparse_attr())); + log("// was #%d %s\n", ref.attr(), log_signal(queue[ref.attr()])); + } + + for (auto const &key : compute_graph.keys()) + { + log("return %d as %s \n", key.second, log_id(key.first)); + } } log("Plugin test passed!\n"); }