Skip to content

Commit

Permalink
fix(pipeline): fix zero bubble pipeline parallelism (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Sep 19, 2024
1 parent 1403550 commit f951cdd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 19 deletions.
44 changes: 29 additions & 15 deletions internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class WeightGradStore:
When using zero bubble pp, WeightGradStore is used to store the args and func for computating weight grad.
"""

cache = []
weight_grad_queue = queue.Queue()

@classmethod
Expand All @@ -123,25 +124,37 @@ def size(cls):

@classmethod
def put(cls, weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args):
assert not gpc.is_first_rank(ParallelMode.PIPELINE), "pp rank 0 should not arrive here"
# Store the weight gradient computation of linear layers.
cls.weight_grad_queue.put((weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args))
cls.cache.append((weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args))

@classmethod
def flush(cls):
if gpc.is_first_rank(ParallelMode.PIPELINE):
return
# Collect all stored computations during backward as a W for each micro batch.
cls.weight_grad_queue.put(cls.cache)
cls.cache = []

@classmethod
def pop(cls):
# Run computation for a single W.
if gpc.is_first_rank(ParallelMode.PIPELINE):
return
assert cls.weight_grad_queue.qsize() > 0
weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args = cls.weight_grad_queue.get()
grad_weight, grad_bias = grad_compute_func(input_tensor, grad_output, has_d_bias)
if is_using_isp():
isp_grad_hook = args[0]
grad_weight, _ = isp_grad_hook(grad_weight, async_op=False, is_bias=False)
if grad_bias is not None:
grad_bias, _ = isp_grad_hook(grad_bias, async_op=False, is_bias=True)
stored_w_grad_computation = cls.weight_grad_queue.get()
# Run computation for a single W.
for weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args in stored_w_grad_computation:
grad_weight, grad_bias = grad_compute_func(input_tensor, grad_output, has_d_bias)
if is_using_isp():
isp_grad_hook = args[0]
grad_weight, _ = isp_grad_hook(grad_weight, async_op=False, is_bias=False)
if grad_bias is not None:
grad_bias, _ = isp_grad_hook(grad_bias, async_op=False, is_bias=True)

# Gradient Accumulation
weight.grad = weight.grad + grad_weight if weight.grad is not None else grad_weight
if has_d_bias:
bias.grad = bias.grad + grad_bias if bias.grad is not None else grad_bias
# Gradient Accumulation
weight.grad = weight.grad + grad_weight if weight.grad is not None else grad_weight
if has_d_bias:
bias.grad = bias.grad + grad_bias if bias.grad is not None else grad_bias


class PipelineScheduler(BaseScheduler):
Expand Down Expand Up @@ -951,6 +964,7 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
scatter_gather_tensors=self.scatter_gather_tensors,
)

WeightGradStore.flush()
if i >= gpc.get_local_rank(ParallelMode.PIPELINE):
WeightGradStore.pop()

Expand All @@ -976,8 +990,8 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
if not gpc.is_first_rank(ParallelMode.PIPELINE):
comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors)

if WeightGradStore.size() > 0:
WeightGradStore.pop()
WeightGradStore.flush()
WeightGradStore.pop()

while WeightGradStore.size() > 0:
WeightGradStore.pop()
Expand Down
12 changes: 8 additions & 4 deletions internlm/model/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ def backward(ctx, grad_output, *args):
handle_x.wait()

x = x.reshape(batch_dim, x.shape[-1])
if gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and gpc.config.parallel["pipeline"].get(
"zero_bubble", False
if (
gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
and gpc.config.parallel["pipeline"].get("zero_bubble", False)
and not gpc.is_first_rank(ParallelMode.PIPELINE)
):
from internlm.core.scheduler.pipeline_scheduler import WeightGradStore

Expand Down Expand Up @@ -234,8 +236,10 @@ def backward(ctx, grad_output, *args):

total_weight = communicator.weight_hook(weight, module=module)

is_using_ZB = gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and gpc.config.parallel["pipeline"].get(
"zero_bubble", False
is_using_ZB = (
gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
and gpc.config.parallel["pipeline"].get("zero_bubble", False)
and not gpc.is_first_rank(ParallelMode.PIPELINE)
)

# compute weight grad
Expand Down

0 comments on commit f951cdd

Please sign in to comment.