Skip to content

Commit

Permalink
Fix potential memory issues when use deepspeed Z3 (#6726)
Browse files Browse the repository at this point in the history
I had OOM problem when doing DPO training using zero3. It needs to call
module twice in one training step, and second call is with no_grad().
The problem is caused by two bugs:
1. "__n_available_params", which helps to control fetched parameters,
becomes negative after release_and_reset_all() function.
2. module.ds_grads_remaining becomes negative in backward() if we call
module more than once in one training step.

I tried to create two patches to fix these issues.

---------

Signed-off-by: Wenbin Chen <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Hongwei Chen <[email protected]>
  • Loading branch information
4 people authored Nov 21, 2024
1 parent f515104 commit cd20a3b
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def _run_before_forward_function(input):
_run_after_backward_hook, inputs)

def _post_backward_module_hook(module, inputs):
module.ds_grads_remaining = 0
if not hasattr(module, "ds_grads_remaining"):
module.ds_grads_remaining = 0

if not hasattr(module, "post_bwd_fn"):

Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def reset_step(self) -> None:
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque())
self.__step_id = 0
self.__n_available_params = 0
self.__profiler.reset_events()

def _dump_params(self, tag, sub_module, params, step_id=None):
Expand Down Expand Up @@ -430,7 +429,7 @@ def release_and_reset_all(self, module: Module) -> None:
# there's a hook execution issue
param.ds_active_sub_modules.clear()
self.__release_param(param)

self.__n_available_params = 0
for param in iter_params(module, recurse=True):
if param.ds_status != ZeroParamStatus.NOT_AVAILABLE:
raise RuntimeError(f"{param.ds_summary()} expected to be released")
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/runtime/zero/test_zero_multiple_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed
import torch
from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import SimpleModel, random_dataloader


class TestZ3MultipleModelCall(DistributedTest):
world_size = 1

def test_z3_multiple_model_call(self):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage": 3
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
}
},
}
if preferred_dtype() is torch.float16:
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
elif preferred_dtype() is torch.bfloat16:
config_dict["bf16"] = {"enabled": True}
hidden_dim, nlayers = 2048, 3
model = SimpleModel(hidden_dim=hidden_dim, nlayers=nlayers)
model_engine, _, _, _ = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=model.parameters())
data_loader = iter(
random_dataloader(model=model_engine, total_samples=10, hidden_dim=hidden_dim, device=model_engine.device))

for n, batch in enumerate(data_loader):
loss1 = model_engine(batch[0], batch[1])
with torch.no_grad():
loss2 = model_engine(batch[0], batch[1])
loss = loss1 + loss2
model_engine.backward(loss)
for name, submodule in model_engine.module.linears._modules.items():
assert hasattr(submodule, "ds_grads_remaining"), \
f"linears.{name} does not have variable ds_grads_remaining"
assert submodule.ds_grads_remaining == 0, \
f"ds_grads_remaining of linears.{name} is not 0 ({submodule.ds_grads_remaining})"
model_engine.step()

0 comments on commit cd20a3b

Please sign in to comment.