diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index dd886a440..559ee119a 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -13,10 +13,15 @@ struct UnionFind final { Id makeSet(); Id find(Id id) const; + Id find(Id id); void merge(Id a, Id b); private: std::vector parents; + std::vector ranks; + +private: + Id canonicalize(Id id) const; }; } // namespace Luau::EqSat diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 04d9ba743..5c01e968b 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -10,26 +10,59 @@ Id UnionFind::makeSet() { Id id{parents.size()}; parents.push_back(id); + ranks.push_back(0); + return id; } Id UnionFind::find(Id id) const { - LUAU_ASSERT(size_t(id) < parents.size()); + return canonicalize(id); +} + +Id UnionFind::find(Id id) +{ + Id set = canonicalize(id); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. while (id != parents[size_t(id)]) - id = parents[size_t(id)]; - - return id; + { + // Note: we don't update the ranks here since a rank + // represents the upper bound on the maximum depth of a tree + Id parent = parents[size_t(id)]; + parents[size_t(id)] = set; + id = parent; + } + + return set; } void UnionFind::merge(Id a, Id b) { - LUAU_ASSERT(size_t(a) < parents.size()); - LUAU_ASSERT(size_t(b) < parents.size()); + Id aSet = find(a); + Id bSet = find(b); + if (aSet == bSet) + return; - parents[size_t(b)] = a; + // Ensure that the rank of set A is greater than the rank of set B + if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + std::swap(aSet, bSet); + + parents[size_t(bSet)] = aSet; + + if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) + ranks[size_t(aSet)]++; +} + +Id UnionFind::canonicalize(Id id) const +{ + LUAU_ASSERT(size_t(id) < parents.size()); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + while (id != parents[size_t(id)]) + id = parents[size_t(id)]; + + return id; } } // namespace Luau::EqSat