Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 10, 2025
1 parent 0220b19 commit e4b5ded
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 21 deletions.
20 changes: 13 additions & 7 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest
import torch

from torchao.prototype.mx_formats import config
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
Expand Down Expand Up @@ -139,8 +138,14 @@ def test_exponent_nan_out(elem_dtype):
else:
raise AssertionError("unsupported")
block_size = 2
use_fp4_custom_triton_dequant_kernel = False
tensor_mx = MXTensor(
scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float
scale_e8m0_bits,
data_bits,
elem_dtype,
block_size,
torch.float,
use_fp4_custom_triton_dequant_kernel,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
Expand Down Expand Up @@ -188,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton):
M, K = 128, 256
block_size = 32
tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
tensor_mx = MXTensor.to_mx(
tensor_hp,
elem_dtype,
block_size,
use_fp4_custom_triton_dequant_kernel=fp4_triton,
)
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
config.use_fp4_custom_triton_dequant_kernel = False

tensor_mx_t = tensor_mx.t()
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype)
config.use_fp4_custom_triton_dequant_kernel = False

assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
Expand Down
15 changes: 13 additions & 2 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel = False
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass


@dataclass
class MXLinearConfig:
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False
22 changes: 19 additions & 3 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
Defines the prototype UX for converting a model to use mx weights
"""

from typing import Any
from typing import Any, Optional

import torch
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.mx_tensor import MXTensor


Expand Down Expand Up @@ -110,13 +111,19 @@ def from_float(
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block size into config
config: MXLinearConfig = None,
block_size=32,
):
mod.__class__ = MXLinear
mod.in_elem_dtype = elem_dtype
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
mod.block_size = block_size
# TODO(next PR): fix this
if config is None:
config = MXLinearConfig()
mod.config = config
return mod

def forward(self, x):
Expand Down Expand Up @@ -151,7 +158,9 @@ class MXInferenceLinear(torch.nn.Linear):

@classmethod
@torch.no_grad()
def from_float(cls, mod, elem_dtype, block_size):
def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
# TODO(next PR): move elem_dtype and block_size into config

with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
Expand All @@ -166,6 +175,7 @@ def from_float(cls, mod, elem_dtype, block_size):
)
new_mod.bias = mod.bias
new_mod.elem_dtype = elem_dtype
new_mod.config = config
return new_mod

@torch.no_grad()
Expand Down Expand Up @@ -207,6 +217,8 @@ def swap_linear_with_mx_linear(
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block_size into config
config: Optional[MXLinearConfig] = None,
block_size=32,
filter_fn=None,
):
Expand All @@ -225,6 +237,7 @@ def __fn(mod, fqn):
elem_dtype,
elem_dtype_weight_override,
elem_dtype_grad_output_override,
config=config,
block_size=block_size,
),
combined_filter_fn,
Expand All @@ -236,6 +249,7 @@ def swap_linear_with_mx_inference_linear(
elem_dtype,
block_size,
filter_fn=None,
config: Optional[MXLinearConfig] = None,
):
if filter_fn is None:
combined_filter_fn = _is_linear
Expand All @@ -247,6 +261,8 @@ def __fn(mod, fqn):
combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size),
lambda mod: MXInferenceLinear.from_float(
mod, elem_dtype, block_size, config=config
),
combined_filter_fn,
)
6 changes: 4 additions & 2 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def mx_desugar_op(aten_op, args, kwargs=None):
old._elem_dtype,
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
)
return new

Expand Down Expand Up @@ -82,6 +83,7 @@ def mx_t(aten_op, args, kwargs=None):
old._elem_dtype,
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
)
return new

Expand Down Expand Up @@ -120,6 +122,7 @@ def mx_view_op(aten_op, args, kwargs=None):
args[0]._elem_dtype,
args[0]._block_size,
args[0]._orig_dtype,
args[0]._use_fp4_custom_triton_dequant_kernel,
)


Expand All @@ -130,7 +133,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
tensor.
"""
assert isinstance(args[0], MXTensor)
# print('before', args[0], args[0].dtype, args[0]._orig_dtype)
assert (
len(kwargs) == 1 and "dtype" in kwargs
), "Only support dtype kwarg for autocast"
Expand All @@ -144,6 +146,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
args[0]._elem_dtype,
args[0]._block_size,
kwargs["dtype"],
args[0]._use_fp4_custom_triton_dequant_kernel,
)
# print('after', res, res.dtype, res._orig_dtype)
return res
46 changes: 39 additions & 7 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import torch

import torchao.prototype.mx_formats.config as config
from torchao.prototype.mx_formats.constants import (
BLOCK_SIZE_DEFAULT,
DTYPE_FP4,
Expand Down Expand Up @@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0):
return s_fp


def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
def to_dtype(
data_lp,
scale_e8m0,
elem_dtype,
block_size,
target_dtype,
use_fp4_custom_triton_dequant_kernel,
):
orig_shape = data_lp.shape
is_transposed = not data_lp.is_contiguous()
# if the underlying data is transposed, convert to row major before
Expand All @@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
data_hp = f6_e3m2_unpacked_to_f32(data_lp)
data_hp = data_hp.to(target_dtype)
elif elem_dtype == DTYPE_FP4:
if config.use_fp4_custom_triton_dequant_kernel:
if use_fp4_custom_triton_dequant_kernel:
data_hp_rescaled = triton_f4_to_scaled_bf16(
data_lp,
scale_e8m0,
Expand Down Expand Up @@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode):
def forward(
ctx,
data_hp,
elem_dtype,
block_size,
scaling_mode,
use_fp4_custom_triton_dequant_kernel,
):
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode
)
return MXTensor(
scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype
scale_e8m0_biased,
data_lp,
elem_dtype,
block_size,
data_hp.dtype,
use_fp4_custom_triton_dequant_kernel,
)

@staticmethod
def backward(ctx, g):
return g, None, None, None
return g, None, None, None, None


@torch._dynamo.allow_in_graph
Expand All @@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype):
tensor_lp._elem_dtype,
tensor_lp._block_size,
target_dtype,
tensor_lp._use_fp4_custom_triton_dequant_kernel,
)

@staticmethod
Expand All @@ -360,6 +379,7 @@ def __new__(
elem_dtype,
block_size,
orig_dtype,
use_fp4_custom_triton_dequant_kernel,
):
new_size = data_bits.size()
if elem_dtype == DTYPE_FP4:
Expand Down Expand Up @@ -417,6 +437,9 @@ def __new__(
self._elem_dtype = elem_dtype
self._block_size = block_size
self._orig_dtype = orig_dtype
self._use_fp4_custom_triton_dequant_kernel = (
use_fp4_custom_triton_dequant_kernel
)
return self

def __repr__(self):
Expand All @@ -443,14 +466,22 @@ def to_mx(
elem_dtype: Union[torch.dtype, str],
block_size: int = BLOCK_SIZE_DEFAULT,
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
use_fp4_custom_triton_dequant_kernel: bool = False,
):
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode)
return ToMXConstrFunc.apply(
data_hp,
elem_dtype,
block_size,
scaling_mode,
use_fp4_custom_triton_dequant_kernel,
)

def __tensor_flatten__(self):
ctx = {
"_elem_dtype": self._elem_dtype,
"_block_size": self._block_size,
"_orig_dtype": self._orig_dtype,
"_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel,
}
return ["_scale_e8m0", "_data"], ctx

Expand All @@ -467,6 +498,7 @@ def __tensor_unflatten__(
metadata["_elem_dtype"],
metadata["_block_size"],
metadata["_orig_dtype"],
metadata["_use_fp4_custom_triton_dequant_kernel"],
)

# Do not force the MXTensor type on the returned tensor
Expand Down

0 comments on commit e4b5ded

Please sign in to comment.