Skip to content

Commit

Permalink
add the forward pass of fp8 tp
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed May 23, 2024
1 parent cd67969 commit 2b0b9ef
Show file tree
Hide file tree
Showing 18 changed files with 355 additions and 124 deletions.
28 changes: 27 additions & 1 deletion src/nanotron/fp8/distributed.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
from typing import List, Union

import torch
import torch.distributed as dist
from torch.distributed import * # noqa

from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8
from nanotron.parallel.parameters import NanotronParameter


def all_reduce(tensor: FP8Tensor, op: dist.ReduceOp, group: dist.ProcessGroup, async_op: bool = False) -> FP8Tensor:
pass


def all_gather(
tensor_list: List[torch.Tensor],
tensor: Union[FP8Tensor, NanotronParameter],
group: dist.ProcessGroup,
async_op: bool = False,
) -> torch.Tensor:
tensor = tensor.data if isinstance(tensor, NanotronParameter) else tensor
# assert isinstance(tensor, FP8Tensor) if isinstance(tensor, FP8Tensor) else isinstance(tensor, torch.Tensor)

if isinstance(tensor, FP8Tensor):
tensor = (
convert_tensor_from_fp8(tensor, tensor.fp8_meta, torch.float32)
if tensor_list[0].dtype != tensor.dtype
else tensor
)

dist.all_gather(tensor_list, tensor, group, async_op)

return tensor
8 changes: 4 additions & 4 deletions src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def addmm(
input,
mat1,
mat2,
output: torch.Tensor,
out: torch.Tensor,
accum_qtype: DTypes,
metadatas: FP8LinearMeta,
beta: Union[float, int] = 1,
Expand All @@ -63,9 +63,9 @@ def addmm(
assert beta == 1.0, "Currently only support beta=1."
assert alpha == 1.0, "Currently only support alpha=1."

output = mm(input=mat1, mat2=mat2, out=output, accum_qtype=accum_qtype, metadatas=metadatas)
output = output if input is None else output + input
return output
out = mm(input=mat1, mat2=mat2, out=out, accum_qtype=accum_qtype, metadatas=metadatas)
out = out if input is None else out + input
return out


def linear(
Expand Down
7 changes: 5 additions & 2 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union, cast

import pydevd
import torch
import transformer_engine as te # noqa
from torch import nn
Expand Down Expand Up @@ -52,9 +53,11 @@ def __init__(
super().__init__(in_features, out_features, bias, device, QTYPE_TO_DTYPE[accum_qtype])
# TODO(xrsrke): don't fixed dtype, take it from the FP8 recipe
# DTypes.FP8E4M3
self.weight = FP8Parameter(
quant_w = FP8Parameter(
self.weight, dtype=FP8LM_RECIPE.linear.weight.dtype, interval=FP8LM_RECIPE.linear.weight.interval
)
self.weight = quant_w
assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
self.metadatas = FP8LinearMeta()
self.accum_qtype = accum_qtype

Expand Down Expand Up @@ -139,7 +142,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
∂L/∂W = Xᵀ @ ∂L/∂Y
Reference: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html
"""
# pydevd.settrace(suspend=False, trace_only_current_thread=True)
pydevd.settrace(suspend=False, trace_only_current_thread=True)
fp8_input, fp8_weight = ctx.saved_tensors
accum_qtype = ctx.accum_qtype

Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/fp8/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,12 @@ def __post_init__(self):
assert (
self.scale.dtype == torch.float32
), f"Expected scale to be of dtype torch.float32, got {self.scale.dtype}"

# TODO(xrsrke): move these to a constant
assert self.amax.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Expected amax to be of dtype torch.float32 or torch.float16, got {self.amax.dtype}"

# if self.is_delayed_scaling is False and self.interval > 1:
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/fp8/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def update_scaling_factor(
"""
# TODO(xrsrke): sometimes we store some params in fp16
# make this configurable
assert amax.dtype in [torch.float32, torch.float16]
assert amax.dtype in [torch.float32, torch.float16, torch.bfloat16], f"amax.dtype: {amax.dtype}"
# TODO(xrsrke): can we use lower precision for scaling_factor?
assert scaling_factor.dtype == torch.float32

Expand Down
10 changes: 6 additions & 4 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.
""" PyTorch LLaMa model.
"""
import math
from typing import Dict, Optional, Union

import torch
from torch import nn

Expand Down Expand Up @@ -607,15 +607,17 @@ def __init__(
layer_idx: int,
):
super().__init__()
self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = CausalSelfAttention(
config=config,
parallel_config=parallel_config,
tp_pg=tp_pg,
layer_idx=layer_idx,
)

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

def forward(
Expand Down Expand Up @@ -945,7 +947,7 @@ def init_model_randomly(self, config: Config):
continue

module = model.get_submodule(module_name)
parametrizator.parametrize(param_name, module)
parametrizator.parametrize(full_param_name, module)

assert full_param_name not in initialized_parameters
initialized_parameters.add(full_param_name)
Expand Down
4 changes: 2 additions & 2 deletions src/nanotron/nn/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class TritonLayerNorm(nn.LayerNorm):
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
# from flash_attn.ops.triton.layer_norm import layer_norm_fn

return layer_norm_fn(
input,
Expand Down Expand Up @@ -37,7 +37,7 @@ def reset_parameters(self):
def forward(
self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False
):
from flash_attn.ops.triton.layer_norm import layer_norm_fn
# from flash_attn.ops.triton.layer_norm import layer_norm_fn

return layer_norm_fn(
input,
Expand Down
31 changes: 27 additions & 4 deletions src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ class NanotronParameter(nn.Parameter):
NANOTRON_PARAMETER_METADATA_SHARDED_KEY = "sharded"

def __new__(cls, tensor: torch.Tensor, requires_grad: bool = True):
assert tensor.data.is_floating_point() or tensor.data.requires_grad is False
try:
assert tensor.data.is_floating_point() or tensor.data.requires_grad is False
except AttributeError:
assert 1 == 1

# data = tensor.data.detach() if tensor.data.is_floating_point() else tensor.data
# requires_grad = requires_grad if data.is_floating_point() else False
Expand All @@ -243,7 +246,16 @@ def __new__(cls, tensor: torch.Tensor, requires_grad: bool = True):
# param = nn.Parameter.__new__(cls, data=data, requires_grad=requires_grad)
# param.data =
# NOTE: this somehow makes the param has the methods of NanotronParameter
param = nn.Parameter._make_wrapper_subclass(cls, size=data.size())
param = nn.Parameter._make_wrapper_subclass(
cls,
size=data.size(),
strides=data.stride(),
storage_offset=data.storage_offset(),
dtype=data.dtype,
layout=data.layout,
device=data.device,
requires_grad=data.requires_grad,
)

if isinstance(tensor, NanotronParameter):
# Check that we don't inherit a weird class
Expand Down Expand Up @@ -335,7 +347,11 @@ def unwrap(e):
# return cls(e, fp8_meta=metadatas[0]) if isinstance(e, torch.Tensor) else e

def wrap(e):
return cls(e) if not isinstance(e, NanotronParameter) else e
# return cls(e) if not isinstance(e, NanotronParameter) else e
if not isinstance(e, NanotronParameter) and isinstance(e, (torch.Tensor, FP8Tensor)):
return cls(e)
else:
return e

args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
Expand All @@ -344,7 +360,14 @@ def wrap(e):
# NOTE: this is for parameter.data or parameter.detach()
return args[0].data
else:
return tree_map(wrap, func(*args, **kwargs))
outputs = func(*args, **kwargs)
# if len(outputs) == 1 and not isinstance(outputs, torch.Tensor):
# # NOTE: in some distributed operation, it doesn't return anything
# # but do in-place operation
# return outputs
# else:
# return tree_map(wrap, outputs)
return tree_map(wrap, outputs)


def sanity_check(root_module: nn.Module):
Expand Down
Loading

0 comments on commit 2b0b9ef

Please sign in to comment.