diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index c2eb66960f..87451bf621 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -11,8 +11,8 @@ import torch import torch.nn as nn -from torchao.prototype.mx_formats.config import MXLinearConfig -from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES +from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig +from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, MXLinear, @@ -50,7 +50,9 @@ def run_around_tests(): @pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) def test_linear_eager(elem_dtype, bias, input_shape): """ - Smoke test for training linear module with mx weight + Smoke test for training linear module with mx weight, compares the following: + * baseline: float32 + * experiment: emulated MX """ # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) @@ -92,6 +94,49 @@ def test_linear_eager(elem_dtype, bias, input_shape): assert x_g_sqnr >= 8.0 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" +) +@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, DTYPE_FP4]) +@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)]) +def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn): + M, K, N = 128, 128, 128 + M, K, N = mkn + + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_() + x_copy = copy.deepcopy(x) + g = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + m_emulated = nn.Sequential( + nn.Linear(K, N, bias=False, device="cuda", dtype=torch.bfloat16), + ) + m_real = copy.deepcopy(m_emulated) + + config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype) + config_real = MXLinearConfig( + block_size=32, + elem_dtype=elem_dtype, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + + swap_linear_with_mx_linear(m_emulated, config=config_emulated) + swap_linear_with_mx_linear(m_real, config=config_real) + + y_emulated = m_emulated(x) + y_emulated.backward(g) + + y_real = m_real(x_copy) + y_real.backward(g) + + with torch.no_grad(): + y_sqnr = compute_error(y_real, y_emulated) + w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad) + g_sqnr = compute_error(x_copy.grad, x.grad) + assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!" + assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!" + assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!" + + # TODO(future): enable compile support @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2a15961586..f5014b7e31 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -7,6 +7,7 @@ import pytest import torch +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_FP6_E2M3, @@ -146,6 +147,7 @@ def test_exponent_nan_out(elem_dtype): block_size, torch.float, use_fp4_custom_triton_dequant_kernel, + MXGemmKernelChoice.EMULATED, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp[0:1])) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 09e7563ebb..1f1db18b7d 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -41,10 +41,21 @@ This is a module to do MX training, the MX matmul is currently emulated. ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear -from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice +from torchao.utils import is_sm_at_least_100 + +# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by +# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support. +gemm_kernel_choice = MXGemmKernelChoice.EMULATED +if is_sm_at_least_100(): + gemm_kernel_choice = MXGemmKernelChoice.CUTLASS m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +config = MXLinearConfig( + elem_dtype=torch.float8_e4m3fn, + block_size=32, + gemm_kernel_choice=gemm_kernel_choice, +) swap_linear_with_mx_linear(m, config=config) # training loop (not shown) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7cdf2d4e58..d511d2614d 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,11 +5,26 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from enum import Enum from typing import Any, Optional import torch -from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + SUPPORTED_ELEM_DTYPES, +) + + +class MXGemmKernelChoice(Enum): + # always available - MX operands are dequantized and a high precision + # gemm is run + EMULATED = "emulated" + + # available only when CUDA capability is greater than or equal to 10.0 + CUTLASS = "cutlass" + + # TODO(future PR): add cuBLAS here once we land pytorch/pytorch support @dataclass @@ -27,10 +42,15 @@ class MXLinearConfig: elem_dtype_weight_override: Optional[Any] = None elem_dtype_grad_output_override: Optional[Any] = None + # defines the gemm kernel choice, if the chosen kernel is not supported + # on the given hardware an exception will be thrown + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False def __post_init__(self): + # validate elem_dtype and its overrides assert ( self.elem_dtype in SUPPORTED_ELEM_DTYPES ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" @@ -42,3 +62,19 @@ def __post_init__(self): assert ( self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + + # validate that block size and elem_dtype matches kernel choice + if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: + assert ( + self.block_size == 32 + ), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}" + valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4] + assert ( + self.elem_dtype in valid_dtypes + ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}" + assert ( + self.elem_dtype_weight_override is None + ), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels" + assert ( + self.elem_dtype_grad_output_override is None + ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index a38a8c5499..e15f2ad727 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -13,7 +13,7 @@ import torch import torch.nn.functional as F -from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -36,19 +36,25 @@ def forward( w_elem_dtype: Any, grad_elem_dtype: Any, block_size: int, + gemm_kernel_choice: MXGemmKernelChoice, ): ctx.save_for_backward(input_hp, weight_hp) ctx.in_elem_dtype = in_elem_dtype ctx.w_elem_dtype = w_elem_dtype ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size + ctx.gemm_kernel_choice = gemm_kernel_choice # input @ weight_t = output input_orig_shape = input_hp.shape input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) - input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size) - weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size) + input_mx_r_dim0 = MXTensor.to_mx( + input_hp_r, in_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + ) + weight_mx_dim0 = MXTensor.to_mx( + weight_hp, w_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + ) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) @@ -62,6 +68,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): w_elem_dtype = ctx.w_elem_dtype grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size + gemm_kernel_choice = ctx.gemm_kernel_choice grad_output_orig_shape = grad_output_hp.shape grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) @@ -71,9 +78,17 @@ def backward(ctx, grad_output_hp: torch.Tensor): # grad_output @ weight = grad_input grad_output_mx_dim0 = MXTensor.to_mx( - grad_output_hp_r, grad_elem_dtype, block_size + grad_output_hp_r, + grad_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + ) + weight_mx_dim1 = MXTensor.to_mx( + weight_hp_t_c, + w_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) - weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -81,15 +96,21 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size + grad_output_hp_r.t().contiguous(), + grad_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), in_elem_dtype, block_size + input_hp_r.t().contiguous(), + in_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None, None, None + return grad_input, grad_weight, None, None, None, None, None class MXLinear(torch.nn.Linear): @@ -132,6 +153,7 @@ def forward(self, x): config.elem_dtype_weight_override or config.elem_dtype, config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, + config.gemm_kernel_choice, ) if self.bias is not None: y = y + self.bias @@ -163,7 +185,10 @@ def from_float( # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight, config.elem_dtype, block_size=config.block_size + mod.weight, + config.elem_dtype, + block_size=config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, ) new_mod.bias = mod.bias new_mod.config = config diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 5fb3e8c6c0..16e61e0653 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -22,11 +22,15 @@ import torch from torch.utils._pytree import tree_map +# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16 +import torchao.ops +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import DTYPE_FP4 from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501 MXTensor, tensor_size_hp_to_fp4x2, ) +from torchao.prototype.mx_formats.utils import to_blocked aten = torch.ops.aten @@ -55,6 +59,7 @@ def mx_desugar_op(aten_op, args, kwargs=None): old._block_size, old._orig_dtype, old._use_fp4_custom_triton_dequant_kernel, + old._gemm_kernel_choice, ) return new @@ -64,12 +69,34 @@ def mx_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) - a_hp = a.to_dtype(a._orig_dtype) - b_hp = b.to_dtype(b._orig_dtype) - # assert memory layout we expect to be required in hardware - assert a_hp.is_contiguous() - assert b_hp.t().is_contiguous() - res = aten_op(a_hp, b_hp) + assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported" + if a._gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: + # real MX gemm backed by torchao's CUTLASS kernels + M, K, N = a.shape[0], a.shape[1], b.shape[1] + assert b._data.t().is_contiguous() + a_scale = a._scale_e8m0.view(M, K // 32) + b_scale = b._scale_e8m0.view(N, K // 32) + a_scale_block = to_blocked(a_scale) + b_scale_block = to_blocked(b_scale) + if a._elem_dtype == torch.float8_e4m3fn: + assert b._elem_dtype == torch.float8_e4m3fn + res = torchao.ops.mx_fp8_bf16( + a._data, b._data, a_scale_block, b_scale_block + ) + else: + assert a._elem_dtype == DTYPE_FP4 + assert b._elem_dtype == DTYPE_FP4 + res = torchao.ops.mx_fp4_bf16( + a._data, b._data, a_scale_block, b_scale_block + ) + else: + # emulated MX gemm + a_hp = a.to_dtype(a._orig_dtype) + b_hp = b.to_dtype(b._orig_dtype) + # assert memory layout we expect to be required in hardware + assert a_hp.is_contiguous() + assert b_hp.t().is_contiguous() + res = aten_op(a_hp, b_hp) return res @@ -84,6 +111,7 @@ def mx_t(aten_op, args, kwargs=None): old._block_size, old._orig_dtype, old._use_fp4_custom_triton_dequant_kernel, + old._gemm_kernel_choice, ) return new @@ -123,6 +151,7 @@ def mx_view_op(aten_op, args, kwargs=None): args[0]._block_size, args[0]._orig_dtype, args[0]._use_fp4_custom_triton_dequant_kernel, + args[0]._gemm_kernel_choice, ) @@ -147,5 +176,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): args[0]._block_size, kwargs["dtype"], args[0]._use_fp4_custom_triton_dequant_kernel, + args[0]._gemm_kernel_choice, ) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 838ab2338c..6c0a718c78 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,6 +21,7 @@ import torch +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP4, @@ -331,6 +332,7 @@ def forward( block_size, scaling_mode, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode @@ -342,11 +344,12 @@ def forward( block_size, data_hp.dtype, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -380,6 +383,7 @@ def __new__( block_size, orig_dtype, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ): new_size = data_bits.size() if elem_dtype == DTYPE_FP4: @@ -440,6 +444,7 @@ def __new__( self._use_fp4_custom_triton_dequant_kernel = ( use_fp4_custom_triton_dequant_kernel ) + self._gemm_kernel_choice = gemm_kernel_choice return self def __repr__(self): @@ -467,6 +472,7 @@ def to_mx( block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, use_fp4_custom_triton_dequant_kernel: bool = False, + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, ): return ToMXConstrFunc.apply( data_hp, @@ -474,6 +480,7 @@ def to_mx( block_size, scaling_mode, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ) def __tensor_flatten__(self): @@ -482,6 +489,7 @@ def __tensor_flatten__(self): "_block_size": self._block_size, "_orig_dtype": self._orig_dtype, "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, + "_gemm_kernel_choice": self._gemm_kernel_choice, } return ["_scale_e8m0", "_data"], ctx @@ -499,6 +507,7 @@ def __tensor_unflatten__( metadata["_block_size"], metadata["_orig_dtype"], metadata["_use_fp4_custom_triton_dequant_kernel"], + metadata["_gemm_kernel_choice"], ) # Do not force the MXTensor type on the returned tensor