Skip to content

Commit

Permalink
PR tensorflow#10395: Handle missing collective_permute in collective-…
Browse files Browse the repository at this point in the history
…permute-motion

Imported from GitHub PR openxla/xla#10395

Currently, the `MoveCollectivePermutes` function expects that if `FindMovableClusterAtBodyRoot` function returns a cluster, then the cluster's `collective_permute` field is not null. However, if `FindMovableClusterAtBodyRoot` does not find a collective_permute on when traversing an input, it happily returns cluster with `collective_permute == null`. This patch just returns `nullopt` for a 'cluster' without collective-permute.

Fixes tensorflow#10394.
Copybara import of the project:

--
bf69727505c62b0bba36ad28ddaad36f07f816e3 by Jaroslav Sevcik <[email protected]>:

Handle missing collective_permute

Merging this change closes tensorflow#10395

PiperOrigin-RevId: 615753210
  • Loading branch information
jaro-sevcik authored and tensorflower-gardener committed Mar 14, 2024
1 parent ae7ba31 commit 7e77bc3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/spmd/collective_permute_motion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ std::optional<MovableCluster> FindMovableClusterAtBodyRoot(
}
}
}
if (cluster.collective_permute == nullptr) {
return std::nullopt;
}
return cluster;
}

Expand Down
34 changes: 34 additions & 0 deletions third_party/xla/xla/service/spmd/collective_permute_motion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,40 @@ TEST_F(CollectivePermuteMotionTest, SimpleMove) {
EXPECT_THAT(output, op::Multiply(select, select));
}

TEST_F(CollectivePermuteMotionTest, NoCollectivePermute) {
absl::string_view hlo_string = R"(
HloModule test
body {
loop_var = (s32[], f32[], f32[]) parameter(0)
constant.1 = s32[] constant(1)
gte0 = s32[] get-tuple-element(loop_var), index=0
add = s32[] add(gte0, constant.1)
gte1 = f32[] get-tuple-element(loop_var), index=1
constant.4 = f32[] constant(4.0)
ROOT tuple = (s32[], f32[], f32[]) tuple(add, constant.4, gte1)
}
cond {
loop_var = (s32[], f32[], f32[]) parameter(0)
gte.cond = s32[] get-tuple-element(loop_var), index=0
constant.3 = s32[] constant(5)
ROOT lt = pred[] compare(gte.cond, constant.3), direction=LT
}
ENTRY main {
constant.2 = s32[] constant(0)
param = f32[] parameter(0)
param.1 = f32[] parameter(1)
tuple.1 = (s32[], f32[], f32[]) tuple(constant.2, param, param.1)
while = (s32[], f32[], f32[]) while(tuple.1), condition=cond, body=body
ROOT result = s32[] get-tuple-element(while), index=0
}
)";
// Test that the pass does not crash if there is no collective permute
// (but other conditions in FindMovableClusterAtBodyRoot are satisfied).
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
CollectivePermuteMotion pass;
ASSERT_FALSE(pass.Run(&*module).value());
}

TEST_F(CollectivePermuteMotionTest, MoveWithElementwise) {
absl::string_view hlo_string = R"(
HloModule test
Expand Down

0 comments on commit 7e77bc3

Please sign in to comment.