diff --git a/tests/shard_parallel/test_basic.py b/tests/shard_parallel/test_basic.py index d1ce7ef25..670954f1c 100644 --- a/tests/shard_parallel/test_basic.py +++ b/tests/shard_parallel/test_basic.py @@ -116,7 +116,6 @@ def loss_func(params): n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir) assert n_total == n_allreduce == 1 - @unittest.skip("The support of Gather is broken after a rebase.") def test_gather(self): class Model(nn.Module): @@ -153,10 +152,10 @@ def loss_func(params): assert executable.auto_sharding_objective < 1e6 hlo_ir = executable.get_hlo_text() - assert "gather(f32[64,32]" in hlo_ir - assert "scatter(f32[64,32]" in hlo_ir - _, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir) - assert n_allreduce == 1 + assert "gather(f32[64,32]" in hlo_ir or "gather(f32[32,64]" in hlo_ir + assert "scatter(f32[64,32]" in hlo_ir or "scatter(f32[32,64]" in hlo_ir + n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir) + assert n_total == n_allreduce == 1 def test_reshape_uneven_partition(self): # TODO(lmzheng): Support the uneven partition of reshape. diff --git a/third_party/tensorflow-alpa b/third_party/tensorflow-alpa index 272dc9ebb..cd865615b 160000 --- a/third_party/tensorflow-alpa +++ b/third_party/tensorflow-alpa @@ -1 +1 @@ -Subproject commit 272dc9ebbeaedcfe452e114e551a2aa45e604030 +Subproject commit cd865615b9b518bc507fbdc71dc44c7cc76618ac