Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[Fix] Fix the auto-sharding bug of gather (#787)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Nov 27, 2022
1 parent a194f59 commit 181de4f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
9 changes: 4 additions & 5 deletions tests/shard_parallel/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def loss_func(params):
n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir)
assert n_total == n_allreduce == 1

@unittest.skip("The support of Gather is broken after a rebase.")
def test_gather(self):

class Model(nn.Module):
Expand Down Expand Up @@ -153,10 +152,10 @@ def loss_func(params):
assert executable.auto_sharding_objective < 1e6

hlo_ir = executable.get_hlo_text()
assert "gather(f32[64,32]" in hlo_ir
assert "scatter(f32[64,32]" in hlo_ir
_, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir)
assert n_allreduce == 1
assert "gather(f32[64,32]" in hlo_ir or "gather(f32[32,64]" in hlo_ir
assert "scatter(f32[64,32]" in hlo_ir or "scatter(f32[32,64]" in hlo_ir
n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir)
assert n_total == n_allreduce == 1

def test_reshape_uneven_partition(self):
# TODO(lmzheng): Support the uneven partition of reshape.
Expand Down

0 comments on commit 181de4f

Please sign in to comment.