Skip to content

Commit

Permalink
Optimizations for UnionFind (#1334)
Browse files Browse the repository at this point in the history
Implements ranks & path compression for union find.

---------

Co-authored-by: Alexander McCord <[email protected]>
  • Loading branch information
birds3345 and alexmccord authored Jul 17, 2024
1 parent 623e1e3 commit 2874ca9
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
5 changes: 5 additions & 0 deletions EqSat/include/Luau/UnionFind.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@ struct UnionFind final
{
Id makeSet();
Id find(Id id) const;
Id find(Id id);
void merge(Id a, Id b);

private:
std::vector<Id> parents;
std::vector<int> ranks;

private:
Id canonicalize(Id id) const;
};

} // namespace Luau::EqSat
47 changes: 40 additions & 7 deletions EqSat/src/UnionFind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,59 @@ Id UnionFind::makeSet()
{
Id id{parents.size()};
parents.push_back(id);
ranks.push_back(0);

return id;
}

Id UnionFind::find(Id id) const
{
LUAU_ASSERT(size_t(id) < parents.size());
return canonicalize(id);
}

Id UnionFind::find(Id id)
{
Id set = canonicalize(id);

// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)])
id = parents[size_t(id)];

return id;
{
// Note: we don't update the ranks here since a rank
// represents the upper bound on the maximum depth of a tree
Id parent = parents[size_t(id)];
parents[size_t(id)] = set;
id = parent;
}

return set;
}

void UnionFind::merge(Id a, Id b)
{
LUAU_ASSERT(size_t(a) < parents.size());
LUAU_ASSERT(size_t(b) < parents.size());
Id aSet = find(a);
Id bSet = find(b);
if (aSet == bSet)
return;

parents[size_t(b)] = a;
// Ensure that the rank of set A is greater than the rank of set B
if (ranks[size_t(aSet)] < ranks[size_t(bSet)])
std::swap(aSet, bSet);

parents[size_t(bSet)] = aSet;

if (ranks[size_t(aSet)] == ranks[size_t(bSet)])
ranks[size_t(aSet)]++;
}

Id UnionFind::canonicalize(Id id) const
{
LUAU_ASSERT(size_t(id) < parents.size());

// An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎.
while (id != parents[size_t(id)])
id = parents[size_t(id)];

return id;
}

} // namespace Luau::EqSat

0 comments on commit 2874ca9

Please sign in to comment.