Skip to content

Commit

Permalink
Merge pull request #3946 from rmlarsen/toposort
Browse files Browse the repository at this point in the history
Speed up TopoSort by 2.7-3.3x.
  • Loading branch information
Ravenslofty authored Oct 17, 2023
2 parents 5f78d1d + bc0df04 commit d21c464
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 51 deletions.
130 changes: 87 additions & 43 deletions kernel/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,84 +128,128 @@ struct stackmap
// A simple class for topological sorting
// ------------------------------------------------

template<typename T, typename C = std::less<T>>
struct TopoSort
template <typename T, typename C = std::less<T>, typename OPS = hash_ops<T>> class TopoSort
{
bool analyze_loops, found_loops;
std::map<T, std::set<T, C>, C> database;
std::set<std::set<T, C>> loops;
public:
// We use this ordering of the edges in the adjacency matrix for
// exact compatibility with an older implementation.
struct IndirectCmp {
IndirectCmp(const std::vector<T> &nodes) : node_cmp_(), nodes_(nodes) {}
bool operator()(int a, int b) const
{
log_assert(static_cast<size_t>(a) < nodes_.size());
log_assert(static_cast<size_t>(b) < nodes_.size());
return node_cmp_(nodes_[a], nodes_[b]);
}
const C node_cmp_;
const std::vector<T> &nodes_;
};

bool analyze_loops;
std::map<T, int, C> node_to_index;
std::vector<std::set<int, IndirectCmp>> edges;
std::vector<T> sorted;
std::set<std::set<T, C>> loops;

TopoSort()
TopoSort() : indirect_cmp(nodes)
{
analyze_loops = true;
found_loops = false;
}

void node(T n)
int node(T n)
{
if (database.count(n) == 0)
database[n] = std::set<T, C>();
auto rv = node_to_index.emplace(n, static_cast<int>(nodes.size()));
if (rv.second) {
nodes.push_back(n);
edges.push_back(std::set<int, IndirectCmp>(indirect_cmp));
}
return rv.first->second;
}

void edge(T left, T right)
void edge(int l_index, int r_index) { edges[r_index].insert(l_index); }

void edge(T left, T right) { edge(node(left), node(right)); }

bool has_node(const T &node) { return node_to_index.find(node) != node_to_index.end(); }

bool sort()
{
node(left);
database[right].insert(left);
log_assert(GetSize(node_to_index) == GetSize(edges));
log_assert(GetSize(nodes) == GetSize(edges));

loops.clear();
sorted.clear();
found_loops = false;

std::vector<bool> marked_cells(edges.size(), false);
std::vector<bool> active_cells(edges.size(), false);
std::vector<int> active_stack;
sorted.reserve(edges.size());

for (const auto &it : node_to_index)
sort_worker(it.second, marked_cells, active_cells, active_stack);

log_assert(GetSize(sorted) == GetSize(nodes));

return !found_loops;
}

void sort_worker(const T &n, std::set<T, C> &marked_cells, std::set<T, C> &active_cells, std::vector<T> &active_stack)
// Build the more expensive representation of edges for
// a few passes that use it directly.
std::map<T, std::set<T, C>, C> get_database()
{
if (active_cells.count(n)) {
std::map<T, std::set<T, C>, C> database;
for (size_t i = 0; i < nodes.size(); ++i) {
std::set<T, C> converted_edge_set;
for (int other_node : edges[i]) {
converted_edge_set.insert(nodes[other_node]);
}
database.emplace(nodes[i], converted_edge_set);
}
return database;
}

private:
bool found_loops;
std::vector<T> nodes;
const IndirectCmp indirect_cmp;

void sort_worker(const int root_index, std::vector<bool> &marked_cells, std::vector<bool> &active_cells, std::vector<int> &active_stack)
{
if (active_cells[root_index]) {
found_loops = true;
if (analyze_loops) {
std::set<T, C> loop;
for (int i = GetSize(active_stack)-1; i >= 0; i--) {
loop.insert(active_stack[i]);
if (active_stack[i] == n)
for (int i = GetSize(active_stack) - 1; i >= 0; i--) {
const int index = active_stack[i];
loop.insert(nodes[index]);
if (index == root_index)
break;
}
loops.insert(loop);
}
return;
}

if (marked_cells.count(n))
if (marked_cells[root_index])
return;

if (!database.at(n).empty())
{
if (!edges[root_index].empty()) {
if (analyze_loops)
active_stack.push_back(n);
active_cells.insert(n);
active_stack.push_back(root_index);
active_cells[root_index] = true;

for (auto &left_n : database.at(n))
for (int left_n : edges[root_index])
sort_worker(left_n, marked_cells, active_cells, active_stack);

if (analyze_loops)
active_stack.pop_back();
active_cells.erase(n);
active_cells[root_index] = false;
}

marked_cells.insert(n);
sorted.push_back(n);
}

bool sort()
{
loops.clear();
sorted.clear();
found_loops = false;

std::set<T, C> marked_cells;
std::set<T, C> active_cells;
std::vector<T> active_stack;

for (auto &it : database)
sort_worker(it.first, marked_cells, active_cells, active_stack);

log_assert(GetSize(sorted) == GetSize(database));
return !found_loops;
marked_cells[root_index] = true;
sorted.push_back(nodes[root_index]);
}
};

Expand Down
2 changes: 1 addition & 1 deletion passes/cmds/glift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ struct GliftPass : public Pass {
for (auto cell : module->selected_cells()) {
RTLIL::Module *tpl = design->module(cell->type);
if (tpl != nullptr) {
if (topo_modules.database.count(tpl) == 0)
if (!topo_modules.has_node(tpl))
worklist.push_back(tpl);
topo_modules.edge(tpl, module);
non_top_modules.insert(cell->type);
Expand Down
16 changes: 11 additions & 5 deletions passes/opt/opt_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,13 +424,19 @@ void replace_const_cells(RTLIL::Design *design, RTLIL::Module *module, bool cons
for (auto &bit : sig)
outbit_to_cell[bit].insert(cell);
}
cells.node(cell);
cells.node(cell);
}

for (auto &it_right : cell_to_inbit)
for (auto &it_sigbit : it_right.second)
for (auto &it_left : outbit_to_cell[it_sigbit])
cells.edge(it_left, it_right.first);
// Build the graph for the topological sort.
for (auto &it_right : cell_to_inbit) {
const int r_index = cells.node(it_right.first);
for (auto &it_sigbit : it_right.second) {
for (auto &it_left : outbit_to_cell[it_sigbit]) {
const int l_index = cells.node(it_left);
cells.edge(l_index, r_index);
}
}
}

cells.sort();

Expand Down
2 changes: 1 addition & 1 deletion passes/opt/share.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ struct ShareWorker
}

bool found_scc = !toposort.sort();
topo_cell_drivers = std::move(toposort.database);
topo_cell_drivers = toposort.get_database();

if (found_scc && toposort.analyze_loops)
for (auto &loop : toposort.loops) {
Expand Down
2 changes: 1 addition & 1 deletion passes/techmap/flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ struct FlattenPass : public Pass {
for (auto cell : module->selected_cells()) {
RTLIL::Module *tpl = design->module(cell->type);
if (tpl != nullptr) {
if (topo_modules.database.count(tpl) == 0)
if (!topo_modules.has_node(tpl))
worklist.insert(tpl);
topo_modules.edge(tpl, module);
}
Expand Down

0 comments on commit d21c464

Please sign in to comment.