diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 214d3c78..715e8554 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -589,9 +589,9 @@ def generate_schedule_sequence(self, mod): else: continue if op.node.op == "call_module": - print(f'sch["{op.node.target}"].shard("weight", dim={dim})') + print(f'sch["{op.node.target}"].shard("weight", axis={dim})') if dim == 0: - print(f'sch["{op.node.target}"].shard("bias", dim={dim})') + print(f'sch["{op.node.target}"].shard("bias", axis={dim})') if ( results[f"{name}_0"] == ShardSpec("RS").id and results[f"{name}_1"] == ShardSpec("SR").id diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py index dfb41df1..30cf113a 100644 --- a/tests/test_autoshard.py +++ b/tests/test_autoshard.py @@ -81,7 +81,8 @@ def test_attn(): logger.info(subsch.mod.graph, ranks=0) sol = Solver(subsch.mod, p=p) - results, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) + _, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) + assert max_cost == 3 * (bs * seq_len * hidden_size / p) + 2 if __name__ == "__main__":