Skip to content

MX: hook up mxfp8 and mxfp4 CUTLASS kernels to MXLinear #1713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch
import torch.nn as nn

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.mx_linear import (
MXInferenceLinear,
MXLinear,
Expand Down Expand Up @@ -50,7 +50,9 @@ def run_around_tests():
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
def test_linear_eager(elem_dtype, bias, input_shape):
"""
Smoke test for training linear module with mx weight
Smoke test for training linear module with mx weight, compares the following:
* baseline: float32
* experiment: emulated MX
"""
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
Expand Down Expand Up @@ -92,6 +94,49 @@ def test_linear_eager(elem_dtype, bias, input_shape):
assert x_g_sqnr >= 8.0


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
)
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, DTYPE_FP4])
@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)])
def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn):
M, K, N = 128, 128, 128
M, K, N = mkn

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_()
x_copy = copy.deepcopy(x)
g = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
m_emulated = nn.Sequential(
nn.Linear(K, N, bias=False, device="cuda", dtype=torch.bfloat16),
)
m_real = copy.deepcopy(m_emulated)

config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype)
config_real = MXLinearConfig(
block_size=32,
elem_dtype=elem_dtype,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
)

swap_linear_with_mx_linear(m_emulated, config=config_emulated)
swap_linear_with_mx_linear(m_real, config=config_real)

y_emulated = m_emulated(x)
y_emulated.backward(g)

y_real = m_real(x_copy)
y_real.backward(g)

with torch.no_grad():
y_sqnr = compute_error(y_real, y_emulated)
w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad)
g_sqnr = compute_error(x_copy.grad, x.grad)
assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!"
assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!"
assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!"


# TODO(future): enable compile support
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
Expand Down
2 changes: 2 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
Expand Down Expand Up @@ -146,6 +147,7 @@ def test_exponent_nan_out(elem_dtype):
block_size,
torch.float,
use_fp4_custom_triton_dequant_kernel,
MXGemmKernelChoice.EMULATED,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
Expand Down
15 changes: 13 additions & 2 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,21 @@ This is a module to do MX training, the MX matmul is currently emulated.

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice
from torchao.utils import is_sm_at_least_100

# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by
# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support.
gemm_kernel_choice = MXGemmKernelChoice.EMULATED
if is_sm_at_least_100():
gemm_kernel_choice = MXGemmKernelChoice.CUTLASS

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
config = MXLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
)
swap_linear_with_mx_linear(m, config=config)

# training loop (not shown)
Expand Down
38 changes: 37 additions & 1 deletion torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,26 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

import torch

from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
SUPPORTED_ELEM_DTYPES,
)


class MXGemmKernelChoice(Enum):
# always available - MX operands are dequantized and a high precision
# gemm is run
EMULATED = "emulated"

# available only when CUDA capability is greater than or equal to 10.0
CUTLASS = "cutlass"

# TODO(future PR): add cuBLAS here once we land pytorch/pytorch support


@dataclass
Expand All @@ -27,10 +42,15 @@ class MXLinearConfig:
elem_dtype_weight_override: Optional[Any] = None
elem_dtype_grad_output_override: Optional[Any] = None

# defines the gemm kernel choice, if the chosen kernel is not supported
# on the given hardware an exception will be thrown
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

def __post_init__(self):
# validate elem_dtype and its overrides
assert (
self.elem_dtype in SUPPORTED_ELEM_DTYPES
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
Expand All @@ -42,3 +62,19 @@ def __post_init__(self):
assert (
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"

# validate that block size and elem_dtype matches kernel choice
if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
assert (
self.block_size == 32
), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}"
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
assert (
self.elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}"
assert (
self.elem_dtype_weight_override is None
), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels"
assert (
self.elem_dtype_grad_output_override is None
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
43 changes: 34 additions & 9 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
from torchao.prototype.mx_formats.mx_tensor import MXTensor


Expand All @@ -36,19 +36,25 @@ def forward(
w_elem_dtype: Any,
grad_elem_dtype: Any,
block_size: int,
gemm_kernel_choice: MXGemmKernelChoice,
):
ctx.save_for_backward(input_hp, weight_hp)
ctx.in_elem_dtype = in_elem_dtype
ctx.w_elem_dtype = w_elem_dtype
ctx.grad_elem_dtype = grad_elem_dtype
ctx.block_size = block_size
ctx.gemm_kernel_choice = gemm_kernel_choice

# input @ weight_t = output
input_orig_shape = input_hp.shape
input_hp_r = input_hp.reshape(-1, input_orig_shape[-1])

input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size)
weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size)
input_mx_r_dim0 = MXTensor.to_mx(
input_hp_r, in_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice
)
weight_mx_dim0 = MXTensor.to_mx(
weight_hp, w_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice
)
output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t())
output = output.reshape(*input_orig_shape[:-1], output.shape[-1])

Expand All @@ -62,6 +68,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
w_elem_dtype = ctx.w_elem_dtype
grad_elem_dtype = ctx.grad_elem_dtype
block_size = ctx.block_size
gemm_kernel_choice = ctx.gemm_kernel_choice

grad_output_orig_shape = grad_output_hp.shape
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
Expand All @@ -71,25 +78,39 @@ def backward(ctx, grad_output_hp: torch.Tensor):

# grad_output @ weight = grad_input
grad_output_mx_dim0 = MXTensor.to_mx(
grad_output_hp_r, grad_elem_dtype, block_size
grad_output_hp_r,
grad_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
weight_mx_dim1 = MXTensor.to_mx(
weight_hp_t_c,
w_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size)
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
grad_input = grad_input.reshape(
*grad_output_orig_shape[:-1], grad_input.shape[-1]
)

# input_t @ grad_output = grad_weight
grad_output_mx_dim1 = MXTensor.to_mx(
grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size
grad_output_hp_r.t().contiguous(),
grad_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
input_t_mx_dim0_tmp = MXTensor.to_mx(
input_hp_r.t().contiguous(), in_elem_dtype, block_size
input_hp_r.t().contiguous(),
in_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)

return grad_input, grad_weight, None, None, None, None
return grad_input, grad_weight, None, None, None, None, None


class MXLinear(torch.nn.Linear):
Expand Down Expand Up @@ -132,6 +153,7 @@ def forward(self, x):
config.elem_dtype_weight_override or config.elem_dtype,
config.elem_dtype_grad_output_override or config.elem_dtype,
config.block_size,
config.gemm_kernel_choice,
)
if self.bias is not None:
y = y + self.bias
Expand Down Expand Up @@ -163,7 +185,10 @@ def from_float(
# TODO(future PR): set to new_mod.weight directly, will need to work
# through some errors
new_mod.weight_mx = MXTensor.to_mx(
mod.weight, config.elem_dtype, block_size=config.block_size
mod.weight,
config.elem_dtype,
block_size=config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
new_mod.bias = mod.bias
new_mod.config = config
Expand Down
42 changes: 36 additions & 6 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import torch
from torch.utils._pytree import tree_map

# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
import torchao.ops
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import DTYPE_FP4
from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501
MXTensor,
tensor_size_hp_to_fp4x2,
)
from torchao.prototype.mx_formats.utils import to_blocked

aten = torch.ops.aten

Expand Down Expand Up @@ -55,6 +59,7 @@ def mx_desugar_op(aten_op, args, kwargs=None):
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
old._gemm_kernel_choice,
)
return new

Expand All @@ -64,12 +69,34 @@ def mx_mm(aten_op, args, kwargs=None):
a = args[0]
b = args[1]
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
# assert memory layout we expect to be required in hardware
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()
res = aten_op(a_hp, b_hp)
assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported"
if a._gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert b._data.t().is_contiguous()
a_scale = a._scale_e8m0.view(M, K // 32)
b_scale = b._scale_e8m0.view(N, K // 32)
a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)
if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
res = torchao.ops.mx_fp8_bf16(
a._data, b._data, a_scale_block, b_scale_block
)
else:
assert a._elem_dtype == DTYPE_FP4
assert b._elem_dtype == DTYPE_FP4
res = torchao.ops.mx_fp4_bf16(
a._data, b._data, a_scale_block, b_scale_block
)
else:
# emulated MX gemm
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
# assert memory layout we expect to be required in hardware
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()
res = aten_op(a_hp, b_hp)
return res


Expand All @@ -84,6 +111,7 @@ def mx_t(aten_op, args, kwargs=None):
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
old._gemm_kernel_choice,
)
return new

Expand Down Expand Up @@ -123,6 +151,7 @@ def mx_view_op(aten_op, args, kwargs=None):
args[0]._block_size,
args[0]._orig_dtype,
args[0]._use_fp4_custom_triton_dequant_kernel,
args[0]._gemm_kernel_choice,
)


Expand All @@ -147,5 +176,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
args[0]._block_size,
kwargs["dtype"],
args[0]._use_fp4_custom_triton_dequant_kernel,
args[0]._gemm_kernel_choice,
)
return res
Loading
Loading