From f951cdd76283130d3084d39dc268e1d64ea75e7c Mon Sep 17 00:00:00 2001 From: jiaxingli <43110891+li126com@users.noreply.github.com> Date: Thu, 19 Sep 2024 12:49:56 +0800 Subject: [PATCH] fix(pipeline): fix zero bubble pipeline parallelism (#334) --- internlm/core/scheduler/pipeline_scheduler.py | 44 ++++++++++++------- internlm/model/modules/linear.py | 12 +++-- 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 92bbc191c..a2e089143 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -115,6 +115,7 @@ class WeightGradStore: When using zero bubble pp, WeightGradStore is used to store the args and func for computating weight grad. """ + cache = [] weight_grad_queue = queue.Queue() @classmethod @@ -123,25 +124,37 @@ def size(cls): @classmethod def put(cls, weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args): + assert not gpc.is_first_rank(ParallelMode.PIPELINE), "pp rank 0 should not arrive here" # Store the weight gradient computation of linear layers. - cls.weight_grad_queue.put((weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args)) + cls.cache.append((weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args)) + + @classmethod + def flush(cls): + if gpc.is_first_rank(ParallelMode.PIPELINE): + return + # Collect all stored computations during backward as a W for each micro batch. + cls.weight_grad_queue.put(cls.cache) + cls.cache = [] @classmethod def pop(cls): - # Run computation for a single W. + if gpc.is_first_rank(ParallelMode.PIPELINE): + return assert cls.weight_grad_queue.qsize() > 0 - weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args = cls.weight_grad_queue.get() - grad_weight, grad_bias = grad_compute_func(input_tensor, grad_output, has_d_bias) - if is_using_isp(): - isp_grad_hook = args[0] - grad_weight, _ = isp_grad_hook(grad_weight, async_op=False, is_bias=False) - if grad_bias is not None: - grad_bias, _ = isp_grad_hook(grad_bias, async_op=False, is_bias=True) + stored_w_grad_computation = cls.weight_grad_queue.get() + # Run computation for a single W. + for weight, bias, input_tensor, grad_output, has_d_bias, grad_compute_func, *args in stored_w_grad_computation: + grad_weight, grad_bias = grad_compute_func(input_tensor, grad_output, has_d_bias) + if is_using_isp(): + isp_grad_hook = args[0] + grad_weight, _ = isp_grad_hook(grad_weight, async_op=False, is_bias=False) + if grad_bias is not None: + grad_bias, _ = isp_grad_hook(grad_bias, async_op=False, is_bias=True) - # Gradient Accumulation - weight.grad = weight.grad + grad_weight if weight.grad is not None else grad_weight - if has_d_bias: - bias.grad = bias.grad + grad_bias if bias.grad is not None else grad_bias + # Gradient Accumulation + weight.grad = weight.grad + grad_weight if weight.grad is not None else grad_weight + if has_d_bias: + bias.grad = bias.grad + grad_bias if bias.grad is not None else grad_bias class PipelineScheduler(BaseScheduler): @@ -951,6 +964,7 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T scatter_gather_tensors=self.scatter_gather_tensors, ) + WeightGradStore.flush() if i >= gpc.get_local_rank(ParallelMode.PIPELINE): WeightGradStore.pop() @@ -976,8 +990,8 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T if not gpc.is_first_rank(ParallelMode.PIPELINE): comm.send_backward(input_obj_grad, scatter_gather_tensors=self.scatter_gather_tensors) - if WeightGradStore.size() > 0: - WeightGradStore.pop() + WeightGradStore.flush() + WeightGradStore.pop() while WeightGradStore.size() > 0: WeightGradStore.pop() diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 8c08ab7d9..2426ab8a6 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -133,8 +133,10 @@ def backward(ctx, grad_output, *args): handle_x.wait() x = x.reshape(batch_dim, x.shape[-1]) - if gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and gpc.config.parallel["pipeline"].get( - "zero_bubble", False + if ( + gpc.is_using_parallel_mode(ParallelMode.PIPELINE) + and gpc.config.parallel["pipeline"].get("zero_bubble", False) + and not gpc.is_first_rank(ParallelMode.PIPELINE) ): from internlm.core.scheduler.pipeline_scheduler import WeightGradStore @@ -234,8 +236,10 @@ def backward(ctx, grad_output, *args): total_weight = communicator.weight_hook(weight, module=module) - is_using_ZB = gpc.is_using_parallel_mode(ParallelMode.PIPELINE) and gpc.config.parallel["pipeline"].get( - "zero_bubble", False + is_using_ZB = ( + gpc.is_using_parallel_mode(ParallelMode.PIPELINE) + and gpc.config.parallel["pipeline"].get("zero_bubble", False) + and not gpc.is_first_rank(ParallelMode.PIPELINE) ) # compute weight grad