diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index ad718beb9c..9e97e1c32b 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -7,7 +7,6 @@ import pytest import torch -from torchao.prototype.mx_formats import config from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_FP6_E2M3, @@ -139,8 +138,14 @@ def test_exponent_nan_out(elem_dtype): else: raise AssertionError("unsupported") block_size = 2 + use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( - scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float + scale_e8m0_bits, + data_bits, + elem_dtype, + block_size, + torch.float, + use_fp4_custom_triton_dequant_kernel, ) tensor_hp = tensor_mx.to_dtype(torch.float) assert torch.all(torch.isnan(tensor_hp[0:1])) @@ -188,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton): M, K = 128, 256 block_size = 32 tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) - config.use_fp4_custom_triton_dequant_kernel = fp4_triton + tensor_mx = MXTensor.to_mx( + tensor_hp, + elem_dtype, + block_size, + use_fp4_custom_triton_dequant_kernel=fp4_triton, + ) tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t() - config.use_fp4_custom_triton_dequant_kernel = False tensor_mx_t = tensor_mx.t() - config.use_fp4_custom_triton_dequant_kernel = fp4_triton tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype) - config.use_fp4_custom_triton_dequant_kernel = False assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 3e7e03d8f6..7b68b5b6a5 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -1,2 +1,13 @@ -# If True, uses a custom triton kernel for fp4 dequantize -use_fp4_custom_triton_dequant_kernel = False +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + + +@dataclass +class MXLinearConfig: + # If True, uses a custom triton kernel for fp4 dequantize + use_fp4_custom_triton_dequant_kernel: bool = False diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index d7aa744334..72c2b6ab39 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -8,11 +8,12 @@ Defines the prototype UX for converting a model to use mx weights """ -from typing import Any +from typing import Any, Optional import torch import torch.nn.functional as F +from torchao.prototype.mx_formats.config import MXLinearConfig from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -110,6 +111,8 @@ def from_float( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block size into config + config: MXLinearConfig = None, block_size=32, ): mod.__class__ = MXLinear @@ -117,6 +120,10 @@ def from_float( mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype mod.block_size = block_size + # TODO(next PR): fix this + if config is None: + config = MXLinearConfig() + mod.config = config return mod def forward(self, x): @@ -151,7 +158,9 @@ class MXInferenceLinear(torch.nn.Linear): @classmethod @torch.no_grad() - def from_float(cls, mod, elem_dtype, block_size): + def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig): + # TODO(next PR): move elem_dtype and block_size into config + with torch.device("meta"): super_kwargs = { "in_features": mod.in_features, @@ -166,6 +175,7 @@ def from_float(cls, mod, elem_dtype, block_size): ) new_mod.bias = mod.bias new_mod.elem_dtype = elem_dtype + new_mod.config = config return new_mod @torch.no_grad() @@ -207,6 +217,8 @@ def swap_linear_with_mx_linear( elem_dtype_weight_override=None, elem_dtype_grad_output_override=None, *, + # TODO(next PR): move elem_dtype* and block_size into config + config: Optional[MXLinearConfig] = None, block_size=32, filter_fn=None, ): @@ -225,6 +237,7 @@ def __fn(mod, fqn): elem_dtype, elem_dtype_weight_override, elem_dtype_grad_output_override, + config=config, block_size=block_size, ), combined_filter_fn, @@ -236,6 +249,7 @@ def swap_linear_with_mx_inference_linear( elem_dtype, block_size, filter_fn=None, + config: Optional[MXLinearConfig] = None, ): if filter_fn is None: combined_filter_fn = _is_linear @@ -247,6 +261,8 @@ def __fn(mod, fqn): combined_filter_fn = __fn replace_with_custom_fn_if_matches_filter( model, - lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size), + lambda mod: MXInferenceLinear.from_float( + mod, elem_dtype, block_size, config=config + ), combined_filter_fn, ) diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 57fb0d54b4..5fb3e8c6c0 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -54,6 +54,7 @@ def mx_desugar_op(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -82,6 +83,7 @@ def mx_t(aten_op, args, kwargs=None): old._elem_dtype, old._block_size, old._orig_dtype, + old._use_fp4_custom_triton_dequant_kernel, ) return new @@ -120,6 +122,7 @@ def mx_view_op(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, args[0]._orig_dtype, + args[0]._use_fp4_custom_triton_dequant_kernel, ) @@ -130,7 +133,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): tensor. """ assert isinstance(args[0], MXTensor) - # print('before', args[0], args[0].dtype, args[0]._orig_dtype) assert ( len(kwargs) == 1 and "dtype" in kwargs ), "Only support dtype kwarg for autocast" @@ -144,6 +146,6 @@ def autocast_to_copy(aten_op, args, kwargs=None): args[0]._elem_dtype, args[0]._block_size, kwargs["dtype"], + args[0]._use_fp4_custom_triton_dequant_kernel, ) - # print('after', res, res.dtype, res._orig_dtype) return res diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 801f29ac3c..838ab2338c 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -21,7 +21,6 @@ import torch -import torchao.prototype.mx_formats.config as config from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP4, @@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0): return s_fp -def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): +def to_dtype( + data_lp, + scale_e8m0, + elem_dtype, + block_size, + target_dtype, + use_fp4_custom_triton_dequant_kernel, +): orig_shape = data_lp.shape is_transposed = not data_lp.is_contiguous() # if the underlying data is transposed, convert to row major before @@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype): data_hp = f6_e3m2_unpacked_to_f32(data_lp) data_hp = data_hp.to(target_dtype) elif elem_dtype == DTYPE_FP4: - if config.use_fp4_custom_triton_dequant_kernel: + if use_fp4_custom_triton_dequant_kernel: data_hp_rescaled = triton_f4_to_scaled_bf16( data_lp, scale_e8m0, @@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function): """ @staticmethod - def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode): + def forward( + ctx, + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ): scale_e8m0_biased, data_lp = to_mx( data_hp, elem_dtype, block_size, scaling_mode ) return MXTensor( - scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype + scale_e8m0_biased, + data_lp, + elem_dtype, + block_size, + data_hp.dtype, + use_fp4_custom_triton_dequant_kernel, ) @staticmethod def backward(ctx, g): - return g, None, None, None + return g, None, None, None, None @torch._dynamo.allow_in_graph @@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype): tensor_lp._elem_dtype, tensor_lp._block_size, target_dtype, + tensor_lp._use_fp4_custom_triton_dequant_kernel, ) @staticmethod @@ -360,6 +379,7 @@ def __new__( elem_dtype, block_size, orig_dtype, + use_fp4_custom_triton_dequant_kernel, ): new_size = data_bits.size() if elem_dtype == DTYPE_FP4: @@ -417,6 +437,9 @@ def __new__( self._elem_dtype = elem_dtype self._block_size = block_size self._orig_dtype = orig_dtype + self._use_fp4_custom_triton_dequant_kernel = ( + use_fp4_custom_triton_dequant_kernel + ) return self def __repr__(self): @@ -443,14 +466,22 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, + use_fp4_custom_triton_dequant_kernel: bool = False, ): - return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode) + return ToMXConstrFunc.apply( + data_hp, + elem_dtype, + block_size, + scaling_mode, + use_fp4_custom_triton_dequant_kernel, + ) def __tensor_flatten__(self): ctx = { "_elem_dtype": self._elem_dtype, "_block_size": self._block_size, "_orig_dtype": self._orig_dtype, + "_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel, } return ["_scale_e8m0", "_data"], ctx @@ -467,6 +498,7 @@ def __tensor_unflatten__( metadata["_elem_dtype"], metadata["_block_size"], metadata["_orig_dtype"], + metadata["_use_fp4_custom_triton_dequant_kernel"], ) # Do not force the MXTensor type on the returned tensor