Skip to content

Commit

Permalink
Fix test and output
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 committed May 26, 2023
1 parent 0b1de58 commit ef2c565
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions slapo/sharding/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ def generate_schedule_sequence(self, mod):
else:
continue
if op.node.op == "call_module":
print(f'sch["{op.node.target}"].shard("weight", dim={dim})')
print(f'sch["{op.node.target}"].shard("weight", axis={dim})')
if dim == 0:
print(f'sch["{op.node.target}"].shard("bias", dim={dim})')
print(f'sch["{op.node.target}"].shard("bias", axis={dim})')
if (
results[f"{name}_0"] == ShardSpec("RS").id
and results[f"{name}_1"] == ShardSpec("SR").id
Expand Down
3 changes: 2 additions & 1 deletion tests/test_autoshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_attn():
logger.info(subsch.mod.graph, ranks=0)

sol = Solver(subsch.mod, p=p)
results, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)])
_, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)])
assert max_cost == 3 * (bs * seq_len * hidden_size / p) + 2


if __name__ == "__main__":
Expand Down

0 comments on commit ef2c565

Please sign in to comment.