Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Jan 30, 2025
1 parent d9667b1 commit d8feece
Showing 1 changed file with 61 additions and 15 deletions.
76 changes: 61 additions & 15 deletions torchao/experimental/tests/test_q_dq_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,20 @@
import unittest

import torch
from torch.testing import FileCheck

from torchao.dtypes import PlainLayout
from torchao.experimental.q_dq_layout import QDQLayout
from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_api import quantize_
from torchao.utils import unwrap_tensor_subclass


class TestQDQLayout(unittest.TestCase):
def test_accuracy(self):
"""
Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing
Checks the accuracy of TestQDQLayout() by comparing
its results to the results of a reference model that uses PlainLayout()
"""
granularity = PerGroup(128)
Expand Down Expand Up @@ -84,24 +85,18 @@ def test_accuracy(self):

def test_export(self):
"""
Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with
torch.export.export, torch.compile, and AOTI.
Checks that models quantized with TestQDQLayout() export as expected
"""
granularity = PerRow()
m = 3
k0 = 512
k1 = 256
k2 = 128
k3 = 1024
granularity = PerGroup(64)
weight_dtype = torch.int4
has_weight_zeros = True
has_weight_zeros = False
layers = [
torch.nn.Linear(k0, k1, bias=False),
torch.nn.Linear(k1, k2, bias=False),
torch.nn.Linear(k2, k3, bias=False),
torch.nn.Linear(512, 256, bias=False),
]
model = torch.nn.Sequential(*layers)
activations = torch.randn(2, 1, m, k0, dtype=torch.float32)
activations = torch.randn(1, 512, dtype=torch.float32)

to_export_with_old_api = copy.deepcopy(model)

print("Quantizing model")
quantize_(
Expand All @@ -122,6 +117,57 @@ def test_export(self):
exported_results = exported.module()(activations)
self.assertTrue(torch.allclose(eager_results, exported_results))

expected_lines = [
"torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(input_1, torch.int8)",
"torch.ops.quantized_decomposed.quantize_per_token.default(input_1, getitem, getitem_1, -128, 127, torch.int8)",
"torch.ops.quantized_decomposed.dequantize_per_token.default(quantize_per_token, getitem, getitem_1, -128, 127, torch.int8, torch.float32)",
"torch.ops.aten.to.dtype(dequantize_per_token, torch.float32)",
"torch.ops.quantized_decomposed.dequantize_per_channel_group.default(p_fn_0_parametrizations_weight_original0, p_fn_0_parametrizations_weight_original1, None, -8, 7, torch.int8, 64, torch.float32)",
"torch.ops.aten.linear.default(to, dequantize_per_channel_group)",
]
for line in expected_lines:
FileCheck().check_count(line, 1, exactly=True).run(
exported.graph_module.code
)

# Compare exported graph with old API
# TODO: delete after old API is deprecated
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer

quantizer = Int8DynActInt4WeightQuantizer(
groupsize=granularity.group_size,
padding_allowed=False,
precision=torch.float32,
scales_precision=torch.float32,
device=torch.device("cpu"),
# mapping_type=MappingType.ASYMMETRIC,
)
quantizer.quantize(to_export_with_old_api)
exported_from_old_api = torch.export.export(
to_export_with_old_api,
(activations,),
)

expected_lines_old_api = [
"torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(to, torch.int8)",
"torch.ops.quantized_decomposed.quantize_per_token.default(to, getitem, getitem_1, -128, 127, torch.int8)",
"torch.ops.quantized_decomposed.dequantize_per_token.default(quantize_per_token, getitem, getitem_1, -128, 127, torch.int8, torch.float32)",
"torch.ops.aten.to.dtype(dequantize_per_token, torch.float32)",
"torch.ops.quantized_decomposed.dequantize_per_channel_group.default(b_getattr_l__fn_____0___weight, b_getattr_l__fn_____0___scales, b_getattr_l__fn_____0___zeros, -8, 7, torch.int8, 64, torch.float32)",
"torch.ops.aten.linear.default(to_1, dequantize_per_channel_group)",
]
for line in expected_lines_old_api:
FileCheck().check_count(line, 1, exactly=True).run(
exported_from_old_api.graph_module.code
)

# TODO: there are slight differences in the results because exported_results uses
# asymmetric with zero_point_domain NONE (has_weight_zeros=False)
# and results_from_old_api uses symmetric (but with an asymmetric range)
# I think the new API might make more sense, but need more thought
# results_from_old_api = exported_from_old_api.module()(activations)
# self.assertTrue(torch.allclose(exported_results, results_from_old_api))


if __name__ == "__main__":
unittest.main()

0 comments on commit d8feece

Please sign in to comment.