Skip to content

Commit

Permalink
Merge pull request #56 from huggingface/xrsrke/fp8-training-clean-up
Browse files Browse the repository at this point in the history
[FP8 Training] A single forward and backward pass for a linear in FP8
  • Loading branch information
xrsrke authored Feb 14, 2024
2 parents c1963cf + 75bb1b8 commit 5bc00bb
Show file tree
Hide file tree
Showing 12 changed files with 584 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/nanotron/fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import warnings

from nanotron.fp8.dtypes import DTypes # noqa
from nanotron.fp8.linear import FP8Linear # noqa
from nanotron.fp8.parameter import FP8Parameter # noqa
from nanotron.fp8.tensor import FP8Tensor # noqa

try:
import transformer_engine as te # noqa
import transformer_engine_extensions as tex # noqa
except ImportError:
warnings.warn("Please install Transformer engine for FP8 training!")
18 changes: 18 additions & 0 deletions src/nanotron/fp8/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

from nanotron.fp8.dtypes import DTypes

FP8_GPU_NAMES = ["h100", "rtx 4090"]

INITIAL_AMAX = 1.0
INITIAL_SCALING_FACTOR = 1.0

# FP8_DTYPES = [torch.fp8e4m3, torch.fp8e5m2]
# FP8E4M3_DTYPE = torch.fp8e4m3
# FP8E5M2_DTYPE = torch.fp8e5m2

FP8_DTYPES = [torch.int8, torch.uint8]
FP8E4M3_DTYPE = torch.int8
FP8E5M2_DTYPE = torch.uint8

DTYPE_TO_FP8_MAX = {DTypes.FP8E4M3: 448.0, DTypes.FP8E5M2: 57344.0, DTypes.KFLOAT16: 65504.0}
7 changes: 7 additions & 0 deletions src/nanotron/fp8/dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum, auto


class DTypes(Enum):
FP8E4M3 = auto()
FP8E5M2 = auto()
KFLOAT16 = auto()
70 changes: 70 additions & 0 deletions src/nanotron/fp8/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex

from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.meta import FP8Meta


@torch.no_grad()
def fp8_matmul_kernel(
mat_a: FP8Tensor,
transpose_a: bool,
mat_b: FP8Tensor,
transpose_b: bool,
use_split_accumulator: bool,
) -> torch.Tensor:
assert (
mat_a.device != "cpu" and mat_b.device != "cpu"
), "The tensors must be on a CUDA device in order to use the FP8 kernel!!"

device = mat_a.device

_empty_tensor = torch.Tensor()
output = torch.empty(mat_a.shape[0], mat_b.shape[1], device=device, dtype=torch.float32)
workspace = torch.empty(33_554_432, dtype=torch.int8, device=device)
accumulate = False

out_dtype = getattr(tex.DType, "kFloat32")
# NOTE: currently TE don't support adding bias in FP8
# along with matmul, it only takes an empty bias
bias = torch.tensor([], dtype=torch.float32)
TE_CONFIG_TRANSPOSE_BIAS = False

mat_a_fp8_meta: FP8Meta = mat_a.fp8_meta
mat_b_fp8_meta: FP8Meta = mat_b.fp8_meta

# NOTE: these are the fixed configs that TE only takes
# so we have to TE the A and B matrix to match these configs
TE_CONFIG_TRANSPOSE_A = True
TE_CONFIG_TRANSPOSE_B = False
SCALE = AMAX = _empty_tensor

mat_a = tex.fp8_transpose(mat_a, mat_a_fp8_meta.te_dtype) if transpose_a is False else mat_a
mat_b = tex.fp8_transpose(mat_b, mat_b_fp8_meta.te_dtype) if transpose_b is True else mat_b

tex.te_gemm(
mat_a,
mat_a_fp8_meta.inverse_scale,
mat_a_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_A,
mat_b,
mat_b_fp8_meta.inverse_scale,
mat_b_fp8_meta.te_dtype,
TE_CONFIG_TRANSPOSE_B,
output,
SCALE,
out_dtype,
AMAX,
bias,
out_dtype,
_empty_tensor,
TE_CONFIG_TRANSPOSE_BIAS,
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)

return output
112 changes: 112 additions & 0 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import Optional, Tuple, TypedDict, Union

import torch
import torch.nn.functional as F
import transformer_engine as te # noqa
from torch import nn

from nanotron.fp8.constants import INITIAL_AMAX, INITIAL_SCALING_FACTOR
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.kernel import fp8_matmul_kernel
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor


class FP8LinearMeta(TypedDict):
"""FP8 metadata for FP8Linear."""

input_grad: FP8Meta
weight_grad: FP8Meta
output_grad: FP8Meta


class FP8Linear(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, device: Optional[torch.device] = None):
super().__init__(in_features, out_features, bias, device)
# TODO(xrsrke): add device, and 2 fp8 dtypes
if self.weight.device != torch.device("cpu"):
self.weight = FP8Parameter(self.weight, dtype=DTypes.FP8E4M3)

# NOTE: quantization metadata for input gradients, weight gradients, and output gradients
# TODO(xrsrke): don't fixed this
fp8e4m3_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR),
dtype=DTypes.FP8E4M3,
)
fp8e5m2_scale = update_scaling_factor(
amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32),
scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32),
dtype=DTypes.FP8E5M2,
)
self.fp8_meta: FP8LinearMeta = {
# kfloat8_e4m3
"input_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
"weight_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale),
# kfloat8_e5m2
"output_grad": FP8Meta(amax=1, dtype=DTypes.FP8E5M2, scale=fp8e5m2_scale),
}

def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
# NOTE: only do fp8 kernel if both input and weight are on CUDA device
if input.device == torch.device("cpu") or self.weight.device == torch.device("cpu"):
return F.linear(input, self.weight, self.bias)

# NOTE: just a phony tensor to make pytorch trigger the backward pass
# because weight and bias's requires_grad are set to False
# so that we can compute the gradients using the fp8 kernels by ourselves
phony = torch.empty(0, device=input.device, requires_grad=True)
output, _ = _FP8Matmul.apply(input, self.weight, self.fp8_meta, phony)

# TODO(xrsrke): add support for adding bias in fp8
# TODO(xrsrke): support return an fp8 tensor as output
# since we will quantize it back to FP8 anyway in the next linear
output = output if self.bias is None else output + self.bias
return output


class _FP8Matmul(torch.autograd.Function):
@staticmethod
@torch.no_grad()
def forward(
ctx, input: FP8Tensor, weight: FP8Tensor, fp8_meta: FP8LinearMeta, phony: torch.Tensor
) -> torch.Tensor:
if type(input) == torch.Tensor:
input = FP8Tensor(input, dtype=DTypes.FP8E4M3)

ctx.save_for_backward(input, weight)
ctx.fp8_meta = fp8_meta

# NOTE: pass FP8Tensor instead of FP8Parameter
output = fp8_matmul_kernel(
mat_a=weight.data, transpose_a=True, mat_b=input, transpose_b=False, use_split_accumulator=False
)

return output, phony

@staticmethod
@torch.no_grad()
def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[torch.Tensor, None, None, None]:
"""
∂L/∂X = ∂L/∂Y @ Wᵀ
∂L/∂W = Xᵀ @ ∂L/∂Y
Source: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html
"""
# TODO(xrsrke): investigate how does grad_output.contiguous() affect the outputs
input, weight = ctx.saved_tensors

if type(grad_output) == torch.Tensor:
grad_output = torch.ones_like(grad_output)
grad_output = grad_output.contiguous()
grad_output = FP8Tensor(grad_output, dtype=DTypes.FP8E5M2)

grad_input = fp8_matmul_kernel(
mat_a=grad_output, transpose_a=True, mat_b=weight, transpose_b=True, use_split_accumulator=True
)
grad_weight = fp8_matmul_kernel(
mat_a=input, transpose_a=False, mat_b=grad_output, transpose_b=False, use_split_accumulator=True
)
weight.grad = grad_weight

return grad_input, None, None, None
40 changes: 40 additions & 0 deletions src/nanotron/fp8/meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Union

import torch
import transformer_engine as te # noqa
import transformer_engine_extensions as tex

from nanotron.fp8.constants import DTYPE_TO_FP8_MAX
from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype


@dataclass
class FP8Meta:
"""Metadata for FP8Tensor."""

amax: Union[int, float]
scale: torch.Tensor

# TODO(xrsrke): change to Literal[torch.int8, torch.uint8]
dtype: torch.dtype

@property
def te_dtype(self) -> tex.DType:
return convert_torch_dtype_to_te_dtype(self.dtype)

def __post_init__(self):
# NOTE: transformer engine only accepts torch tensors
self.amax = torch.tensor(self.amax, device="cuda") if not isinstance(self.amax, torch.Tensor) else self.amax

@property
def fp8_max(self) -> float:
"""Return the maximum normal value for the current dtype."""
return DTYPE_TO_FP8_MAX[self.dtype]

@property
def inverse_scale(self) -> torch.Tensor:
return 1 / self.scale

def __repr__(self) -> str:
return f"FP8Meta(amax={self.amax}, scale={self.scale}, inverse_scale={self.inverse_scale}, dtype={self.dtype})"
43 changes: 43 additions & 0 deletions src/nanotron/fp8/parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from torch import nn

from nanotron.fp8.constants import FP8_DTYPES
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.tensor import FP8Tensor


class FP8Parameter(nn.Parameter):
"""
A custom FP8 parameter class that allows gradients
to flow into FP8 tensors (which are integer tensors).
"""

def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True) -> nn.Parameter:
assert isinstance(data, torch.Tensor), "data must be a tensor"
assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter"
assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors"
# TODO(xrsrke): if the tensor is on cpu, then bypass quantization

with torch.no_grad():
# TODO(xrsrke): support take an FP8 Tensor as data
# currently we can't only quantize a tensor to FP8 after the parameter is created
# because it raise "Only Tensors of floating point and complex dtype can require gradients"
self = torch.Tensor._make_subclass(cls, data, requires_grad)
self._data = FP8Tensor(data, dtype=dtype)
return self

@property
def data(self) -> FP8Tensor:
return self._data

@data.setter
def data(self, data: FP8Tensor):
self._data = data

@property
def fp8_meta(self) -> FP8Meta:
return self.data.fp8_meta

def __repr__(self) -> str:
return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}"
Loading

0 comments on commit 5bc00bb

Please sign in to comment.