-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #56 from huggingface/xrsrke/fp8-training-clean-up
[FP8 Training] A single forward and backward pass for a linear in FP8
- Loading branch information
Showing
12 changed files
with
584 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import warnings | ||
|
||
from nanotron.fp8.dtypes import DTypes # noqa | ||
from nanotron.fp8.linear import FP8Linear # noqa | ||
from nanotron.fp8.parameter import FP8Parameter # noqa | ||
from nanotron.fp8.tensor import FP8Tensor # noqa | ||
|
||
try: | ||
import transformer_engine as te # noqa | ||
import transformer_engine_extensions as tex # noqa | ||
except ImportError: | ||
warnings.warn("Please install Transformer engine for FP8 training!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
|
||
from nanotron.fp8.dtypes import DTypes | ||
|
||
FP8_GPU_NAMES = ["h100", "rtx 4090"] | ||
|
||
INITIAL_AMAX = 1.0 | ||
INITIAL_SCALING_FACTOR = 1.0 | ||
|
||
# FP8_DTYPES = [torch.fp8e4m3, torch.fp8e5m2] | ||
# FP8E4M3_DTYPE = torch.fp8e4m3 | ||
# FP8E5M2_DTYPE = torch.fp8e5m2 | ||
|
||
FP8_DTYPES = [torch.int8, torch.uint8] | ||
FP8E4M3_DTYPE = torch.int8 | ||
FP8E5M2_DTYPE = torch.uint8 | ||
|
||
DTYPE_TO_FP8_MAX = {DTypes.FP8E4M3: 448.0, DTypes.FP8E5M2: 57344.0, DTypes.KFLOAT16: 65504.0} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from enum import Enum, auto | ||
|
||
|
||
class DTypes(Enum): | ||
FP8E4M3 = auto() | ||
FP8E5M2 = auto() | ||
KFLOAT16 = auto() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
import transformer_engine as te # noqa | ||
import transformer_engine_extensions as tex | ||
|
||
from nanotron.fp8.tensor import FP8Tensor | ||
from nanotron.fp8.meta import FP8Meta | ||
|
||
|
||
@torch.no_grad() | ||
def fp8_matmul_kernel( | ||
mat_a: FP8Tensor, | ||
transpose_a: bool, | ||
mat_b: FP8Tensor, | ||
transpose_b: bool, | ||
use_split_accumulator: bool, | ||
) -> torch.Tensor: | ||
assert ( | ||
mat_a.device != "cpu" and mat_b.device != "cpu" | ||
), "The tensors must be on a CUDA device in order to use the FP8 kernel!!" | ||
|
||
device = mat_a.device | ||
|
||
_empty_tensor = torch.Tensor() | ||
output = torch.empty(mat_a.shape[0], mat_b.shape[1], device=device, dtype=torch.float32) | ||
workspace = torch.empty(33_554_432, dtype=torch.int8, device=device) | ||
accumulate = False | ||
|
||
out_dtype = getattr(tex.DType, "kFloat32") | ||
# NOTE: currently TE don't support adding bias in FP8 | ||
# along with matmul, it only takes an empty bias | ||
bias = torch.tensor([], dtype=torch.float32) | ||
TE_CONFIG_TRANSPOSE_BIAS = False | ||
|
||
mat_a_fp8_meta: FP8Meta = mat_a.fp8_meta | ||
mat_b_fp8_meta: FP8Meta = mat_b.fp8_meta | ||
|
||
# NOTE: these are the fixed configs that TE only takes | ||
# so we have to TE the A and B matrix to match these configs | ||
TE_CONFIG_TRANSPOSE_A = True | ||
TE_CONFIG_TRANSPOSE_B = False | ||
SCALE = AMAX = _empty_tensor | ||
|
||
mat_a = tex.fp8_transpose(mat_a, mat_a_fp8_meta.te_dtype) if transpose_a is False else mat_a | ||
mat_b = tex.fp8_transpose(mat_b, mat_b_fp8_meta.te_dtype) if transpose_b is True else mat_b | ||
|
||
tex.te_gemm( | ||
mat_a, | ||
mat_a_fp8_meta.inverse_scale, | ||
mat_a_fp8_meta.te_dtype, | ||
TE_CONFIG_TRANSPOSE_A, | ||
mat_b, | ||
mat_b_fp8_meta.inverse_scale, | ||
mat_b_fp8_meta.te_dtype, | ||
TE_CONFIG_TRANSPOSE_B, | ||
output, | ||
SCALE, | ||
out_dtype, | ||
AMAX, | ||
bias, | ||
out_dtype, | ||
_empty_tensor, | ||
TE_CONFIG_TRANSPOSE_BIAS, | ||
workspace, | ||
workspace.shape[0], | ||
accumulate, | ||
use_split_accumulator, | ||
0, | ||
) | ||
|
||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from typing import Optional, Tuple, TypedDict, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import transformer_engine as te # noqa | ||
from torch import nn | ||
|
||
from nanotron.fp8.constants import INITIAL_AMAX, INITIAL_SCALING_FACTOR | ||
from nanotron.fp8.dtypes import DTypes | ||
from nanotron.fp8.kernel import fp8_matmul_kernel | ||
from nanotron.fp8.meta import FP8Meta | ||
from nanotron.fp8.parameter import FP8Parameter | ||
from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor | ||
|
||
|
||
class FP8LinearMeta(TypedDict): | ||
"""FP8 metadata for FP8Linear.""" | ||
|
||
input_grad: FP8Meta | ||
weight_grad: FP8Meta | ||
output_grad: FP8Meta | ||
|
||
|
||
class FP8Linear(nn.Linear): | ||
def __init__(self, in_features: int, out_features: int, bias: bool = True, device: Optional[torch.device] = None): | ||
super().__init__(in_features, out_features, bias, device) | ||
# TODO(xrsrke): add device, and 2 fp8 dtypes | ||
if self.weight.device != torch.device("cpu"): | ||
self.weight = FP8Parameter(self.weight, dtype=DTypes.FP8E4M3) | ||
|
||
# NOTE: quantization metadata for input gradients, weight gradients, and output gradients | ||
# TODO(xrsrke): don't fixed this | ||
fp8e4m3_scale = update_scaling_factor( | ||
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), | ||
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR), | ||
dtype=DTypes.FP8E4M3, | ||
) | ||
fp8e5m2_scale = update_scaling_factor( | ||
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), | ||
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), | ||
dtype=DTypes.FP8E5M2, | ||
) | ||
self.fp8_meta: FP8LinearMeta = { | ||
# kfloat8_e4m3 | ||
"input_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale), | ||
"weight_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale), | ||
# kfloat8_e5m2 | ||
"output_grad": FP8Meta(amax=1, dtype=DTypes.FP8E5M2, scale=fp8e5m2_scale), | ||
} | ||
|
||
def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor: | ||
# NOTE: only do fp8 kernel if both input and weight are on CUDA device | ||
if input.device == torch.device("cpu") or self.weight.device == torch.device("cpu"): | ||
return F.linear(input, self.weight, self.bias) | ||
|
||
# NOTE: just a phony tensor to make pytorch trigger the backward pass | ||
# because weight and bias's requires_grad are set to False | ||
# so that we can compute the gradients using the fp8 kernels by ourselves | ||
phony = torch.empty(0, device=input.device, requires_grad=True) | ||
output, _ = _FP8Matmul.apply(input, self.weight, self.fp8_meta, phony) | ||
|
||
# TODO(xrsrke): add support for adding bias in fp8 | ||
# TODO(xrsrke): support return an fp8 tensor as output | ||
# since we will quantize it back to FP8 anyway in the next linear | ||
output = output if self.bias is None else output + self.bias | ||
return output | ||
|
||
|
||
class _FP8Matmul(torch.autograd.Function): | ||
@staticmethod | ||
@torch.no_grad() | ||
def forward( | ||
ctx, input: FP8Tensor, weight: FP8Tensor, fp8_meta: FP8LinearMeta, phony: torch.Tensor | ||
) -> torch.Tensor: | ||
if type(input) == torch.Tensor: | ||
input = FP8Tensor(input, dtype=DTypes.FP8E4M3) | ||
|
||
ctx.save_for_backward(input, weight) | ||
ctx.fp8_meta = fp8_meta | ||
|
||
# NOTE: pass FP8Tensor instead of FP8Parameter | ||
output = fp8_matmul_kernel( | ||
mat_a=weight.data, transpose_a=True, mat_b=input, transpose_b=False, use_split_accumulator=False | ||
) | ||
|
||
return output, phony | ||
|
||
@staticmethod | ||
@torch.no_grad() | ||
def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[torch.Tensor, None, None, None]: | ||
""" | ||
∂L/∂X = ∂L/∂Y @ Wᵀ | ||
∂L/∂W = Xᵀ @ ∂L/∂Y | ||
Source: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html | ||
""" | ||
# TODO(xrsrke): investigate how does grad_output.contiguous() affect the outputs | ||
input, weight = ctx.saved_tensors | ||
|
||
if type(grad_output) == torch.Tensor: | ||
grad_output = torch.ones_like(grad_output) | ||
grad_output = grad_output.contiguous() | ||
grad_output = FP8Tensor(grad_output, dtype=DTypes.FP8E5M2) | ||
|
||
grad_input = fp8_matmul_kernel( | ||
mat_a=grad_output, transpose_a=True, mat_b=weight, transpose_b=True, use_split_accumulator=True | ||
) | ||
grad_weight = fp8_matmul_kernel( | ||
mat_a=input, transpose_a=False, mat_b=grad_output, transpose_b=False, use_split_accumulator=True | ||
) | ||
weight.grad = grad_weight | ||
|
||
return grad_input, None, None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from dataclasses import dataclass | ||
from typing import Union | ||
|
||
import torch | ||
import transformer_engine as te # noqa | ||
import transformer_engine_extensions as tex | ||
|
||
from nanotron.fp8.constants import DTYPE_TO_FP8_MAX | ||
from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype | ||
|
||
|
||
@dataclass | ||
class FP8Meta: | ||
"""Metadata for FP8Tensor.""" | ||
|
||
amax: Union[int, float] | ||
scale: torch.Tensor | ||
|
||
# TODO(xrsrke): change to Literal[torch.int8, torch.uint8] | ||
dtype: torch.dtype | ||
|
||
@property | ||
def te_dtype(self) -> tex.DType: | ||
return convert_torch_dtype_to_te_dtype(self.dtype) | ||
|
||
def __post_init__(self): | ||
# NOTE: transformer engine only accepts torch tensors | ||
self.amax = torch.tensor(self.amax, device="cuda") if not isinstance(self.amax, torch.Tensor) else self.amax | ||
|
||
@property | ||
def fp8_max(self) -> float: | ||
"""Return the maximum normal value for the current dtype.""" | ||
return DTYPE_TO_FP8_MAX[self.dtype] | ||
|
||
@property | ||
def inverse_scale(self) -> torch.Tensor: | ||
return 1 / self.scale | ||
|
||
def __repr__(self) -> str: | ||
return f"FP8Meta(amax={self.amax}, scale={self.scale}, inverse_scale={self.inverse_scale}, dtype={self.dtype})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from nanotron.fp8.constants import FP8_DTYPES | ||
from nanotron.fp8.dtypes import DTypes | ||
from nanotron.fp8.meta import FP8Meta | ||
from nanotron.fp8.tensor import FP8Tensor | ||
|
||
|
||
class FP8Parameter(nn.Parameter): | ||
""" | ||
A custom FP8 parameter class that allows gradients | ||
to flow into FP8 tensors (which are integer tensors). | ||
""" | ||
|
||
def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True) -> nn.Parameter: | ||
assert isinstance(data, torch.Tensor), "data must be a tensor" | ||
assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter" | ||
assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors" | ||
# TODO(xrsrke): if the tensor is on cpu, then bypass quantization | ||
|
||
with torch.no_grad(): | ||
# TODO(xrsrke): support take an FP8 Tensor as data | ||
# currently we can't only quantize a tensor to FP8 after the parameter is created | ||
# because it raise "Only Tensors of floating point and complex dtype can require gradients" | ||
self = torch.Tensor._make_subclass(cls, data, requires_grad) | ||
self._data = FP8Tensor(data, dtype=dtype) | ||
return self | ||
|
||
@property | ||
def data(self) -> FP8Tensor: | ||
return self._data | ||
|
||
@data.setter | ||
def data(self, data: FP8Tensor): | ||
self._data = data | ||
|
||
@property | ||
def fp8_meta(self) -> FP8Meta: | ||
return self.data.fp8_meta | ||
|
||
def __repr__(self) -> str: | ||
return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}" |
Oops, something went wrong.