Skip to content

Commit

Permalink
MX: hook up mxfp8 and mxfp4 CUTLASS kernels to MXLinear
Browse files Browse the repository at this point in the history
Summary:

1. add a kernel choice setting to `MXLinearConfig` to choose between
   emulated gemm and CUTLASS gemm
2. respect the setting in the torch.mm op override
3. numerical tests to match emulated vs real e2e
   activations/weights/grads

Test Plan:

```
pytest test/prototype/mx_formats/ -s -x
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 55f72a0e0d18898c2f02fb2d88c537b382ed5a67
ghstack-comment-id: 2657958104
Pull Request resolved: #1713
  • Loading branch information
vkuzo committed Feb 14, 2025
1 parent 40d01cd commit 860131d
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 22 deletions.
51 changes: 48 additions & 3 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import torch
import torch.nn as nn

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.mx_linear import (
MXInferenceLinear,
MXLinear,
Expand Down Expand Up @@ -50,7 +50,9 @@ def run_around_tests():
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
def test_linear_eager(elem_dtype, bias, input_shape):
"""
Smoke test for training linear module with mx weight
Smoke test for training linear module with mx weight, compares the following:
* baseline: float32
* experiment: emulated MX
"""
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
Expand Down Expand Up @@ -92,6 +94,49 @@ def test_linear_eager(elem_dtype, bias, input_shape):
assert x_g_sqnr >= 8.0


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
)
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, DTYPE_FP4])
@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)])
def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn):
M, K, N = 128, 128, 128
M, K, N = mkn

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_()
x_copy = copy.deepcopy(x)
g = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
m_emulated = nn.Sequential(
nn.Linear(K, N, bias=False, device="cuda", dtype=torch.bfloat16),
)
m_real = copy.deepcopy(m_emulated)

config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype)
config_real = MXLinearConfig(
block_size=32,
elem_dtype=elem_dtype,
gemm_kernel_choice=MXGemmKernelChoice.CUTLASS,
)

swap_linear_with_mx_linear(m_emulated, config=config_emulated)
swap_linear_with_mx_linear(m_real, config=config_real)

y_emulated = m_emulated(x)
y_emulated.backward(g)

y_real = m_real(x_copy)
y_real.backward(g)

with torch.no_grad():
y_sqnr = compute_error(y_real, y_emulated)
w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad)
g_sqnr = compute_error(x_copy.grad, x.grad)
assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!"
assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!"
assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!"


# TODO(future): enable compile support
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
Expand Down
2 changes: 2 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch

from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
Expand Down Expand Up @@ -146,6 +147,7 @@ def test_exponent_nan_out(elem_dtype):
block_size,
torch.float,
use_fp4_custom_triton_dequant_kernel,
MXGemmKernelChoice.EMULATED,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
Expand Down
15 changes: 13 additions & 2 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,21 @@ This is a module to do MX training, the MX matmul is currently emulated.

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice
from torchao.utils import is_sm_at_least_100

# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by
# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support.
gemm_kernel_choice = MXGemmKernelChoice.EMULATED
if is_sm_at_least_100():
gemm_kernel_choice = MXGemmKernelChoice.CUTLASS

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
config = MXLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
)
swap_linear_with_mx_linear(m, config=config)

# training loop (not shown)
Expand Down
38 changes: 37 additions & 1 deletion torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,26 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

import torch

from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
SUPPORTED_ELEM_DTYPES,
)


class MXGemmKernelChoice(Enum):
# always available - MX operands are dequantized and a high precision
# gemm is run
EMULATED = "emulated"

# available only when CUDA capability is greater than or equal to 10.0
CUTLASS = "cutlass"

# TODO(future PR): add cuBLAS here once we land pytorch/pytorch support


@dataclass
Expand All @@ -27,10 +42,15 @@ class MXLinearConfig:
elem_dtype_weight_override: Optional[Any] = None
elem_dtype_grad_output_override: Optional[Any] = None

# defines the gemm kernel choice, if the chosen kernel is not supported
# on the given hardware an exception will be thrown
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

def __post_init__(self):
# validate elem_dtype and its overrides
assert (
self.elem_dtype in SUPPORTED_ELEM_DTYPES
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
Expand All @@ -42,3 +62,19 @@ def __post_init__(self):
assert (
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"

# validate that block size and elem_dtype matches kernel choice
if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
assert (
self.block_size == 32
), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}"
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
assert (
self.elem_dtype in valid_dtypes
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}"
assert (
self.elem_dtype_weight_override is None
), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels"
assert (
self.elem_dtype_grad_output_override is None
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
43 changes: 34 additions & 9 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
from torchao.prototype.mx_formats.mx_tensor import MXTensor


Expand All @@ -36,19 +36,25 @@ def forward(
w_elem_dtype: Any,
grad_elem_dtype: Any,
block_size: int,
gemm_kernel_choice: MXGemmKernelChoice,
):
ctx.save_for_backward(input_hp, weight_hp)
ctx.in_elem_dtype = in_elem_dtype
ctx.w_elem_dtype = w_elem_dtype
ctx.grad_elem_dtype = grad_elem_dtype
ctx.block_size = block_size
ctx.gemm_kernel_choice = gemm_kernel_choice

# input @ weight_t = output
input_orig_shape = input_hp.shape
input_hp_r = input_hp.reshape(-1, input_orig_shape[-1])

input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size)
weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size)
input_mx_r_dim0 = MXTensor.to_mx(
input_hp_r, in_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice
)
weight_mx_dim0 = MXTensor.to_mx(
weight_hp, w_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice
)
output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t())
output = output.reshape(*input_orig_shape[:-1], output.shape[-1])

Expand All @@ -62,6 +68,7 @@ def backward(ctx, grad_output_hp: torch.Tensor):
w_elem_dtype = ctx.w_elem_dtype
grad_elem_dtype = ctx.grad_elem_dtype
block_size = ctx.block_size
gemm_kernel_choice = ctx.gemm_kernel_choice

grad_output_orig_shape = grad_output_hp.shape
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
Expand All @@ -71,25 +78,39 @@ def backward(ctx, grad_output_hp: torch.Tensor):

# grad_output @ weight = grad_input
grad_output_mx_dim0 = MXTensor.to_mx(
grad_output_hp_r, grad_elem_dtype, block_size
grad_output_hp_r,
grad_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
weight_mx_dim1 = MXTensor.to_mx(
weight_hp_t_c,
w_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size)
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
grad_input = grad_input.reshape(
*grad_output_orig_shape[:-1], grad_input.shape[-1]
)

# input_t @ grad_output = grad_weight
grad_output_mx_dim1 = MXTensor.to_mx(
grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size
grad_output_hp_r.t().contiguous(),
grad_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
input_t_mx_dim0_tmp = MXTensor.to_mx(
input_hp_r.t().contiguous(), in_elem_dtype, block_size
input_hp_r.t().contiguous(),
in_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)

return grad_input, grad_weight, None, None, None, None
return grad_input, grad_weight, None, None, None, None, None


class MXLinear(torch.nn.Linear):
Expand Down Expand Up @@ -132,6 +153,7 @@ def forward(self, x):
config.elem_dtype_weight_override or config.elem_dtype,
config.elem_dtype_grad_output_override or config.elem_dtype,
config.block_size,
config.gemm_kernel_choice,
)
if self.bias is not None:
y = y + self.bias
Expand Down Expand Up @@ -163,7 +185,10 @@ def from_float(
# TODO(future PR): set to new_mod.weight directly, will need to work
# through some errors
new_mod.weight_mx = MXTensor.to_mx(
mod.weight, config.elem_dtype, block_size=config.block_size
mod.weight,
config.elem_dtype,
block_size=config.block_size,
gemm_kernel_choice=config.gemm_kernel_choice,
)
new_mod.bias = mod.bias
new_mod.config = config
Expand Down
42 changes: 36 additions & 6 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
import torch
from torch.utils._pytree import tree_map

# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
import torchao.ops
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
from torchao.prototype.mx_formats.constants import DTYPE_FP4
from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501
MXTensor,
tensor_size_hp_to_fp4x2,
)
from torchao.prototype.mx_formats.utils import to_blocked

aten = torch.ops.aten

Expand Down Expand Up @@ -55,6 +59,7 @@ def mx_desugar_op(aten_op, args, kwargs=None):
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
old._gemm_kernel_choice,
)
return new

Expand All @@ -64,12 +69,34 @@ def mx_mm(aten_op, args, kwargs=None):
a = args[0]
b = args[1]
assert isinstance(a, MXTensor) and isinstance(b, MXTensor)
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
# assert memory layout we expect to be required in hardware
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()
res = aten_op(a_hp, b_hp)
assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported"
if a._gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
# real MX gemm backed by torchao's CUTLASS kernels
M, K, N = a.shape[0], a.shape[1], b.shape[1]
assert b._data.t().is_contiguous()
a_scale = a._scale_e8m0.view(M, K // 32)
b_scale = b._scale_e8m0.view(N, K // 32)
a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)
if a._elem_dtype == torch.float8_e4m3fn:
assert b._elem_dtype == torch.float8_e4m3fn
res = torchao.ops.mx_fp8_bf16(
a._data, b._data, a_scale_block, b_scale_block
)
else:
assert a._elem_dtype == DTYPE_FP4
assert b._elem_dtype == DTYPE_FP4
res = torchao.ops.mx_fp4_bf16(
a._data, b._data, a_scale_block, b_scale_block
)
else:
# emulated MX gemm
a_hp = a.to_dtype(a._orig_dtype)
b_hp = b.to_dtype(b._orig_dtype)
# assert memory layout we expect to be required in hardware
assert a_hp.is_contiguous()
assert b_hp.t().is_contiguous()
res = aten_op(a_hp, b_hp)
return res


Expand All @@ -84,6 +111,7 @@ def mx_t(aten_op, args, kwargs=None):
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
old._gemm_kernel_choice,
)
return new

Expand Down Expand Up @@ -123,6 +151,7 @@ def mx_view_op(aten_op, args, kwargs=None):
args[0]._block_size,
args[0]._orig_dtype,
args[0]._use_fp4_custom_triton_dequant_kernel,
args[0]._gemm_kernel_choice,
)


Expand All @@ -147,5 +176,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
args[0]._block_size,
kwargs["dtype"],
args[0]._use_fp4_custom_triton_dequant_kernel,
args[0]._gemm_kernel_choice,
)
return res
Loading

0 comments on commit 860131d

Please sign in to comment.