Skip to content

Commit

Permalink
Merge pull request huggingface#203 from AleHD/fix_tp_mem_cache
Browse files Browse the repository at this point in the history
Fix tp mem cache
  • Loading branch information
3outeille authored Aug 2, 2024
2 parents 2793c92 + 31c3c5a commit 4eb520f
Show file tree
Hide file tree
Showing 9 changed files with 166 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __post_init__(self):
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class ParallelismArgs:
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

def __post_init__(self):
Expand Down
6 changes: 4 additions & 2 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
Expand All @@ -164,8 +165,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)
self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act))

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
Expand Down Expand Up @@ -316,6 +316,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
Expand Down Expand Up @@ -765,6 +766,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableReduceScatterSum.apply(grad_output, group), None
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None


class DifferentiableReduceScatterSum(torch.autograd.Function):
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
Expand Down
103 changes: 93 additions & 10 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@

import nanotron.distributed as dist
from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import (
differentiable_all_gather,
differentiable_all_reduce_sum,
differentiable_identity,
differentiable_reduce_scatter_sum,
)
from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode
from nanotron.parallel.utils import assert_cuda_max_connections_set_to_1
from nanotron.parallel.utils import MemoryBuffer, assert_cuda_max_connections_set_to_1


class _ShardedCrossEntropy(torch.autograd.Function):
Expand Down Expand Up @@ -89,10 +88,10 @@ def forward(

@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors

# All the inputs have softmax as thier gradient.
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
sharded_hidden_size = softmax.size()[-1]
Expand Down Expand Up @@ -338,25 +337,109 @@ def backward(ctx, grad_output):
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _ColumnLinearNoAsyncCommunicationReduceScatterMode(torch.autograd.Function):
"""
Column linear with memory_buffer for the allgather, context parallel
enabled (i.e. tp_mode = TensorParallelLinearMode.REDUCE_SCATTER) and
async communication disabled.
"""

@staticmethod
def forward(
ctx,
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_recompute_allgather: bool,
):

# Do allgather.
sharded_batch_size, *rest_size = input.shape
unsharded_batch_size = sharded_batch_size * group.size()
if group.size() == 1:
total_input = input.contiguous()
elif tp_recompute_allgather:
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)
else:
total_input = torch.empty(unsharded_batch_size, *rest_size, dtype=input.dtype, device=input.device)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Prepare context.
ctx.group = group
ctx.tp_recompute_allgather = tp_recompute_allgather
ctx.input_size = input.shape
if tp_recompute_allgather:
ctx.save_for_backward(input, weight, bias)
else:
ctx.save_for_backward(total_input, weight, bias)

# Get linear output.
out = F.linear(total_input, weight, bias)
return out

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Either allgather the inputs again or get them from context.
group = ctx.group
tp_recompute_allgather = ctx.tp_recompute_allgather
input_size = ctx.input_size
if group.size() == 1 or not tp_recompute_allgather:
total_input, weight, bias = ctx.saved_tensors
else:
input, weight, bias = ctx.saved_tensors
sharded_batch_size, *rest_size = input.shape
total_input = sharded_batch_size * group.size()
unsharded_batch_size = sharded_batch_size * group.size()
total_input = MemoryBuffer().get("allgather", (unsharded_batch_size, *rest_size), dtype=input.dtype)
dist.all_gather_into_tensor(total_input, input.contiguous(), group=group)

# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.contiguous()
grad_output_first_dims, grad_output_last_dim = grad_output.shape[:-1], grad_output.shape[-1]
total_input_first_dims, total_input_last_dim = total_input.shape[:-1], total_input.shape[-1]
grad_output = grad_output.view(math.prod(grad_output_first_dims), grad_output_last_dim)
total_input = total_input.view(math.prod(total_input_first_dims), total_input_last_dim)

# Compute gradients.
grad_weight = grad_output.T @ total_input
grad_input = grad_output @ weight
if group.size() == 1:
sub_grad_input = grad_input
else:
# Seems that `reduce_scatter` need contiguous tensors: https://github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
# We set grad_input to be contiguous in case it isn't already.
grad_input = grad_input.contiguous()
sub_grad_input = torch.empty(
input_size, dtype=total_input.dtype, device=total_input.device, requires_grad=False
)
dist.reduce_scatter_tensor(sub_grad_input, grad_input, group=group, op=dist.ReduceOp.SUM)
grad_bias = torch.sum(grad_output, dim=0) if bias is not None else None

return sub_grad_input, grad_weight, grad_bias, None, None


def column_linear(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
group: dist.ProcessGroup,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool = True,
):
if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
input = differentiable_identity(input, group=group)
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
input = differentiable_all_gather(input, group=group)
else:
raise ValueError(f"Got unexpected mode: {tp_mode}.")

return F.linear(input, weight, bias)
return F.linear(input, weight, bias)
if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply(
input, weight, bias, group, tp_recompute_allgather
)
raise ValueError(f"Got unexpected mode: {tp_mode}.")


class _RowLinearAsyncCommunication(torch.autograd.Function):
Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
dtype=None,
async_communication: bool = False,
contiguous_chunks: Optional[Tuple[int, ...]] = None,
tp_recompute_allgather: bool = True,
):
self.pg = pg
self.world_size = pg.size()
Expand All @@ -59,6 +60,7 @@ def __init__(

self.in_features = in_features
self.out_features = out_features // self.world_size
self.tp_recompute_allgather = tp_recompute_allgather

super().__init__(
in_features=self.in_features,
Expand Down Expand Up @@ -91,6 +93,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
group=self.pg,
tp_mode=self.mode,
async_communication=self.async_communication,
tp_recompute_allgather=self.tp_recompute_allgather,
)

def extra_repr(self) -> str:
Expand Down
20 changes: 20 additions & 0 deletions src/nanotron/parallel/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,31 @@
import functools
import operator
import os

import torch
from torch import nn

from nanotron import distributed as dist
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
from nanotron.utils import Singleton


class MemoryBuffer(metaclass=Singleton):
"""
Global memory buffer to store intermediate activations that need not to be cached for the backward pass.
"""

def __init__(self):
self.buffer = {}

def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
required_numel = functools.reduce(operator.mul, shape, 1)
if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel:
self.buffer[name, dtype] = torch.empty(
required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[name, dtype][:required_numel].view(shape)


def assert_cuda_max_connections_set_to_1(func):
Expand Down
25 changes: 22 additions & 3 deletions src/nanotron/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import functools
import inspect
import math
import os
import random
import socket
from contextlib import ExitStack, contextmanager
from typing import Callable, ContextManager, List, Optional
from typing import ContextManager, List, Optional

import torch
from packaging import version
Expand All @@ -15,6 +14,25 @@
from nanotron import distributed as dist


class Singleton(type):
"""
Singleton metaclass.
Create objects using this class as the metaclass to enable singleton behaviour.
For instance:
```
class Logger(metaclass=Singleton):
...
```
"""

_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class ContextManagers:
"""
Wrapper for `contextlib.ExitStack` which enters a collection of context managers. Adaptation of `ContextManagers`
Expand Down Expand Up @@ -52,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup):
@contextmanager
def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None):
"""Context manager that executes the code in the context with all the local rank zero of the group going first.
Usefull to run only once per node first (e.g. to create local files, etc)
Useful to run only once per node first (e.g. to create local files, etc)
"""
is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0
if is_main:
Expand Down Expand Up @@ -123,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage:
else:
return tensor.storage().untyped()


def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype):
# TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage.
device = untyped_storage.device
Expand Down
22 changes: 18 additions & 4 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,30 @@
@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
@pytest.mark.parametrize("tp_recompute_allgather", [False, True])
@rerun_if_address_is_in_use()
def test_column_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool):
def test_column_linear(
tp: int,
dp: int,
pp: int,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and tp_recompute_allgather:
pytest.skip("ALL_REDUCE mode is unaffected by tp_recompute_allgather")
init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)(
tp_mode=tp_mode, async_communication=async_communication
tp_mode=tp_mode, async_communication=async_communication, tp_recompute_allgather=tp_recompute_allgather
)


def _test_column_linear(
parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool
parallel_context: ParallelContext,
tp_mode: TensorParallelLinearMode,
async_communication: bool,
tp_recompute_allgather: bool,
):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
Expand All @@ -44,6 +57,7 @@ def _test_column_linear(
mode=tp_mode,
device="cuda",
async_communication=async_communication,
tp_recompute_allgather=tp_recompute_allgather,
)

# Un-sharded
Expand Down Expand Up @@ -86,7 +100,7 @@ def _test_column_linear(
random_input = sharded_random_input
else:
ValueError(f"Unsupported mode: {tp_mode}")
# It's important that `random_input` and `sharded_random_input` are two seperate tensors with seperate storage
# It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage
sharded_random_input = sharded_random_input.clone()
random_input.requires_grad = True
sharded_random_input.requires_grad = True
Expand Down

0 comments on commit 4eb520f

Please sign in to comment.