From d8feece2cc546737ee7506debef6d6736743b2d9 Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Wed, 29 Jan 2025 16:55:41 -0800 Subject: [PATCH] up --- .../experimental/tests/test_q_dq_layout.py | 76 +++++++++++++++---- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/torchao/experimental/tests/test_q_dq_layout.py b/torchao/experimental/tests/test_q_dq_layout.py index a5b95fe91c..9f8f6a7ddf 100644 --- a/torchao/experimental/tests/test_q_dq_layout.py +++ b/torchao/experimental/tests/test_q_dq_layout.py @@ -8,11 +8,12 @@ import unittest import torch +from torch.testing import FileCheck from torchao.dtypes import PlainLayout from torchao.experimental.q_dq_layout import QDQLayout from torchao.experimental.quant_api import int8_dynamic_activation_intx_weight -from torchao.quantization.granularity import PerGroup, PerRow +from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_api import quantize_ from torchao.utils import unwrap_tensor_subclass @@ -20,7 +21,7 @@ class TestQDQLayout(unittest.TestCase): def test_accuracy(self): """ - Checks the accuracy of PackedLinearInt8DynamicActivationIntxWeightLayout() by comparing + Checks the accuracy of TestQDQLayout() by comparing its results to the results of a reference model that uses PlainLayout() """ granularity = PerGroup(128) @@ -84,24 +85,18 @@ def test_accuracy(self): def test_export(self): """ - Checks that models quantized with PackedLinearInt8DynamicActivationIntxWeightLayout() work with - torch.export.export, torch.compile, and AOTI. + Checks that models quantized with TestQDQLayout() export as expected """ - granularity = PerRow() - m = 3 - k0 = 512 - k1 = 256 - k2 = 128 - k3 = 1024 + granularity = PerGroup(64) weight_dtype = torch.int4 - has_weight_zeros = True + has_weight_zeros = False layers = [ - torch.nn.Linear(k0, k1, bias=False), - torch.nn.Linear(k1, k2, bias=False), - torch.nn.Linear(k2, k3, bias=False), + torch.nn.Linear(512, 256, bias=False), ] model = torch.nn.Sequential(*layers) - activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + activations = torch.randn(1, 512, dtype=torch.float32) + + to_export_with_old_api = copy.deepcopy(model) print("Quantizing model") quantize_( @@ -122,6 +117,57 @@ def test_export(self): exported_results = exported.module()(activations) self.assertTrue(torch.allclose(eager_results, exported_results)) + expected_lines = [ + "torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(input_1, torch.int8)", + "torch.ops.quantized_decomposed.quantize_per_token.default(input_1, getitem, getitem_1, -128, 127, torch.int8)", + "torch.ops.quantized_decomposed.dequantize_per_token.default(quantize_per_token, getitem, getitem_1, -128, 127, torch.int8, torch.float32)", + "torch.ops.aten.to.dtype(dequantize_per_token, torch.float32)", + "torch.ops.quantized_decomposed.dequantize_per_channel_group.default(p_fn_0_parametrizations_weight_original0, p_fn_0_parametrizations_weight_original1, None, -8, 7, torch.int8, 64, torch.float32)", + "torch.ops.aten.linear.default(to, dequantize_per_channel_group)", + ] + for line in expected_lines: + FileCheck().check_count(line, 1, exactly=True).run( + exported.graph_module.code + ) + + # Compare exported graph with old API + # TODO: delete after old API is deprecated + from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer + + quantizer = Int8DynActInt4WeightQuantizer( + groupsize=granularity.group_size, + padding_allowed=False, + precision=torch.float32, + scales_precision=torch.float32, + device=torch.device("cpu"), + # mapping_type=MappingType.ASYMMETRIC, + ) + quantizer.quantize(to_export_with_old_api) + exported_from_old_api = torch.export.export( + to_export_with_old_api, + (activations,), + ) + + expected_lines_old_api = [ + "torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric.default(to, torch.int8)", + "torch.ops.quantized_decomposed.quantize_per_token.default(to, getitem, getitem_1, -128, 127, torch.int8)", + "torch.ops.quantized_decomposed.dequantize_per_token.default(quantize_per_token, getitem, getitem_1, -128, 127, torch.int8, torch.float32)", + "torch.ops.aten.to.dtype(dequantize_per_token, torch.float32)", + "torch.ops.quantized_decomposed.dequantize_per_channel_group.default(b_getattr_l__fn_____0___weight, b_getattr_l__fn_____0___scales, b_getattr_l__fn_____0___zeros, -8, 7, torch.int8, 64, torch.float32)", + "torch.ops.aten.linear.default(to_1, dequantize_per_channel_group)", + ] + for line in expected_lines_old_api: + FileCheck().check_count(line, 1, exactly=True).run( + exported_from_old_api.graph_module.code + ) + + # TODO: there are slight differences in the results because exported_results uses + # asymmetric with zero_point_domain NONE (has_weight_zeros=False) + # and results_from_old_api uses symmetric (but with an asymmetric range) + # I think the new API might make more sense, but need more thought + # results_from_old_api = exported_from_old_api.module()(activations) + # self.assertTrue(torch.allclose(exported_results, results_from_old_api)) + if __name__ == "__main__": unittest.main()