Skip to content

Commit

Permalink
[hotfix] fix the bug that large tensor exceed the maximum capacity of…
Browse files Browse the repository at this point in the history
… TensorBucket (hpcaitech#5879)
  • Loading branch information
Hz188 authored Jul 2, 2024
1 parent 7c2f79f commit ea94c07
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,13 @@ def step(self, closure=None):
working_param = real_working_params[group_id][idx]
param_to_gather = master_param.to(device).to(self._dtype)
pg = self.param_to_pg[working_param]
if param_to_gather.numel() > self.pg_to_tensor_bucket[pg].max_size:
buffer_tensor = torch.empty_like(
torch.cat([param_to_gather for _ in range(dist.get_world_size(pg))])
)
dist.all_gather_into_tensor(buffer_tensor, param_to_gather, pg)
working_param.data.copy_(buffer_tensor[: working_param.numel()].reshape_as(working_param))
continue
try:
self.pg_to_tensor_bucket[pg].add_to_bucket(param_to_gather, write_back_tensor=working_param)
except RuntimeError:
Expand Down

0 comments on commit ea94c07

Please sign in to comment.