Skip to content

Commit

Permalink
MX: move block_size and elem_dtype into MXLinearConfig
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ab131f5cca3b85a6ef9cb1eaca3b32ea0d717b2f
ghstack-comment-id: 2649054194
Pull Request resolved: #1689
  • Loading branch information
vkuzo committed Feb 14, 2025
1 parent 5371cff commit e85d370
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 75 deletions.
47 changes: 19 additions & 28 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.mx_linear import (
MXInferenceLinear,
Expand Down Expand Up @@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
nn.Linear(8, 6, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)
config = MXLinearConfig(
block_size=2,
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
)
swap_linear_with_mx_linear(m_mx, config=config)

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
x = copy.deepcopy(x_ref)
Expand Down Expand Up @@ -97,8 +103,8 @@ def test_activation_checkpointing():
nn.Linear(4, 6, bias=True, device="cuda"),
nn.Linear(6, 6, bias=True, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m, config=config)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
g = torch.randn(*grad_shape, device="cuda")
Expand Down Expand Up @@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")

Expand Down Expand Up @@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape):
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
Expand All @@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype):
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)
m_mx = torch.compile(m_mx, fullgraph="true")

x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16)
Expand All @@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype):
assert sqnr >= 13.5


def test_mx_linear_input_weight_gradient_dtypes():
m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32)
assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0]
assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1]
assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2]

m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32)
assert m[0].in_elem_dtype == torch.float8_e4m3fn
assert m[0].w_elem_dtype == torch.float8_e4m3fn
assert m[0].grad_elem_dtype == torch.float8_e4m3fn


def test_filter_fn():
m1 = nn.Sequential(
nn.Linear(32, 32),
Expand All @@ -245,12 +237,11 @@ def test_filter_fn():
m2 = copy.deepcopy(m1)
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731

swap_linear_with_mx_linear(
m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn
)
config = MXLinearConfig(block_size=32)
swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn)
assert type(m1[0]) == MXLinear
assert type(m1[1]) == torch.nn.Linear

swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
assert type(m2[0]) == MXInferenceLinear
assert type(m2[1]) == torch.nn.Linear
11 changes: 6 additions & 5 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated.

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.config import MXLinearConfig

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
elem_dtype = torch.float8_e4m3fn
swap_linear_with_mx_linear(m, elem_dtype, block_size=32)
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_linear(m, config=config)

# training loop (not shown)
```
Expand All @@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
from torchao.prototype.mx_formats.config import MXLinearConfig

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
elem_dtype = torch.float8_e4m3fn
block_size = 32
swap_linear_with_mx_inference_linear(m, elem_dtype, block_size)
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_inference_linear(m, config=config)

# do inference (not shown)
```
Expand Down
31 changes: 31 additions & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,40 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Optional

import torch

from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES


@dataclass
class MXLinearConfig:
# block size for scaling, default is 32 to match
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
# section 5.2
block_size: int = 32

# element dtype, used for activations, weights and gradients
elem_dtype: Any = torch.float8_e4m3fn

# overrides for element dtype for weights and gradients
# TODO(future PR): refactor to make this cleaner
elem_dtype_weight_override: Optional[Any] = None
elem_dtype_grad_output_override: Optional[Any] = None

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

def __post_init__(self):
assert (
self.elem_dtype in SUPPORTED_ELEM_DTYPES
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
if self.elem_dtype_weight_override is not None:
assert (
self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
if self.elem_dtype_grad_output_override is not None:
assert (
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
60 changes: 18 additions & 42 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,11 @@ class MXLinear(torch.nn.Linear):
def from_float(
cls,
mod,
elem_dtype,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block size into config
config: MXLinearConfig = None,
block_size=32,
config: Optional[MXLinearConfig] = MXLinearConfig(),
):
# TODO(before land): remove this
assert isinstance(config, MXLinearConfig)
mod.__class__ = MXLinear
mod.in_elem_dtype = elem_dtype
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
mod.block_size = block_size
# TODO(next PR): fix this
if config is None:
config = MXLinearConfig()
mod.config = config
return mod

Expand All @@ -135,13 +124,14 @@ def forward(self, x):
else:
w = self.weight

config = self.config
y = mx_mm.apply(
x,
w,
self.in_elem_dtype,
self.w_elem_dtype,
self.grad_elem_dtype,
self.block_size,
config.elem_dtype,
config.elem_dtype_weight_override or config.elem_dtype,
config.elem_dtype_grad_output_override or config.elem_dtype,
config.block_size,
)
if self.bias is not None:
y = y + self.bias
Expand All @@ -158,9 +148,11 @@ class MXInferenceLinear(torch.nn.Linear):

@classmethod
@torch.no_grad()
def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
# TODO(next PR): move elem_dtype and block_size into config

def from_float(
cls,
mod,
config: Optional[MXLinearConfig] = MXLinearConfig(),
):
with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
Expand All @@ -171,10 +163,9 @@ def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
# TODO(future PR): set to new_mod.weight directly, will need to work
# through some errors
new_mod.weight_mx = MXTensor.to_mx(
mod.weight, elem_dtype, block_size=block_size
mod.weight, config.elem_dtype, block_size=config.block_size
)
new_mod.bias = mod.bias
new_mod.elem_dtype = elem_dtype
new_mod.config = config
return new_mod

Expand Down Expand Up @@ -213,13 +204,8 @@ def _is_linear(mod, fqn):

def swap_linear_with_mx_linear(
model,
elem_dtype,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block_size into config
config: Optional[MXLinearConfig] = None,
block_size=32,
filter_fn=None,
):
if filter_fn is None:
Expand All @@ -232,24 +218,16 @@ def __fn(mod, fqn):
combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXLinear.from_float(
mod,
elem_dtype,
elem_dtype_weight_override,
elem_dtype_grad_output_override,
config=config,
block_size=block_size,
),
lambda mod: MXLinear.from_float(mod, config=config),
combined_filter_fn,
)


def swap_linear_with_mx_inference_linear(
model,
elem_dtype,
block_size,
filter_fn=None,
*,
config: Optional[MXLinearConfig] = None,
filter_fn=None,
):
if filter_fn is None:
combined_filter_fn = _is_linear
Expand All @@ -261,8 +239,6 @@ def __fn(mod, fqn):
combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXInferenceLinear.from_float(
mod, elem_dtype, block_size, config=config
),
lambda mod: MXInferenceLinear.from_float(mod, config=config),
combined_filter_fn,
)

0 comments on commit e85d370

Please sign in to comment.