From 7cc6653c69c19cd42da05e6a8712159a146407e7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:16:42 +0000 Subject: [PATCH 1/4] memory efficient async linear --- .../parallel/tensor_parallel/functional.py | 48 ++++++------------- 1 file changed, 15 insertions(+), 33 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 47c0b5a1..1a82254e 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -149,14 +149,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = torch.empty( - gathered_batch_size, - *intermediate_size, - hidden_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=tensor.requires_grad, - ) + gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -261,7 +254,7 @@ def backward(ctx, grad_output): use_bias = ctx.use_bias tp_mode = ctx.tp_mode - handle: Optional[dist.Work] = None + handle1: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape @@ -273,14 +266,8 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = torch.empty( - unsharded_batch_size, - *rest_size, - device=tensor.device, - dtype=tensor.dtype, - requires_grad=False, - ) - handle = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) + unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation total_tensor = unsharded_tensor @@ -289,9 +276,6 @@ def backward(ctx, grad_output): grad_tensor = grad_output.matmul(weight) - if handle is not None: - handle.wait() - # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only # clones it if it's not contiguous: @@ -303,7 +287,7 @@ def backward(ctx, grad_output): grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim) total_tensor = total_tensor.view(math.prod(total_tensor_first_dims), total_tensor_last_dim) - handle: Optional[dist.Work] = None + handle2: Optional[dist.Work] = None if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: if group.size() == 1: sub_grad_tensor = grad_tensor @@ -312,23 +296,27 @@ def backward(ctx, grad_output): tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter - handle = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) + handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # reduce scatter is scheduled before the weight gradient computation elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: # Asynchronous all-reduce - handle = dist.all_reduce(grad_tensor, group=group, async_op=True) + handle2 = dist.all_reduce(grad_tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # all-reduce is scheduled before the weight gradient computation else: raise ValueError() + grad_bias = grad_output.sum(dim=0) if use_bias else None + + if handle1 is not None: + handle1.wait() + # TODO @thomasw21: This sounds like we don't have the optimal physical layout grad_weight = grad_output.t().matmul(total_tensor) - grad_bias = grad_output.sum(dim=0) if use_bias else None - if handle is not None: - handle.wait() + if handle2 is not None: + handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return sub_grad_tensor, grad_weight, grad_bias, None, None @@ -472,13 +460,7 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = torch.empty( - unsharded_batch_size, - *rest_size, - device=grad_output.device, - dtype=grad_output.dtype, - requires_grad=False, - ) + total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From cb0f2609e357747a0a3001dd9b12c649f9e6eef7 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:17:09 +0000 Subject: [PATCH 2/4] precommit --- .../parallel/tensor_parallel/functional.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 1a82254e..3821d544 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,14 +19,13 @@ from torch.nn import functional as F import nanotron.distributed as dist -from nanotron.parallel.utils import MemoryBuffer from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, differentiable_reduce_scatter_sum, ) from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1 +from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1 class _ShardedCrossEntropy(torch.autograd.Function): @@ -149,7 +148,9 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get("allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype) + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -266,7 +267,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - unsharded_tensor = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + unsharded_tensor = MemoryBuffer().get( + "allgather", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) handle1 = dist.all_gather_into_tensor(unsharded_tensor, tensor, group=group, async_op=True) # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # gather is scheduled before the tensor gradient computation @@ -399,7 +402,6 @@ def backward(ctx, grad_output: torch.Tensor): return sub_grad_input, grad_weight, grad_bias, None, None - def column_linear( input: torch.Tensor, weight: torch.Tensor, @@ -460,7 +462,9 @@ def backward(ctx, grad_output): else: unsharded_batch_size = sharded_batch_size * group.size() - total_grad_output = MemoryBuffer().get("allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype) + total_grad_output = MemoryBuffer().get( + "allgather2", (unsharded_batch_size, *rest_size), dtype=tensor.dtype + ) # Doing gather + slicing during the NeMo forward pass can make this tensor # not be contiguous. PyTorch only checks if the tensor is contiguous, and only From 6d85d038d52ffc06ec4f2ae4705deae3b05d25d8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 18 Jul 2024 11:46:39 +0000 Subject: [PATCH 3/4] Added no_recompute_allgather mode to async --- .../parallel/tensor_parallel/functional.py | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 3821d544..29480f9a 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -120,10 +120,12 @@ class _ColumnLinearAsyncCommunication(torch.autograd.Function): @staticmethod @assert_cuda_max_connections_set_to_1 - def forward(ctx, tensor, weight, bias, group, tp_mode): + def forward(ctx, tensor, weight, bias, group, tp_mode, tp_recompute_allgather): ctx.use_bias = bias is not None ctx.tp_mode = tp_mode ctx.group = group + ctx.tp_recompute_allgather = tp_recompute_allgather + ctx.tensor_shape = tensor.size() if tp_mode is TensorParallelLinearMode.ALL_REDUCE: gathered_tensor = tensor @@ -140,7 +142,7 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # `tensor` can sometimes not be contiguous # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L317 tensor = tensor.contiguous() - ctx.save_for_backward(tensor, weight) + # ctx.save_for_backward(tensor, weight) # TODO @thomasw21: gather along another dimension sharded_batch_size, *intermediate_size, hidden_size = tensor.shape @@ -148,9 +150,19 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): group = dist.distributed_c10d._get_default_group() gathered_batch_size = sharded_batch_size * group.size() - gathered_tensor = MemoryBuffer().get( - "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype - ) + if tp_recompute_allgather: + gathered_tensor = MemoryBuffer().get( + "allgather", (gathered_batch_size, *intermediate_size, hidden_size), dtype=tensor.dtype + ) + else: + gathered_tensor = torch.empty( + gathered_batch_size, + *intermediate_size, + hidden_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) handle = dist.all_gather_into_tensor(gathered_tensor, tensor, group=group, async_op=True) @@ -198,6 +210,10 @@ def forward(ctx, tensor, weight, bias, group, tp_mode): # Wait communication handle.wait() + if tp_recompute_allgather: + ctx.save_for_backward(tensor, weight) + else: + ctx.save_for_backward(gathered_tensor, weight) # Compute all the other shards that are obtained from AllGather # weights: w0 w1 w2 w3 @@ -256,7 +272,7 @@ def backward(ctx, grad_output): tp_mode = ctx.tp_mode handle1: Optional[dist.Work] = None - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER and ctx.tp_recompute_allgather: # TODO @thomasw21: gather along another dimension sharded_batch_size, *rest_size = tensor.shape if group is None: @@ -296,7 +312,7 @@ def backward(ctx, grad_output): sub_grad_tensor = grad_tensor else: sub_grad_tensor = torch.empty( - tensor.shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False + ctx.tensor_shape, dtype=grad_tensor.dtype, device=grad_tensor.device, requires_grad=False ) # reduce_scatter handle2 = dist.reduce_scatter_tensor(sub_grad_tensor, grad_tensor, group=group, async_op=True) @@ -322,9 +338,9 @@ def backward(ctx, grad_output): handle2.wait() if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: - return sub_grad_tensor, grad_weight, grad_bias, None, None + return sub_grad_tensor, grad_weight, grad_bias, None, None, None elif tp_mode is TensorParallelLinearMode.ALL_REDUCE: - return grad_tensor, grad_weight, grad_bias, None, None + return grad_tensor, grad_weight, grad_bias, None, None, None else: raise ValueError(f"Got unexpected mode: {tp_mode}.") @@ -412,7 +428,7 @@ def column_linear( tp_recompute_allgather: bool = True, ): if async_communication: - return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) + return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: input = differentiable_identity(input, group=group) From 4c94b99a8cafd5f8dc2b7f208341a17a4d818234 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 30 Jul 2024 19:34:31 +0200 Subject: [PATCH 4/4] Added tp_recompute_allgather test --- tests/test_tensor_parallel.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 8e73973b..16008eaa 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -164,15 +164,32 @@ def _test_column_linear( @pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)]) @pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode)) @pytest.mark.parametrize("async_communication", [False, True]) +@pytest.mark.parametrize("tp_recompute_allgather", [False, True]) @rerun_if_address_is_in_use() -def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool): +def test_row_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: pytest.skip("ALL_REDUCE mode does not support async communication") + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather: + pytest.skip("ALL_REDUCE mode is not affected by tp_recompute_allgather") - init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication) + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather + ) -def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool): +def _test_row_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + tp_recompute_allgather: bool, +): if async_communication: os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" out_features = 3