From c83452072fd2c228ca66e9861533af55010db326 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Jan 2025 20:32:05 -0800 Subject: [PATCH 1/8] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/mx_tensor.py | 60 ++++++++++++++++++++--- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 8eeeaf8bfd..1581628d58 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -16,6 +16,7 @@ * Zeros: N/A """ +from enum import Enum, auto from typing import Dict, Union import torch @@ -53,11 +54,38 @@ unpack_uint4, ) +# TODO(later): read from somewhere else? +SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23 +EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1 +EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 +EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 +EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3 +EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2 + + +class ScaleCalculationMode(Enum): + """ + Enum representing the different methods for calculating MX block scaling. + There are three methods available: + FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp). + It result in overflow issues for large values and bad for gradient quantization. + CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor. + It uses X = 2^ceil(log2(max_abs(v))-max_exp). + EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)). + It provides better accuracy for MX4 training compared to FLOOR and CEIL. + By default, we use the EVEN method for better accuracy. + """ + + FLOOR = auto() + CEIL = auto() + EVEN = auto() + def to_mx( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, ): """ Takes a high precision tensor and converts to MX scale and raw data, in @@ -88,25 +116,45 @@ def to_mx( # where the values are zero. eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype) - # Find largest power of 2 less than or equal to max_abs. - largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps)) - # Set X to be the largest power-of-two less than or equal to # max_abs(v), divided by the largest power of two representable - # in the element data type + # in the element data type, and get the mbits at the same time if elem_dtype == torch.float8_e4m3fn: target_max_pow2 = F8E4M3_MAX_POW2 + mbits = MBITS_F8_E4M3 elif elem_dtype == torch.float8_e5m2: target_max_pow2 = F8E5M2_MAX_POW2 + mbits = MBITS_F8_E5M2 elif elem_dtype == DTYPE_FP6_E2M3: target_max_pow2 = F6_E2M3_MAX_POW2 + mbits = MBITS_F6_E2M3 elif elem_dtype == DTYPE_FP6_E3M2: target_max_pow2 = F6_E3M2_MAX_POW2 + mbits = MBITS_F6_E3M2 elif elem_dtype == DTYPE_FP4: target_max_pow2 = F4_E2M1_MAX_POW2 + mbits = MBITS_F4_E2M1 else: - raise AssertionError("unsupported") - scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2 + raise AssertionError("unsupported element dtype") + + # rounding before calculating the largest power of 2 + # X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)) + if scaling_mode == ScaleCalculationMode.EVEN: + nan_mask = torch.isnan(max_abs) + max_abs = max_abs.to(torch.float32).view(torch.int32) + val_to_add = 1 << (MBITS_F32 - mbits - 1) + mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32 + max_abs = (max_abs + val_to_add) & mask + max_abs = max_abs.view(torch.float32) + max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device) + + # Calculate the scale for different modes + if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN): + scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2 + elif scaling_mode == ScaleCalculationMode.CEIL: + scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2 + else: + raise AssertionError("unsupported scaling calculation mode") # Clamp to exponents that can be represented in e8m0 scale_e8m0_unbiased = torch.clamp( From 85da2973ba29d4c6978416dbe33896ef8ecf40fd Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 29 Jan 2025 20:53:42 -0800 Subject: [PATCH 2/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 16 ++++++++++++++-- torchao/prototype/mx_formats/mx_tensor.py | 11 +++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ae87ee021e..2bad17a13d 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -18,6 +18,7 @@ from torchao.prototype.mx_formats.mx_tensor import ( E8M0_EXPONENT_NAN_VAL, MXTensor, + ScaleCalculationMode, to_dtype, ) from torchao.quantization.utils import compute_error @@ -43,8 +44,10 @@ def run_before_and_after_tests(): torch._dynamo.reset() -def _test_mx(data_hp, elem_dtype, block_size): - data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size) +def _test_mx( + data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR +): + data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode) data_mx_dq = data_mx.to_dtype(data_hp.dtype) def assert_sqnr_gt_threshold(orig, new, threshold): @@ -70,6 +73,15 @@ def test_hello_world(elem_dtype): _test_mx(data, elem_dtype, block_size) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode]) +@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) +def test_realistic_numerics(elem_dtype, scale_calculation_mode): + data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16) + block_size = 32 + _test_mx(data, elem_dtype, block_size, scale_calculation_mode) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_all_zeros(elem_dtype): diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 1581628d58..801f29ac3c 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -318,15 +318,17 @@ class ToMXConstrFunc(torch.autograd.Function): """ @staticmethod - def forward(ctx, data_hp, elem_dtype, block_size): - scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size) + def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): + scale_e8m0_biased, data_lp = to_mx( + data_hp, elem_dtype, block_size, scaling_mode + ) return MXTensor( scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype ) @staticmethod def backward(ctx, g): - return g, None, None + return g, None, None, None @torch._dynamo.allow_in_graph @@ -440,8 +442,9 @@ def to_mx( data_hp: torch.Tensor, elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, ): - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size) + return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) def __tensor_flatten__(self): ctx = { From 0220b1923b448a91a193e89e0331c6664fd7e1a0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 09:17:20 -0800 Subject: [PATCH 3/8] Update [ghstack-poisoned] --- third_party/cutlass | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/cutlass b/third_party/cutlass index bf9da7b76c..b78588d163 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit bf9da7b76c766d7ee7d536afc77880a4ef1f1156 +Subproject commit b78588d1630aa6643bf021613717bafb705df4ef From e4b5dedac3006777d21b8466ebc38c7e25eaf3c7 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 09:57:17 -0800 Subject: [PATCH 4/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 20 +++++---- torchao/prototype/mx_formats/config.py | 15 ++++++- torchao/prototype/mx_formats/mx_linear.py | 22 ++++++++-- torchao/prototype/mx_formats/mx_ops.py | 6 ++- torchao/prototype/mx_formats/mx_tensor.py | 46 +++++++++++++++++---- 5 files changed, 88 insertions(+), 21 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ad718beb9c..9e97e1c32b 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -7,7 +7,6 @@ import pytest import torch -from torchao.prototype.mx_formats import config from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_FP6_E2M3, @@ -139,8 +138,14 @@ def test_exponent_nan_out(elem_dtype): else: raise AssertionError("unsupported") block_size = 2 + use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( - scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float + scale_e8m0_bits, + data_bits, + elem_dtype, + block_size, + torch.float, + use_fp4_custom_triton_dequant_kernel, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp[0:1])) @@ -188,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton): M, K = 128, 256 block_size = 32 tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) - config.use_fp4_custom_triton_dequant_kernel = fp4_triton + tensor_mx = MXTensor.to_mx( + tensor_hp, + elem_dtype, + block_size, + use_fp4_custom_triton_dequant_kernel=fp4_triton, + ) tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() - config.use_fp4_custom_triton_dequant_kernel = False tensor_mx_t = tensor_mx.t() - config.use_fp4_custom_triton_dequant_kernel = fp4_triton tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype) - config.use_fp4_custom_triton_dequant_kernel = False assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 3e7e03d8f6..7b68b5b6a5 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -1,2 +1,13 @@ -# If True, uses a custom triton kernel for fp4 dequantize -use_fp4_custom_triton_dequant_kernel = False +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass +class MXLinearConfig: + # If True, uses a custom triton kernel for fp4 dequantize + use_fp4_custom_triton_dequant_kernel: bool = False diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index d7aa744334..72c2b6ab39 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -8,11 +8,12 @@ Defines the prototype UX for converting a model to use mx weights """ -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -110,6 +111,8 @@ def from_float( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block size into config + config: MXLinearConfig = None, block_size=32, ): mod.__class__ = MXLinear @@ -117,6 +120,10 @@ def from_float( mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype mod.block_size = block_size + # TODO(next PR): fix this + if config is None: + config = MXLinearConfig() + mod.config = config return mod def forward(self, x): @@ -151,7 +158,9 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size): + def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): + # TODO(next PR): move elem_dtype and block_size into config + with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -166,6 +175,7 @@ def from_float(cls, mod, elem_dtype, block_size): ) new_mod.bias = mod.bias new_mod.elem_dtype = elem_dtype + new_mod.config = config return new_mod @torch.no_grad() @@ -207,6 +217,8 @@ def swap_linear_with_mx_linear( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block_size into config + config: Optional[MXLinearConfig] = None, block_size=32, filter_fn=None, ): @@ -225,6 +237,7 @@ def __fn(mod, fqn): elem_dtype, elem_dtype_weight_override, elem_dtype_grad_output_override, + config=config, block_size=block_size, ), combined_filter_fn, @@ -236,6 +249,7 @@ def swap_linear_with_mx_inference_linear( elem_dtype, block_size, filter_fn=None, + config: Optional[MXLinearConfig] = None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -247,6 +261,8 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size), + lambda mod: MXInferenceLinear.from_float( + mod, elem_dtype, block_size, config=config + ), combined_filter_fn, ) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 57fb0d54b4..5fb3e8c6c0 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -54,6 +54,7 @@ def mx_desugar_op(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -82,6 +83,7 @@ def mx_t(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -120,6 +122,7 @@ def mx_view_op(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, args[0]._orig_dtype, + args[0]._use_fp4_custom_triton_dequant_kernel, ) @@ -130,7 +133,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): tensor. """ assert isinstance(args[0], MXTensor) - # print('before', args[0], args[0].dtype, args[0]._orig_dtype) assert ( len(kwargs) == 1 and "dtype" in kwargs ), "Only support dtype kwarg for autocast" @@ -144,6 +146,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, kwargs["dtype"], + args[0]._use_fp4_custom_triton_dequant_kernel, ) - # print('after', res, res.dtype, res._orig_dtype) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 801f29ac3c..838ab2338c 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,7 +21,6 @@ import torch -import torchao.prototype.mx_formats.config as config from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP4, @@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0): return s_fp -def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): +def to_dtype( + data_lp, + scale_e8m0, + elem_dtype, + block_size, + target_dtype, + use_fp4_custom_triton_dequant_kernel, +): orig_shape = data_lp.shape is_transposed = not data_lp.is_contiguous() # if the underlying data is transposed, convert to row major before @@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): data_hp = f6_e3m2_unpacked_to_f32(data_lp) data_hp = data_hp.to(target_dtype) elif elem_dtype == DTYPE_FP4: - if config.use_fp4_custom_triton_dequant_kernel: + if use_fp4_custom_triton_dequant_kernel: data_hp_rescaled = triton_f4_to_scaled_bf16( data_lp, scale_e8m0, @@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function): """ @staticmethod - def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): + def forward( + ctx, + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode ) return MXTensor( - scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype + scale_e8m0_biased, + data_lp, + elem_dtype, + block_size, + data_hp.dtype, + use_fp4_custom_triton_dequant_kernel, ) @staticmethod def backward(ctx, g): - return g, None, None, None + return g, None, None, None, None @torch._dynamo.allow_in_graph @@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype): tensor_lp._elem_dtype, tensor_lp._block_size, target_dtype, + tensor_lp._use_fp4_custom_triton_dequant_kernel, ) @staticmethod @@ -360,6 +379,7 @@ def __new__( elem_dtype, block_size, orig_dtype, + use_fp4_custom_triton_dequant_kernel, ): new_size = data_bits.size() if elem_dtype == DTYPE_FP4: @@ -417,6 +437,9 @@ def __new__( self._elem_dtype = elem_dtype self._block_size = block_size self._orig_dtype = orig_dtype + self._use_fp4_custom_triton_dequant_kernel = ( + use_fp4_custom_triton_dequant_kernel + ) return self def __repr__(self): @@ -443,14 +466,22 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, + use_fp4_custom_triton_dequant_kernel: bool = False, ): - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) + return ToMXConstrFunc.apply( + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ) def __tensor_flatten__(self): ctx = { "_elem_dtype": self._elem_dtype, "_block_size": self._block_size, "_orig_dtype": self._orig_dtype, + "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, } return ["_scale_e8m0", "_data"], ctx @@ -467,6 +498,7 @@ def __tensor_unflatten__( metadata["_elem_dtype"], metadata["_block_size"], metadata["_orig_dtype"], + metadata["_use_fp4_custom_triton_dequant_kernel"], ) # Do not force the MXTensor type on the returned tensor From 7c1166e399c5a4e70fa643bedba0c573f8d0da26 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 11:36:29 -0800 Subject: [PATCH 5/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 9e97e1c32b..2a15961586 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -264,12 +264,14 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): to_dtype_c = torch.compile(to_dtype, fullgraph=True) + use_fp4_custom_triton_dequant_kernel = False x_mx_dq = to_dtype( x_mx._data, x_mx._scale_e8m0, x_mx._elem_dtype, x_mx._block_size, hp_dtype, # noqa: E501 + use_fp4_custom_triton_dequant_kernel, ) x_mx_c_dq = to_dtype_c( x_mx_c._data, @@ -277,5 +279,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c._elem_dtype, x_mx_c._block_size, hp_dtype, + use_fp4_custom_triton_dequant_kernel, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) From 8819b28642c36876cbb191d0e33ec94f7f27099b Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 10 Feb 2025 11:36:29 -0800 Subject: [PATCH 6/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 47 +++++++--------- torchao/prototype/mx_formats/README.md | 11 ++-- torchao/prototype/mx_formats/config.py | 31 +++++++++++ torchao/prototype/mx_formats/mx_linear.py | 60 +++++++-------------- 4 files changed, 74 insertions(+), 75 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 17a76a750d..c2eb66960f 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -11,6 +11,7 @@ import torch import torch.nn as nn +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, @@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape): nn.Linear(8, 6, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size) + config = MXLinearConfig( + block_size=2, + elem_dtype=elem_dtype[0], + elem_dtype_weight_override=elem_dtype[1], + elem_dtype_grad_output_override=elem_dtype[2], + ) + swap_linear_with_mx_linear(m_mx, config=config) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() x = copy.deepcopy(x_ref) @@ -97,8 +103,8 @@ def test_activation_checkpointing(): nn.Linear(4, 6, bias=True, device="cuda"), nn.Linear(6, 6, bias=True, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m, config=config) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast): m_mx = nn.Sequential( nn.Linear(K, N, bias=bias, device="cuda"), ) - block_size = 2 - swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_linear(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape): m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16) y_ref = m(x) @@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype): m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16)) m = m.cuda() m_mx = copy.deepcopy(m) - block_size = 2 - swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size) + config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype) + swap_linear_with_mx_inference_linear(m_mx, config=config) m_mx = torch.compile(m_mx, fullgraph="true") x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16) @@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 13.5 -def test_mx_linear_input_weight_gradient_dtypes(): - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32) - assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0] - assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1] - assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2] - - m = nn.Sequential(nn.Linear(32, 32)) - swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32) - assert m[0].in_elem_dtype == torch.float8_e4m3fn - assert m[0].w_elem_dtype == torch.float8_e4m3fn - assert m[0].grad_elem_dtype == torch.float8_e4m3fn - - def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), @@ -245,12 +237,11 @@ def test_filter_fn(): m2 = copy.deepcopy(m1) filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 - swap_linear_with_mx_linear( - m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn - ) + config = MXLinearConfig(block_size=32) + swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear - swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501 + swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 32f45e3755..09e7563ebb 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated. ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -swap_linear_with_mx_linear(m, elem_dtype, block_size=32) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_linear(m, config=config) # training loop (not shown) ``` @@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear +from torchao.prototype.mx_formats.config import MXLinearConfig m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -elem_dtype = torch.float8_e4m3fn -block_size = 32 -swap_linear_with_mx_inference_linear(m, elem_dtype, block_size) +config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +swap_linear_with_mx_inference_linear(m, config=config) # do inference (not shown) ``` diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7b68b5b6a5..7cdf2d4e58 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,9 +5,40 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES @dataclass class MXLinearConfig: + # block size for scaling, default is 32 to match + # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, + # section 5.2 + block_size: int = 32 + + # element dtype, used for activations, weights and gradients + elem_dtype: Any = torch.float8_e4m3fn + + # overrides for element dtype for weights and gradients + # TODO(future PR): refactor to make this cleaner + elem_dtype_weight_override: Optional[Any] = None + elem_dtype_grad_output_override: Optional[Any] = None + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False + + def __post_init__(self): + assert ( + self.elem_dtype in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_weight_override is not None: + assert ( + self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + if self.elem_dtype_grad_output_override is not None: + assert ( + self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES + ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 72c2b6ab39..a38a8c5499 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -107,22 +107,11 @@ class MXLinear(torch.nn.Linear): def from_float( cls, mod, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, - *, - # TODO(next PR): move elem_dtype* and block size into config - config: MXLinearConfig = None, - block_size=32, + config: Optional[MXLinearConfig] = MXLinearConfig(), ): + # TODO(before land): remove this + assert isinstance(config, MXLinearConfig) mod.__class__ = MXLinear - mod.in_elem_dtype = elem_dtype - mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype - mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype - mod.block_size = block_size - # TODO(next PR): fix this - if config is None: - config = MXLinearConfig() mod.config = config return mod @@ -135,13 +124,14 @@ def forward(self, x): else: w = self.weight + config = self.config y = mx_mm.apply( x, w, - self.in_elem_dtype, - self.w_elem_dtype, - self.grad_elem_dtype, - self.block_size, + config.elem_dtype, + config.elem_dtype_weight_override or config.elem_dtype, + config.elem_dtype_grad_output_override or config.elem_dtype, + config.block_size, ) if self.bias is not None: y = y + self.bias @@ -158,9 +148,11 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): - # TODO(next PR): move elem_dtype and block_size into config - + def from_float( + cls, + mod, + config: Optional[MXLinearConfig] = MXLinearConfig(), + ): with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -171,10 +163,9 @@ def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight, elem_dtype, block_size=block_size + mod.weight, config.elem_dtype, block_size=config.block_size ) new_mod.bias = mod.bias - new_mod.elem_dtype = elem_dtype new_mod.config = config return new_mod @@ -213,13 +204,8 @@ def _is_linear(mod, fqn): def swap_linear_with_mx_linear( model, - elem_dtype, - elem_dtype_weight_override=None, - elem_dtype_grad_output_override=None, *, - # TODO(next PR): move elem_dtype* and block_size into config config: Optional[MXLinearConfig] = None, - block_size=32, filter_fn=None, ): if filter_fn is None: @@ -232,24 +218,16 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXLinear.from_float( - mod, - elem_dtype, - elem_dtype_weight_override, - elem_dtype_grad_output_override, - config=config, - block_size=block_size, - ), + lambda mod: MXLinear.from_float(mod, config=config), combined_filter_fn, ) def swap_linear_with_mx_inference_linear( model, - elem_dtype, - block_size, - filter_fn=None, + *, config: Optional[MXLinearConfig] = None, + filter_fn=None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -261,8 +239,6 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float( - mod, elem_dtype, block_size, config=config - ), + lambda mod: MXInferenceLinear.from_float(mod, config=config), combined_filter_fn, ) From 24399306aefe3609c525bab29e133492f2bdafc2 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 13 Feb 2025 16:02:13 -0800 Subject: [PATCH 7/8] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 51 +++++++++++++++++++-- test/prototype/mx_formats/test_mx_tensor.py | 2 + torchao/prototype/mx_formats/README.md | 15 +++++- torchao/prototype/mx_formats/config.py | 38 ++++++++++++++- torchao/prototype/mx_formats/mx_linear.py | 43 +++++++++++++---- torchao/prototype/mx_formats/mx_ops.py | 42 ++++++++++++++--- torchao/prototype/mx_formats/mx_tensor.py | 11 ++++- 7 files changed, 180 insertions(+), 22 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index c2eb66960f..87451bf621 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -11,8 +11,8 @@ import torch import torch.nn as nn -from torchao.prototype.mx_formats.config import MXLinearConfig -from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES +from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig +from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, MXLinear, @@ -50,7 +50,9 @@ def run_around_tests(): @pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) def test_linear_eager(elem_dtype, bias, input_shape): """ - Smoke test for training linear module with mx weight + Smoke test for training linear module with mx weight, compares the following: + * baseline: float32 + * experiment: emulated MX """ # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) @@ -92,6 +94,49 @@ def test_linear_eager(elem_dtype, bias, input_shape): assert x_g_sqnr >= 8.0 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" +) +@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, DTYPE_FP4]) +@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)]) +def test_linear_eager_emulated_vs_real_gemm(elem_dtype, mkn): + M, K, N = 128, 128, 128 + M, K, N = mkn + + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_() + x_copy = copy.deepcopy(x) + g = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + m_emulated = nn.Sequential( + nn.Linear(K, N, bias=False, device="cuda", dtype=torch.bfloat16), + ) + m_real = copy.deepcopy(m_emulated) + + config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype) + config_real = MXLinearConfig( + block_size=32, + elem_dtype=elem_dtype, + gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + ) + + swap_linear_with_mx_linear(m_emulated, config=config_emulated) + swap_linear_with_mx_linear(m_real, config=config_real) + + y_emulated = m_emulated(x) + y_emulated.backward(g) + + y_real = m_real(x_copy) + y_real.backward(g) + + with torch.no_grad(): + y_sqnr = compute_error(y_real, y_emulated) + w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad) + g_sqnr = compute_error(x_copy.grad, x.grad) + assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!" + assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!" + assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!" + + # TODO(future): enable compile support @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") def test_activation_checkpointing(): diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 2a15961586..f5014b7e31 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -7,6 +7,7 @@ import pytest import torch +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_FP6_E2M3, @@ -146,6 +147,7 @@ def test_exponent_nan_out(elem_dtype): block_size, torch.float, use_fp4_custom_triton_dequant_kernel, + MXGemmKernelChoice.EMULATED, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp[0:1])) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 09e7563ebb..1f1db18b7d 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -41,10 +41,21 @@ This is a module to do MX training, the MX matmul is currently emulated. ```python from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear -from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice +from torchao.utils import is_sm_at_least_100 + +# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by +# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support. +gemm_kernel_choice = MXGemmKernelChoice.EMULATED +if is_sm_at_least_100(): + gemm_kernel_choice = MXGemmKernelChoice.CUTLASS m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() -config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32) +config = MXLinearConfig( + elem_dtype=torch.float8_e4m3fn, + block_size=32, + gemm_kernel_choice=gemm_kernel_choice, +) swap_linear_with_mx_linear(m, config=config) # training loop (not shown) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 7cdf2d4e58..5f57ec940f 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,11 +5,26 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass +from enum import StrEnum from typing import Any, Optional import torch -from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + SUPPORTED_ELEM_DTYPES, +) + + +class MXGemmKernelChoice(StrEnum): + # always available - MX operands are dequantized and a high precision + # gemm is run + EMULATED = "emulated" + + # available only when CUDA capability is greater than or equal to 10.0 + CUTLASS = "cutlass" + + # TODO(future PR): add cuBLAS here once we land pytorch/pytorch support @dataclass @@ -27,10 +42,15 @@ class MXLinearConfig: elem_dtype_weight_override: Optional[Any] = None elem_dtype_grad_output_override: Optional[Any] = None + # defines the gemm kernel choice, if the chosen kernel is not supported + # on the given hardware an exception will be thrown + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False def __post_init__(self): + # validate elem_dtype and its overrides assert ( self.elem_dtype in SUPPORTED_ELEM_DTYPES ), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" @@ -42,3 +62,19 @@ def __post_init__(self): assert ( self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES ), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}" + + # validate that block size and elem_dtype matches kernel choice + if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: + assert ( + self.block_size == 32 + ), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}" + valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4] + assert ( + self.elem_dtype in valid_dtypes + ), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}" + assert ( + self.elem_dtype_weight_override is None + ), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels" + assert ( + self.elem_dtype_grad_output_override is None + ), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels" diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index a38a8c5499..e15f2ad727 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -13,7 +13,7 @@ import torch import torch.nn.functional as F -from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -36,19 +36,25 @@ def forward( w_elem_dtype: Any, grad_elem_dtype: Any, block_size: int, + gemm_kernel_choice: MXGemmKernelChoice, ): ctx.save_for_backward(input_hp, weight_hp) ctx.in_elem_dtype = in_elem_dtype ctx.w_elem_dtype = w_elem_dtype ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size + ctx.gemm_kernel_choice = gemm_kernel_choice # input @ weight_t = output input_orig_shape = input_hp.shape input_hp_r = input_hp.reshape(-1, input_orig_shape[-1]) - input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size) - weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size) + input_mx_r_dim0 = MXTensor.to_mx( + input_hp_r, in_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + ) + weight_mx_dim0 = MXTensor.to_mx( + weight_hp, w_elem_dtype, block_size, gemm_kernel_choice=gemm_kernel_choice + ) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) output = output.reshape(*input_orig_shape[:-1], output.shape[-1]) @@ -62,6 +68,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): w_elem_dtype = ctx.w_elem_dtype grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size + gemm_kernel_choice = ctx.gemm_kernel_choice grad_output_orig_shape = grad_output_hp.shape grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) @@ -71,9 +78,17 @@ def backward(ctx, grad_output_hp: torch.Tensor): # grad_output @ weight = grad_input grad_output_mx_dim0 = MXTensor.to_mx( - grad_output_hp_r, grad_elem_dtype, block_size + grad_output_hp_r, + grad_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + ) + weight_mx_dim1 = MXTensor.to_mx( + weight_hp_t_c, + w_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) - weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] @@ -81,15 +96,21 @@ def backward(ctx, grad_output_hp: torch.Tensor): # input_t @ grad_output = grad_weight grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size + grad_output_hp_r.t().contiguous(), + grad_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), in_elem_dtype, block_size + input_hp_r.t().contiguous(), + in_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None, None, None + return grad_input, grad_weight, None, None, None, None, None class MXLinear(torch.nn.Linear): @@ -132,6 +153,7 @@ def forward(self, x): config.elem_dtype_weight_override or config.elem_dtype, config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, + config.gemm_kernel_choice, ) if self.bias is not None: y = y + self.bias @@ -163,7 +185,10 @@ def from_float( # TODO(future PR): set to new_mod.weight directly, will need to work # through some errors new_mod.weight_mx = MXTensor.to_mx( - mod.weight, config.elem_dtype, block_size=config.block_size + mod.weight, + config.elem_dtype, + block_size=config.block_size, + gemm_kernel_choice=config.gemm_kernel_choice, ) new_mod.bias = mod.bias new_mod.config = config diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 5fb3e8c6c0..16e61e0653 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -22,11 +22,15 @@ import torch from torch.utils._pytree import tree_map +# from torchao.ops import mx_fp4_bf16, mx_fp8_bf16 +import torchao.ops +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import DTYPE_FP4 from torchao.prototype.mx_formats.mx_tensor import ( # noqa: E501 MXTensor, tensor_size_hp_to_fp4x2, ) +from torchao.prototype.mx_formats.utils import to_blocked aten = torch.ops.aten @@ -55,6 +59,7 @@ def mx_desugar_op(aten_op, args, kwargs=None): old._block_size, old._orig_dtype, old._use_fp4_custom_triton_dequant_kernel, + old._gemm_kernel_choice, ) return new @@ -64,12 +69,34 @@ def mx_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) - a_hp = a.to_dtype(a._orig_dtype) - b_hp = b.to_dtype(b._orig_dtype) - # assert memory layout we expect to be required in hardware - assert a_hp.is_contiguous() - assert b_hp.t().is_contiguous() - res = aten_op(a_hp, b_hp) + assert a._gemm_kernel_choice == b._gemm_kernel_choice, "unsupported" + if a._gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: + # real MX gemm backed by torchao's CUTLASS kernels + M, K, N = a.shape[0], a.shape[1], b.shape[1] + assert b._data.t().is_contiguous() + a_scale = a._scale_e8m0.view(M, K // 32) + b_scale = b._scale_e8m0.view(N, K // 32) + a_scale_block = to_blocked(a_scale) + b_scale_block = to_blocked(b_scale) + if a._elem_dtype == torch.float8_e4m3fn: + assert b._elem_dtype == torch.float8_e4m3fn + res = torchao.ops.mx_fp8_bf16( + a._data, b._data, a_scale_block, b_scale_block + ) + else: + assert a._elem_dtype == DTYPE_FP4 + assert b._elem_dtype == DTYPE_FP4 + res = torchao.ops.mx_fp4_bf16( + a._data, b._data, a_scale_block, b_scale_block + ) + else: + # emulated MX gemm + a_hp = a.to_dtype(a._orig_dtype) + b_hp = b.to_dtype(b._orig_dtype) + # assert memory layout we expect to be required in hardware + assert a_hp.is_contiguous() + assert b_hp.t().is_contiguous() + res = aten_op(a_hp, b_hp) return res @@ -84,6 +111,7 @@ def mx_t(aten_op, args, kwargs=None): old._block_size, old._orig_dtype, old._use_fp4_custom_triton_dequant_kernel, + old._gemm_kernel_choice, ) return new @@ -123,6 +151,7 @@ def mx_view_op(aten_op, args, kwargs=None): args[0]._block_size, args[0]._orig_dtype, args[0]._use_fp4_custom_triton_dequant_kernel, + args[0]._gemm_kernel_choice, ) @@ -147,5 +176,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): args[0]._block_size, kwargs["dtype"], args[0]._use_fp4_custom_triton_dequant_kernel, + args[0]._gemm_kernel_choice, ) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 838ab2338c..6c0a718c78 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,6 +21,7 @@ import torch +from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP4, @@ -331,6 +332,7 @@ def forward( block_size, scaling_mode, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode @@ -342,11 +344,12 @@ def forward( block_size, data_hp.dtype, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ) @staticmethod def backward(ctx, g): - return g, None, None, None, None + return g, None, None, None, None, None @torch._dynamo.allow_in_graph @@ -380,6 +383,7 @@ def __new__( block_size, orig_dtype, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ): new_size = data_bits.size() if elem_dtype == DTYPE_FP4: @@ -440,6 +444,7 @@ def __new__( self._use_fp4_custom_triton_dequant_kernel = ( use_fp4_custom_triton_dequant_kernel ) + self._gemm_kernel_choice = gemm_kernel_choice return self def __repr__(self): @@ -467,6 +472,7 @@ def to_mx( block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, use_fp4_custom_triton_dequant_kernel: bool = False, + gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, ): return ToMXConstrFunc.apply( data_hp, @@ -474,6 +480,7 @@ def to_mx( block_size, scaling_mode, use_fp4_custom_triton_dequant_kernel, + gemm_kernel_choice, ) def __tensor_flatten__(self): @@ -482,6 +489,7 @@ def __tensor_flatten__(self): "_block_size": self._block_size, "_orig_dtype": self._orig_dtype, "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, + "_gemm_kernel_choice": self._gemm_kernel_choice, } return ["_scale_e8m0", "_data"], ctx @@ -499,6 +507,7 @@ def __tensor_unflatten__( metadata["_block_size"], metadata["_orig_dtype"], metadata["_use_fp4_custom_triton_dequant_kernel"], + metadata["_gemm_kernel_choice"], ) # Do not force the MXTensor type on the returned tensor From 12add12f02329187dede12ec3c2b1fc993e51eb0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 13 Feb 2025 16:10:33 -0800 Subject: [PATCH 8/8] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 5f57ec940f..d511d2614d 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from enum import StrEnum +from enum import Enum from typing import Any, Optional import torch @@ -16,7 +16,7 @@ ) -class MXGemmKernelChoice(StrEnum): +class MXGemmKernelChoice(Enum): # always available - MX operands are dequantized and a high precision # gemm is run EMULATED = "emulated"