Skip to content

Remove preserve_zero and zero_point_domain from choose_qparams_affine #2149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
dequantize_affine,
quantize_affine,
Expand Down Expand Up @@ -112,7 +111,7 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
mapping_type = MappingType.SYMMETRIC
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int32
zero_point_domain = ZeroPointDomain.INT
# zero_point_domain is ZeroPointDomain.INT
block_size = (1, group_size)

scale, zero_point = choose_qparams_affine(
Expand All @@ -123,8 +122,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
eps=eps,
scale_dtype=torch.float32,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=zero_point_domain,
)

aqt = quantize_affine(
Expand All @@ -133,15 +130,12 @@ def test_uintx_weight_only_quant(dtype, group_size, device):
scale,
zero_point,
dtype,
zero_point_domain=zero_point_domain,
)
# Note: output will be uint8 tensor for sub byte tensors for now

q = to_uintx(aqt, dtype, -1)
assert q is not None, "quantization failed"
deqaunt = dequantize_affine(
q, block_size, scale, zero_point, dtype, zero_point_domain=zero_point_domain
)
deqaunt = dequantize_affine(q, block_size, scale, zero_point, dtype)
assert deqaunt is not None, "deqauntization failed"


Expand Down
1 change: 0 additions & 1 deletion test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# Owner(s): ["oncall: quantization"]
# ruff: noqa: F841


import unittest

import torch
Expand Down
252 changes: 90 additions & 162 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,16 @@
import unittest

import torch
from parameterized import parameterized

from torchao.float8.float8_utils import EPS as float8_eps
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
choose_qparams_affine_float8,
choose_qparams_affine_tinygemm,
dequantize_affine,
dequantize_affine_float8,
fake_quantize_affine,
fake_quantize_affine_cachemask,
quantize_affine,
quantize_affine_float8,
)

# TODO: remove test for utils?
Expand Down Expand Up @@ -650,35 +646,6 @@ def test_raises(self):
with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"):
_ = quantize_affine(input, block_size, scale, zero_point, dtype)

def test_not_preserve_zero_not_supported(self):
"""Making sure preserve_zero == False is not supported for symmetric quant"""
input = torch.randn(10, 256)
n_bit = 4
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
quant_max = 2**n_bit - 1
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
with self.assertRaisesRegex(
ValueError,
"preserve_zero == False is not supported for symmetric quantization",
):
choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
)

def test_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
Expand All @@ -702,22 +669,33 @@ def test_get_groupwise_affine_qparams(self):
dtype=torch.bfloat16,
zero_point_domain=zero_point_domain,
)
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=zero_point_domain == ZeroPointDomain.INT,
zero_point_domain=zero_point_domain,
)
if zero_point_domain == ZeroPointDomain.FLOAT:
scale, zero_point = choose_qparams_affine_tinygemm(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
)
else:
scale, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))
self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zero_point_ref))

def test_groupwise_affine_quantize_tensor_from_qparams(self):
input = torch.randn(10, 256)
Expand Down Expand Up @@ -847,119 +825,69 @@ def test_fake_quantize_affine_cachemask(self):
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

def test_none_zero_point_domain(self):
"""A None value for a ZeroPointDomain should not work, but ZeroPointDomain.NONE should"""
input = torch.randn(10, 256)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = None
quant_max = None
eps = 1e-6
scale_dtype = torch.float32
zero_point_dtype = torch.int64
try:
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=None,
)
except ValueError:
# This exception was expected
# Now test for ZeroPointDomain.NONE
_, zero_point = choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
self.assertTrue(zero_point is None)
else:
# An exception should have been thrown for zero_point_domain None
self.assertTrue(
False,
msg="A runtime exception should have been thrown for zero_point_domain None",
)

@parameterized.expand(
[
(
torch.float32,
torch.float8_e4m3fn,
),
(
torch.float32,
torch.float8_e5m2,
),
(
torch.bfloat16,
torch.float8_e4m3fn,
),
(
torch.bfloat16,
torch.float8_e5m2,
),
]
)
def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
input = torch.randn(10, 10)

# float8 quantization primitives
scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)

# reference implementation using generic primitives
expected_scale, _ = choose_qparams_affine(
input,
MappingType.SYMMETRIC,
input.shape,
float8_dtype,
eps=float8_eps, # use same EPS as float8 training
scale_dtype=torch.float32,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
)
expected_quantized = quantize_affine(
input,
input.shape,
scale,
output_dtype=float8_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=ZeroPointDomain.NONE,
)
expected_dequantized = dequantize_affine(
expected_quantized,
input.shape,
scale,
input_dtype=float8_dtype,
output_dtype=hp_dtype,
quant_min=torch.finfo(float8_dtype).min,
quant_max=torch.finfo(float8_dtype).max,
zero_point=None,
zero_point_domain=ZeroPointDomain.NONE,
)

self.assertTrue(torch.equal(expected_scale, scale))
torch.testing.assert_close(expected_quantized, quantized)
torch.testing.assert_close(expected_dequantized, dequantized)
# @parameterized.expand(
# [
# (
# torch.float32,
# torch.float8_e4m3fn,
# ),
# (
# torch.float32,
# torch.float8_e5m2,
# ),
# (
# torch.bfloat16,
# torch.float8_e4m3fn,
# ),
# (
# torch.bfloat16,
# torch.float8_e5m2,
# ),
# ]
# )
# def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
# input = torch.randn(10, 10)

# # float8 quantization primitives
# scale = choose_qparams_affine_float8(input, float8_dtype=float8_dtype)
# quantized = quantize_affine_float8(input, scale, float8_dtype=float8_dtype)
# dequantized = dequantize_affine_float8(quantized, scale, output_dtype=hp_dtype)

# # reference implementation using generic primitives
# expected_scale, _ = choose_qparams_affine(
# input,
# MappingType.SYMMETRIC,
# input.shape,
# float8_dtype,
# eps=float8_eps, # use same EPS as float8 training
# scale_dtype=torch.float32,
# quant_min=torch.finfo(float8_dtype).min,
# quant_max=torch.finfo(float8_dtype).max,
# )
# expected_quantized = quantize_affine(
# input,
# input.shape,
# scale,
# output_dtype=float8_dtype,
# quant_min=torch.finfo(float8_dtype).min,
# quant_max=torch.finfo(float8_dtype).max,
# zero_point=None,
# )
# expected_dequantized = dequantize_affine(
# expected_quantized,
# input.shape,
# scale,
# input_dtype=float8_dtype,
# output_dtype=hp_dtype,
# quant_min=torch.finfo(float8_dtype).min,
# quant_max=torch.finfo(float8_dtype).max,
# zero_point=None,
# zero_point_domain=ZeroPointDomain.NONE,
# )

# self.assertTrue(torch.equal(expected_scale, scale))
# torch.testing.assert_close(expected_quantized, quantized)
# torch.testing.assert_close(expected_dequantized, dequantized)


if __name__ == "__main__":
Expand Down
6 changes: 0 additions & 6 deletions test/sparsity/test_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from torchao.quantization.quant_api import int4_weight_only, quantize_
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
choose_qparams_affine,
quantize_affine,
)
Expand Down Expand Up @@ -92,8 +91,6 @@ def test_pack_unpack_equivalence(self):
eps = 1e-6
zero_point_dtype = torch.bfloat16
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT
scale_dtype = None

w = torch.rand(shape, dtype=torch.float16, device="cuda")
Expand All @@ -112,8 +109,6 @@ def test_pack_unpack_equivalence(self):
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain,
)
w_q_24 = quantize_affine(
w_24,
Expand All @@ -123,7 +118,6 @@ def test_pack_unpack_equivalence(self):
target_dtype,
quant_min,
quant_max,
zero_point_domain,
)
scales = scales.reshape(-1, w_q_24.shape[1])

Expand Down
Loading
Loading