Skip to content

Commit

Permalink
Added MXFP6 packing and fused unpack-dequantise kernels, pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-titterton committed Feb 10, 2025
1 parent cc6244c commit f7b03ca
Show file tree
Hide file tree
Showing 7 changed files with 747 additions and 52 deletions.
41 changes: 41 additions & 0 deletions test/prototype/mx_formats/test_custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
f32_to_f6_e3m2_unpacked,
get_bits,
pack_uint4,
pack_uint6,
triton_f4_to_bf16,
triton_f6_e2m3_to_bf16,
triton_f6_e3m2_to_bf16,
unpack_uint4,
)
from torchao.prototype.mx_formats.fp_format_spec import (
Expand Down Expand Up @@ -411,3 +414,41 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):

f6_e3m2_unpacked = f32_to_f6_e3m2_unpacked(torch.tensor(-f32_val, device=device))
assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e2m3_pack_unpack():
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e2m3_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
)
def test_fp6_e3m2_pack_unpack():
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
"cuda"
)
orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals)
orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked)
assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4)
orig_vals_f6_packed_unpacked = triton_f6_e3m2_to_bf16(orig_vals_f6_packed).to(
torch.float32
)
assert torch.all(orig_vals_f6_packed_unpacked == orig_vals)
26 changes: 13 additions & 13 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
"""
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
grad_shape[-1] = 6
grad_shape[-1] = 8

m = nn.Sequential(
nn.Linear(8, 6, bias=bias, device="cuda"),
nn.Linear(8, 8, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
block_size = 2
block_size = 4
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
Expand Down Expand Up @@ -90,14 +90,14 @@ def test_linear_eager(elem_dtype, bias, input_shape):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_activation_checkpointing():
input_shape = (2, 4)
grad_shape = (2, 6)
grad_shape = (2, 8)
elem_dtype = torch.float8_e4m3fn

m = nn.Sequential(
nn.Linear(4, 6, bias=True, device="cuda"),
nn.Linear(6, 6, bias=True, device="cuda"),
nn.Linear(4, 8, bias=True, device="cuda"),
nn.Linear(8, 8, bias=True, device="cuda"),
)
block_size = 2
block_size = 4
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
Expand Down Expand Up @@ -127,13 +127,13 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
M, K, N = 4, 8, 6
M, K, N = 4, 8, 8
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
block_size = 2
block_size = 4
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
Expand Down Expand Up @@ -178,10 +178,10 @@ def test_inference_linear(elem_dtype, bias, input_shape):
"""
Smoke test for inference linear module with mx weight
"""
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
block_size = 4
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
Expand All @@ -206,10 +206,10 @@ def test_inference_compile_simple(elem_dtype):
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
block_size = 4
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
m_mx = torch.compile(m_mx, fullgraph="true")

Expand Down
80 changes: 56 additions & 24 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.custom_cast import pack_uint4
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
from torchao.prototype.mx_formats.mx_tensor import (
E8M0_EXPONENT_NAN_VAL,
MXTensor,
Expand Down Expand Up @@ -70,15 +70,15 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_hello_world(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_zeros(elem_dtype):
data = torch.zeros(4, 4, device="cuda", dtype=torch.bfloat16)
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -88,7 +88,7 @@ def test_some_zeros(elem_dtype):
data = torch.randn(4, 4, device="cuda", dtype=torch.bfloat16)
data[0, :] = 0.0
data[:, 2] = 0.0
block_size = 2
block_size = 4
_test_mx(data, elem_dtype, block_size)


Expand All @@ -100,9 +100,9 @@ def test_exponent_nan_in(elem_dtype):
value is set to is NaN
"""
tensor_hp = torch.tensor(
[float("nan"), 1, 2, 3, 4, 5], device="cuda", dtype=torch.bfloat16
[float("nan"), 1, 2, 3, 4, 5, 6, 7], device="cuda", dtype=torch.bfloat16
)
block_size = 2
block_size = 4
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL)
assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL)
Expand All @@ -115,24 +115,36 @@ def test_exponent_nan_out(elem_dtype):
If block exponent value is NaN, the MX tensor block value is NaN
"""
scale_e8m0_bits = torch.tensor(
[E8M0_EXPONENT_NAN_VAL, 23, 42], dtype=torch.uint8, device="cuda"
[E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda"
)

block_size = 4

if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=elem_dtype, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=elem_dtype, device="cuda"
) # noqa: E501
elif elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2):
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
if config.pack_fp6:
data_bits = data_bits.reshape(-1, block_size)
data_bits = pack_uint6(data_bits)
elif elem_dtype == DTYPE_FP4:
data_bits = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.uint8, device="cuda") # noqa: E501
data_bits = torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda"
) # noqa: E501
data_bits = pack_uint4(data_bits)
else:
raise AssertionError("unsupported")
block_size = 2

tensor_mx = MXTensor(
scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
assert not torch.any(torch.isnan(tensor_hp[2:]))
assert torch.all(torch.isnan(tensor_hp.flatten()[0:4]))
assert not torch.any(torch.isnan(tensor_hp.flatten()[4:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -141,24 +153,26 @@ def test_ranks(elem_dtype):
"""
The reshaping logic works for various ranks
"""
B = 2
shapes = ((B * 4,), (B * 4, 2), (B * 4, 2, 2), (B * 4, 2, 2, 2))
B = 4
shapes = ((B * 4,), (B * 4, 4), (B * 4, 4, 4), (B * 4, 4, 4, 4))
for s in shapes:
tensor_hp = torch.randn(*s, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_block_sizes(elem_dtype):
@pytest.mark.parametrize("B", [1, 4, 32])
def test_block_sizes(elem_dtype, B):
"""
Smoke test for various block sizes
"""
for B in (1, 2, 32):
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)
if B == 1 and elem_dtype == DTYPE_FP4:
pytest.skip("unsupported configuration")
elif B % 4 != 0 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]:
pytest.skip("unsupported configuration")
tensor_hp = torch.randn(B, device="cuda", dtype=torch.bfloat16)
_test_mx(tensor_hp, elem_dtype, B)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -202,14 +216,32 @@ def test_cast_autograd(elem_dtype):
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_view(elem_dtype):
x = torch.randn(1, 2, 4)
block_size = 2
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
x_mx_2 = x_mx.view(2, 4) # noqa: F841


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2])
@pytest.mark.parametrize("do_fp6_packing", [False, True])
def test_fp6_packing(elem_dtype, do_fp6_packing):
config.pack_fp6 = do_fp6_packing
x = torch.randn(1, 2, 4, device="cuda")
block_size = 4
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
if config.pack_fp6:
expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4])
else:
expected_packed_shape = x.shape
config.pack_fp6 = True

assert x_mx._data.shape == expected_packed_shape


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
Expand All @@ -231,7 +263,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
x = torch.randn(*shape, dtype=hp_dtype, device="cuda")
else:
x = torch.zeros(*shape, dtype=hp_dtype, device="cuda")
block_size = 2
block_size = 4
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)

x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel = False
pack_fp6 = True
Loading

0 comments on commit f7b03ca

Please sign in to comment.