From 06c99b75de334060e586c62f6f30c5e6b1002d4b Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 17 Jul 2024 01:21:07 -0400 Subject: [PATCH 1/8] optimize unionfind --- EqSat/include/Luau/UnionFind.h | 2 ++ EqSat/src/UnionFind.cpp | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index dd886a440..4f9e56fd9 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -13,10 +13,12 @@ struct UnionFind final { Id makeSet(); Id find(Id id) const; + Id find(Id id); void merge(Id a, Id b); private: std::vector parents; + std::vector ranks; }; } // namespace Luau::EqSat diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 04d9ba743..f42b8a174 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -10,6 +10,8 @@ Id UnionFind::makeSet() { Id id{parents.size()}; parents.push_back(id); + ranks.push_back(0); + return id; } @@ -24,12 +26,37 @@ Id UnionFind::find(Id id) const return id; } +Id UnionFind::find(Id id) +{ + LUAU_ASSERT(size_t(id) < parents.size()); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + if (id != parents[size_t(id)]) + // Note: we don't update the ranks here since a rank + // represents the upper bound on the maximum depth of a tree + parents[size_t(id)] = find(parents[size_t(id)]); + + return parents[size_t(id)]; +} + void UnionFind::merge(Id a, Id b) { LUAU_ASSERT(size_t(a) < parents.size()); LUAU_ASSERT(size_t(b) < parents.size()); + + Id aSet = find(a); + Id bSet = find(b); + if (aSet == bSet) + return; + + // Ensure that the rank of set A is greater than the rank of set B + if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) + std::swap(a, b); parents[size_t(b)] = a; + + if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) + ranks[size_t(aSet)]++; } } // namespace Luau::EqSat From 59de424f9664bffd7d313cf11ebaca16263ddde0 Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 17 Jul 2024 13:37:49 -0400 Subject: [PATCH 2/8] remove asserts Co-authored-by: Alexander McCord <11488393+alexmccord@users.noreply.github.com> --- EqSat/src/UnionFind.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index f42b8a174..7b1cb670f 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -41,9 +41,6 @@ Id UnionFind::find(Id id) void UnionFind::merge(Id a, Id b) { - LUAU_ASSERT(size_t(a) < parents.size()); - LUAU_ASSERT(size_t(b) < parents.size()); - Id aSet = find(a); Id bSet = find(b); if (aSet == bSet) From 9f07cd5e71907bc72024455a80c14c7d118cf645 Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 17 Jul 2024 13:40:23 -0400 Subject: [PATCH 3/8] use correct variables --- EqSat/src/UnionFind.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 7b1cb670f..536de2710 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -48,9 +48,9 @@ void UnionFind::merge(Id a, Id b) // Ensure that the rank of set A is greater than the rank of set B if (ranks[size_t(aSet)] < ranks[size_t(bSet)]) - std::swap(a, b); + std::swap(aSet, bSet); - parents[size_t(b)] = a; + parents[size_t(bSet)] = aSet; if (ranks[size_t(aSet)] == ranks[size_t(bSet)]) ranks[size_t(aSet)]++; From aef6dd715bf45731b5bef39888a73ca5c75f295e Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 17 Jul 2024 13:46:00 -0400 Subject: [PATCH 4/8] make find iterative --- EqSat/src/UnionFind.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 536de2710..ec6386282 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -30,13 +30,18 @@ Id UnionFind::find(Id id) { LUAU_ASSERT(size_t(id) < parents.size()); + Id set = const_cast(this)->find(id); + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. - if (id != parents[size_t(id)]) + while (id != parents[size_t(id)]) + { // Note: we don't update the ranks here since a rank // represents the upper bound on the maximum depth of a tree - parents[size_t(id)] = find(parents[size_t(id)]); + parents[size_t(id)] = set; + id = parents[size_t(id)]; + } - return parents[size_t(id)]; + return set; } void UnionFind::merge(Id a, Id b) From 9b401ae1d6942e6f0ff6a69deacccc173482ec56 Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 17 Jul 2024 13:58:05 -0400 Subject: [PATCH 5/8] fix mistake --- EqSat/src/UnionFind.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index ec6386282..09418a4a7 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -37,8 +37,9 @@ Id UnionFind::find(Id id) { // Note: we don't update the ranks here since a rank // represents the upper bound on the maximum depth of a tree + Id parent = parents[size_t(id)]; parents[size_t(id)] = set; - id = parents[size_t(id)]; + id = parent; } return set; From 1cb6973100480ecf2183b88f975574e33bdd0aad Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:09:02 -0400 Subject: [PATCH 6/8] remove const_cast --- EqSat/include/Luau/UnionFind.h | 3 +++ EqSat/src/UnionFind.cpp | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index 4f9e56fd9..e319b5f1a 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -19,6 +19,9 @@ struct UnionFind final private: std::vector parents; std::vector ranks; + +private: + Id canonicalize(Id id) const; }; } // namespace Luau::EqSat diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index 09418a4a7..d6f13a045 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -17,20 +17,14 @@ Id UnionFind::makeSet() Id UnionFind::find(Id id) const { - LUAU_ASSERT(size_t(id) < parents.size()); - - // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. - while (id != parents[size_t(id)]) - id = parents[size_t(id)]; - - return id; + return canonicalize(id); } Id UnionFind::find(Id id) { LUAU_ASSERT(size_t(id) < parents.size()); - Id set = const_cast(this)->find(id); + Id set = canonicalize(id); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. while (id != parents[size_t(id)]) @@ -62,4 +56,15 @@ void UnionFind::merge(Id a, Id b) ranks[size_t(aSet)]++; } +Id UnionFind::canonicalize(Id id) const +{ + LUAU_ASSERT(size_t(id) < parents.size()); + + // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. + while (id != parents[size_t(id)]) + id = parents[size_t(id)]; + + return id; +} + } // namespace Luau::EqSat From db3fcb41a26abaa4693effe40c0c0657fd58adc7 Mon Sep 17 00:00:00 2001 From: birds3345 <31601136+birds3345@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:10:31 -0400 Subject: [PATCH 7/8] remove assert --- EqSat/src/UnionFind.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/EqSat/src/UnionFind.cpp b/EqSat/src/UnionFind.cpp index d6f13a045..5c01e968b 100644 --- a/EqSat/src/UnionFind.cpp +++ b/EqSat/src/UnionFind.cpp @@ -22,8 +22,6 @@ Id UnionFind::find(Id id) const Id UnionFind::find(Id id) { - LUAU_ASSERT(size_t(id) < parents.size()); - Id set = canonicalize(id); // An e-class id 𝑎 is canonical iff find(𝑎) = 𝑎. From d27477caa2e56ccb1e4123c4aa1f018d3567b7d8 Mon Sep 17 00:00:00 2001 From: Alexander McCord <11488393+alexmccord@users.noreply.github.com> Date: Wed, 17 Jul 2024 16:10:26 -0700 Subject: [PATCH 8/8] Update EqSat/include/Luau/UnionFind.h --- EqSat/include/Luau/UnionFind.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/EqSat/include/Luau/UnionFind.h b/EqSat/include/Luau/UnionFind.h index e319b5f1a..559ee119a 100644 --- a/EqSat/include/Luau/UnionFind.h +++ b/EqSat/include/Luau/UnionFind.h @@ -18,7 +18,7 @@ struct UnionFind final private: std::vector parents; - std::vector ranks; + std::vector ranks; private: Id canonicalize(Id id) const;