From ca2922aff7c55084948ce85baab8300373d0f3dd Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Thu, 30 Nov 2023 10:18:18 -0800 Subject: [PATCH 1/4] Update onnx quantsim load_encodings_for_sim Signed-off-by: Kevin Hsieh --- .../src/python/aimet_onnx/qc_quantize_op.py | 41 ++- .../onnx/src/python/aimet_onnx/quantsim.py | 248 +++++++++++++++--- .../onnx/test/python/test_quantsim.py | 64 ++++- 3 files changed, 308 insertions(+), 45 deletions(-) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py index 51fe7afff90..1e49052c525 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/qc_quantize_op.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -36,7 +36,7 @@ # ============================================================================= """ Custom QcQuantizeOp to quantize weights and activations using ONNXRuntime """ -from typing import Union, List +from typing import Union, List, Optional import aimet_common.libpymo as libpymo from aimet_common.libpymo import TensorQuantizerOpMode from aimet_common.defs import QuantScheme, MAP_QUANT_SCHEME_TO_PYMO, MAP_ROUND_MODE_TO_PYMO, QuantizationDataType @@ -237,9 +237,42 @@ def encodings(self) -> libpymo.TfEncoding: """ return self.quant_info.encoding - def load_encodings(self, encoding): + def update_quantizer_and_load_encodings(self, encoding: List[libpymo.TfEncoding], is_symmetric: Optional[bool], + is_strict_symmetric: Optional[bool], is_unsigned_symmetric: Optional[bool], + data_type: QuantizationDataType): """ - Loads pre-existing encodings to quantizer which can be used during quantize-dequantize + Update quantizer settings and load pre-existing encodings to quantizer which can be used during + quantize-dequantize. + + :param encoding: The libpymo.TfEncoding object to be used by the C++ op + :param is_symmetric: True if encoding is symmetric, False otherwise + :param is_strict_symmetric: True if encoding is strict symmetric, False otherwise + :param is_unsigned_symmetric: True if encoding is unsigned symmetric, False otherwise + :param data_type: Data type of encoding + """ + self.enabled = True + self.bitwidth = encoding[0].bw + self.data_type = data_type + if self.data_type == QuantizationDataType.int: + assert self.use_symmetric_encodings is not None + assert self.use_strict_symmetric is not None + assert self.use_unsigned_symmetric is not None + + self.use_symmetric_encodings = is_symmetric + if self.use_symmetric_encodings: + self.use_strict_symmetric = is_strict_symmetric + # is_unsigned_symmetric is a special case since the flag could be enabled but the encoding can be signed + # if the observed tensor had negative values. + # To err on the side of caution, only set self.use_unsigned_symmetric if we know for sure that the encodings + # were unsigned. + if self.use_symmetric_encodings and is_unsigned_symmetric: + self.use_unsigned_symmetric = is_unsigned_symmetric + + self.load_encodings(encoding) + + def load_encodings(self, encoding: List[libpymo.TfEncoding]): + """ + Load pre-existing encodings to quantizer which can be used during quantize-dequantize :param encoding: The libpymo.TfEncoding object to be used by the C++ op """ diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index 9685679b5bb..0ff33af6ea6 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -2,7 +2,7 @@ # ============================================================================= # @@-COPYRIGHT-START-@@ # -# Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) 2022-2023, Qualcomm Innovation Center, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -36,8 +36,9 @@ # ============================================================================= """ Implementation for simulating models running on Quantized hardware """ +from dataclasses import dataclass import os -from typing import Dict, List, Union, Tuple +from typing import Dict, List, Union, Tuple, Optional import json import numpy as np import onnx @@ -53,7 +54,7 @@ from aimet_common import libquant_info from aimet_common.defs import QuantScheme, QuantizationDataType from aimet_common.quantsim import encoding_version, extract_global_quantizer_args -from aimet_common.utils import save_json_yaml +from aimet_common.utils import save_json_yaml, AimetLogger from aimet_onnx import utils from aimet_onnx.meta.operations import Op from aimet_onnx.meta.utils import get_op_given_param_name, get_param_shape_using_connected_graph @@ -62,6 +63,8 @@ from aimet_onnx.quantsim_config.quantsim_config import QuantSimConfigurator from aimet_onnx.utils import make_dummy_input, add_hook_to_get_activation, remove_activation_hooks +logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant) + # pylint: disable=no-name-in-module, ungrouped-imports if version.parse(onnx.__version__) >= version.parse("1.14.0"): from onnx import ModelProto @@ -77,6 +80,32 @@ data_types_to_quantize = [np.float32] +@dataclass +class LoadEncodingMismatchInfo: + """ + Dataclass tracking information about mismatched quantizer vs. encoding settings. + """ + quantizer_name = '' + enabled_mismatch = None + dtype_mismatch = None + bitwidth_mismatch = None + is_symmetric_mismatch = None + is_strict_symmetric_mismatch = None + is_unsigned_symmetric_mismatch = None + + def has_mismatch(self) -> bool: + """ + Returns True if there is a mismatched setting. + + :return: True if there is a mismatched setting, False otherwise + """ + return (self.enabled_mismatch is not None or + self.dtype_mismatch is not None or + self.bitwidth_mismatch is not None or + self.is_symmetric_mismatch is not None or + self.is_strict_symmetric_mismatch is not None or + self.is_unsigned_symmetric_mismatch is not None) + class QuantizationSimModel: """ Creates a QuantizationSimModel model by adding quantization simulations ops to a given model """ @@ -514,15 +543,6 @@ def set_and_freeze_param_encodings(self, encoding_path: str): :param encoding_path: path from where to load parameter encodings file """ - def _create_libpymo_encodings(encoding): - libpymo_encodings = [] - for enc_val in encoding: - enc = libpymo.TfEncoding() - enc.bw, enc.delta, enc.max, enc.min, enc.offset = enc_val['bitwidth'], enc_val['scale'], enc_val['max'], \ - enc_val['min'], enc_val['offset'] - libpymo_encodings.append(enc) - return libpymo_encodings - # Load encodings file with open(encoding_path) as json_file: encodings = json.load(json_file) @@ -530,16 +550,15 @@ def _create_libpymo_encodings(encoding): for quantizer_name in encodings: if quantizer_name in self.qc_quantize_op_dict: libpymo_encodings = _create_libpymo_encodings(encodings[quantizer_name]) - self.qc_quantize_op_dict[quantizer_name].load_encodings(libpymo_encodings) - self.qc_quantize_op_dict[quantizer_name].bitwidth = encodings[quantizer_name][0]['bitwidth'] - dtype = QuantizationDataType.float - if encodings[quantizer_name][0]['dtype'] == 'int': - dtype = QuantizationDataType.int - self.qc_quantize_op_dict[quantizer_name].data_type = dtype - is_symmetric = False - if encodings[quantizer_name][0]['is_symmetric'] == 'True': - is_symmetric = True - self.qc_quantize_op_dict[quantizer_name].use_symmetric_encodings = is_symmetric + is_symmetric, is_strict_symmetric, is_unsigned_symmetric = \ + get_symmetric_properties(encodings[quantizer_name]) + data_type = QuantizationDataType.int if encodings[quantizer_name][0]['dtype'] == 'int' else \ + QuantizationDataType.float + self.qc_quantize_op_dict[quantizer_name].update_quantizer_and_load_encodings(libpymo_encodings, + is_symmetric, + is_strict_symmetric, + is_unsigned_symmetric, + data_type) self.qc_quantize_op_dict[quantizer_name].freeze_encodings() def get_all_quantizers(self) -> Tuple[List, List]: @@ -558,7 +577,8 @@ def get_all_quantizers(self) -> Tuple[List, List]: return param_quantizers, activation_quantizers -def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str): +def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str, strict=True) -> \ + List[LoadEncodingMismatchInfo]: """ Loads the saved encodings to quant sim model. The encoding filename to load should end in .encodings, generated as part of quantsim export. @@ -566,26 +586,176 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_p :param quant_sim_model: Quantized model to load encodings for. Note: The model configuration should be the same as when encodings were exported. :param onnx_encoding_path: Path of the encodings file to load. + :param strict: If set to True and encoding settings between encodings to load do not line up with Quantsim + initialized settings, an assertion will be thrown. If set to False, quantizer settings will update to align with + encodings to load. + :return: List of LoadEncodingMismatchInfo objects containing quantizer names and mismatched settings """ - def _create_libpymo_encodings(encoding): - libpymo_encodings = [] - for enc_val in encoding: - enc = libpymo.TfEncoding() - enc.bw, enc.delta, enc.max, enc.min, enc.offset = enc_val['bitwidth'], enc_val['scale'], enc_val['max'], \ - enc_val['min'], enc_val['offset'] - libpymo_encodings.append(enc) - return libpymo_encodings + mismatched_encodings = [] # Load encodings file with open(onnx_encoding_path) as json_file: encodings = json.load(json_file) - for quantizer in encodings['activation_encodings']: - if quantizer in quant_sim_model.qc_quantize_op_dict: - libpymo_encodings = _create_libpymo_encodings(encodings['activation_encodings'][quantizer]) - quant_sim_model.qc_quantize_op_dict[quantizer].load_encodings(libpymo_encodings) + for quantizer_name, quantizer in quant_sim_model.qc_quantize_op_dict.items(): + if quantizer_name not in encodings['activation_encodings'] and \ + quantizer_name not in encodings['param_encodings']: + validate_encoding_settings(quantizer_name, quantizer, None, mismatched_encodings) + quantizer.enabled = False + continue + + if quantizer_name in encodings['activation_encodings']: + encodings_to_load = encodings['activation_encodings'][quantizer_name] + else: + encodings_to_load = encodings['param_encodings'][quantizer_name] + + is_symmetric, is_strict_symmetric, is_unsigned_symmetric = \ + get_symmetric_properties(encodings_to_load) + data_type = QuantizationDataType.int if encodings_to_load[0]['dtype'] == 'int' else \ + QuantizationDataType.float + libpymo_encodings = _create_libpymo_encodings(encodings_to_load) + validate_encoding_settings(quantizer_name, quantizer, encodings_to_load, mismatched_encodings) + quant_sim_model.qc_quantize_op_dict[quantizer_name].update_quantizer_and_load_encodings( + libpymo_encodings, is_symmetric, is_strict_symmetric, is_unsigned_symmetric, data_type) + + log_and_catch_mismatched_encodings(mismatched_encodings, strict) + return mismatched_encodings - for quantizer in encodings['param_encodings']: - if quantizer in quant_sim_model.qc_quantize_op_dict: - libpymo_encodings = _create_libpymo_encodings(encodings['param_encodings'][quantizer]) - quant_sim_model.qc_quantize_op_dict[quantizer].load_encodings(libpymo_encodings) +def log_and_catch_mismatched_encodings(mismatched_encodings: List[LoadEncodingMismatchInfo], strict: bool): + """ + If mismatched_encodings is not empty, log details for each entry. If strict is True, raise an AssertionError. + + :param mismatched_encodings: List of mismatched quantizer names and encoding settings + :param strict: If True, raise an AssertionError if there are mismatched settings + """ + if mismatched_encodings: + logging_strings = ['The following quantizers had settings not matching with provided encodings to load:'] + for mismatched_encoding_info in mismatched_encodings: + logging_strings.append(mismatched_encoding_info.quantizer_name + ':') + if mismatched_encoding_info.enabled_mismatch: + logging_strings.append(f'\tenabled: {mismatched_encoding_info.enabled_mismatch[0]}, ' + f'loaded encoding enabled: ' + f'{mismatched_encoding_info.enabled_mismatch[1]}') + + if mismatched_encoding_info.dtype_mismatch: + logging_strings.append(f'\tdtype: {mismatched_encoding_info.dtype_mismatch[0]}, ' + f'loaded encoding dtype: ' + f'{mismatched_encoding_info.dtype_mismatch[1]}') + + if mismatched_encoding_info.bitwidth_mismatch: + logging_strings.append(f'\tbitwidth: ' + f'{mismatched_encoding_info.bitwidth_mismatch[0]}, loaded encoding bitwidth:' + f'{mismatched_encoding_info.bitwidth_mismatch[1]}') + + if mismatched_encoding_info.is_symmetric_mismatch: + logging_strings.append(f'\tsymmetric: ' + f'{mismatched_encoding_info.is_symmetric_mismatch[0]}, ' + f'loaded encoding symmetric: ' + f'{mismatched_encoding_info.is_symmetric_mismatch[1]}') + + if mismatched_encoding_info.is_strict_symmetric_mismatch: + logging_strings.append(f'\tstrict symmetric: ' + f'{mismatched_encoding_info.is_strict_symmetric_mismatch[0]}, ' + f'loaded encoding strict symmetric: ' + f'{mismatched_encoding_info.is_strict_symmetric_mismatch[1]}') + + if mismatched_encoding_info.is_unsigned_symmetric_mismatch: + logging_strings.append(f'\tunsigned symmetric: ' + f'{mismatched_encoding_info.is_unsigned_symmetric_mismatch[0]}, ' + f'loaded encoding unsigned symmetric: ' + f'{mismatched_encoding_info.is_unsigned_symmetric_mismatch[1]}') + log_message = '\n'.join(logging_strings) + if strict: + logger.error(log_message) + raise AssertionError(log_message) + logger.info(log_message) + + +def _create_libpymo_encodings(encoding: Dict[str, Union[str, int, float]]) -> List[libpymo.TfEncoding]: + """ + Given encoding dict, return a TfEncoding object with corresponding info. + + :param encoding: Encoding dict to create TfEncoding object with + :return: TfEncoding object containing encoding dict info + """ + libpymo_encodings = [] + for enc_val in encoding: + enc = libpymo.TfEncoding() + enc.bw = enc_val['bitwidth'] + enc.delta, enc.max, enc.min, enc.offset = 0.0, 0.0, 0.0, 0 + if enc_val['dtype'] == 'int': + enc.delta, enc.max, enc.min, enc.offset = (enc_val['scale'], enc_val['max'], enc_val['min'], + enc_val['offset']) + libpymo_encodings.append(enc) + return libpymo_encodings + + +def get_symmetric_properties(encodings: List[Dict]) -> Tuple[Optional[bool], Optional[bool], Optional[bool]]: + """ + Return symmetric properties of the given encodings. If encodings are float, return None for each. + + :param encodings: Encodings to get symmetric properties for + :return: Tuple of is_symmetric, is_strict_symmetric, and is_unsigned symmetric properties + """ + if encodings[0]['dtype'] == 'float': + return None, None, None + + is_symmetric = encodings[0]['is_symmetric'] == 'True' + + is_strict_symmetric = False + if is_symmetric and encodings[0]['offset'] == -2**(encodings[0]['bitwidth'] - 1) + 1: + is_strict_symmetric = True + + # Note: Even if the original quantizer had is_unsigned_symmetric set to True, if any observed values were negative, + # the resulting encodings will look signed. This logic can only perform a best effort check to return True only if + # any encoding showed unsigned symmetric properties. + is_unsigned_symmetric = False + if is_symmetric: + for encoding in encodings: + if encoding['offset'] == 0: + is_unsigned_symmetric = True + break + return is_symmetric, is_strict_symmetric, is_unsigned_symmetric + +def validate_encoding_settings(quantizer_name: str, quantizer: QcQuantizeOp, encodings_to_load: Optional[List[Dict]], + mismatched_encodings_info: List[LoadEncodingMismatchInfo]): + """ + Check that quantizer settings align with the settings in encodings_to_load. If settings do not align, track the + mismatching settings in a LoadEncodingMismatchInfo object and add it to mismatched_encodings_info list. + + :param quantizer_name: Name of quantizer to check + :param quantizer: Quantizer to check + :param encodings_to_load: Encodings to check + :param mismatched_encodings_info: List holding information of quantizer names with mismatched settings + """ + encoding_mismatch_info = LoadEncodingMismatchInfo() + + # Match enabled state + if quantizer.enabled and encodings_to_load is None: + encoding_mismatch_info.enabled_mismatch = (quantizer.enabled, False) + if not quantizer.enabled and encodings_to_load is not None: + encoding_mismatch_info.enabled_mismatch = (quantizer.enabled, True) + + if encodings_to_load is not None: + is_symmetric, is_strict_symmetric, is_unsigned_symmetric = get_symmetric_properties(encodings_to_load) + + if quantizer.bitwidth != encodings_to_load[0]['bitwidth']: + encoding_mismatch_info.bitwidth_mismatch = (quantizer.bitwidth, encodings_to_load[0]['bitwidth']) + if quantizer.data_type.name != encodings_to_load[0]['dtype']: + encoding_mismatch_info.dtype_mismatch = (quantizer.data_type.name, encodings_to_load[0]['dtype']) + if quantizer.use_symmetric_encodings != is_symmetric: + encoding_mismatch_info.is_symmetric_mismatch = (quantizer.use_symmetric_encodings, is_symmetric) + if quantizer.use_strict_symmetric != is_strict_symmetric: + encoding_mismatch_info.is_strict_symmetric_mismatch = (quantizer.use_strict_symmetric, is_strict_symmetric) + + # Unsigned symmetric is a special case because even if the setting is true, the encodings may appear to be + # signed symmetric if any observed tensor values were < 0. + # In this case, only mark a mismatch if quantizer was set to signed symmetric but an unsigned symmetric + # encoding was seen. + if quantizer.use_unsigned_symmetric != is_unsigned_symmetric and not quantizer.use_unsigned_symmetric: + encoding_mismatch_info.is_unsigned_symmetric_mismatch = (quantizer.use_unsigned_symmetric, + is_unsigned_symmetric) + + if encoding_mismatch_info.has_mismatch(): + encoding_mismatch_info.quantizer_name = quantizer_name + mismatched_encodings_info.append(encoding_mismatch_info) diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index de360e5af79..8eefb30f94e 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -36,8 +36,7 @@ # ============================================================================= import json import os - -import onnx +import onnx.numpy_helper import torch import numpy as np from onnx import load_model @@ -446,6 +445,67 @@ def callback(session, args): assert np.allclose(out2, out3) + @pytest.mark.parametrize('strict', [False, True]) + def test_load_encodings_strict_and_non_strict(self, strict): + model = single_residual_model().model + + # Update weights for testing is_unsigned_symmetric override later + weight_initializers = [i.name for i in model.graph.initializer if len(i.dims) > 1] + weight_initializer_3 = [i for i in model.graph.initializer if i.name == weight_initializers[3]][0] + weight_initializer_3_data = onnx.numpy_helper.to_array(weight_initializer_3) + weight_initializer_3.raw_data = np.asarray(np.abs(weight_initializer_3_data), dtype=np.float32).tobytes() + + sim = QuantizationSimModel(model, config_file=get_path_for_per_channel_config()) + + conv_ops = [node for node in sim.model.model.graph.node if node.op_type == 'Conv'] + relu_ops = [node for node in sim.model.model.graph.node if node.op_type == 'Relu'] + avgpool_ops = [node for node in sim.model.model.graph.node if node.op_type == 'AveragePool'] + + act_1 = conv_ops[0].output[0] + act_2 = relu_ops[0].output[0] + act_3 = avgpool_ops[0].output[0] + act_4 = conv_ops[2].output[0] + sim.get_qc_quantize_op()[act_1].enabled = True + sim.get_qc_quantize_op()[act_2].enabled = False + sim.get_qc_quantize_op()[act_3].data_type = QuantizationDataType.float + sim.get_qc_quantize_op()[weight_initializers[0]].bitwidth = 16 + sim.get_qc_quantize_op()[act_4].bitwidth = 4 + sim.get_qc_quantize_op()[weight_initializers[1]].use_symmetric_encodings = False + sim.get_qc_quantize_op()[weight_initializers[2]].use_strict_symmetric = True + sim.get_qc_quantize_op()[weight_initializers[3]].use_unsigned_symmetric = True + + def callback(session, args): + in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} + session.run(None, in_tensor) + + dummy_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} + + sim.compute_encodings(callback, None) + sim.export('./tmp', 'onnx_sim') + + out2 = sim.session.run(None, dummy_tensor) + + del sim + + sim = QuantizationSimModel(model, config_file=get_path_for_per_channel_config()) + if strict: + with pytest.raises(AssertionError): + load_encodings_to_sim(sim, './tmp/onnx_sim.encodings', strict=strict) + else: + mismatched_encodings = load_encodings_to_sim(sim, './tmp/onnx_sim.encodings', strict=strict) + out3 = sim.session.run(None, dummy_tensor) + sim.export('./tmp', 'loaded_onnx_sim') + + assert sim.get_qc_quantize_op()[act_1].enabled + assert not sim.get_qc_quantize_op()[act_2].enabled + assert sim.get_qc_quantize_op()[act_3].data_type == QuantizationDataType.float + assert sim.get_qc_quantize_op()[weight_initializers[0]].bitwidth == 16 + assert sim.get_qc_quantize_op()[act_4].bitwidth == 4 + assert not sim.get_qc_quantize_op()[weight_initializers[1]].use_symmetric_encodings + assert sim.get_qc_quantize_op()[weight_initializers[2]].use_strict_symmetric + assert sim.get_qc_quantize_op()[weight_initializers[3]].use_unsigned_symmetric + assert len(mismatched_encodings) == 8 + assert np.allclose(out2, out3) def test_model_with_constants(self): if version.parse(torch.__version__) >= version.parse("1.13"): From 57a59daa5fa2ddcae0095bfb7398ff35f7c6477c Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Mon, 4 Dec 2023 16:57:23 -0800 Subject: [PATCH 2/4] Address review comments Signed-off-by: Kevin Hsieh --- .../onnx/src/python/aimet_onnx/quantsim.py | 59 ++++++++++++------- 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index 0ff33af6ea6..bb8aafbae14 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -81,17 +81,17 @@ data_types_to_quantize = [np.float32] @dataclass -class LoadEncodingMismatchInfo: +class EncodingMismatchInfo: """ Dataclass tracking information about mismatched quantizer vs. encoding settings. """ - quantizer_name = '' - enabled_mismatch = None - dtype_mismatch = None - bitwidth_mismatch = None - is_symmetric_mismatch = None - is_strict_symmetric_mismatch = None - is_unsigned_symmetric_mismatch = None + quantizer_name: str + enabled_mismatch: Optional[Tuple] = None + dtype_mismatch: Optional[Tuple] = None + bitwidth_mismatch: Optional[Tuple] = None + is_symmetric_mismatch: Optional[Tuple] = None + is_strict_symmetric_mismatch: Optional[Tuple] = None + is_unsigned_symmetric_mismatch: Optional[Tuple] = None def has_mismatch(self) -> bool: """ @@ -578,7 +578,7 @@ def get_all_quantizers(self) -> Tuple[List, List]: def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_path: str, strict=True) -> \ - List[LoadEncodingMismatchInfo]: + List[EncodingMismatchInfo]: """ Loads the saved encodings to quant sim model. The encoding filename to load should end in .encodings, generated as part of quantsim export. @@ -589,7 +589,7 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_p :param strict: If set to True and encoding settings between encodings to load do not line up with Quantsim initialized settings, an assertion will be thrown. If set to False, quantizer settings will update to align with encodings to load. - :return: List of LoadEncodingMismatchInfo objects containing quantizer names and mismatched settings + :return: List of EncodingMismatchInfo objects containing quantizer names and mismatched settings """ mismatched_encodings = [] @@ -597,10 +597,30 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_p with open(onnx_encoding_path) as json_file: encodings = json.load(json_file) + # First pass through quantizers to check for mismatched encodings + for quantizer_name, quantizer in quant_sim_model.qc_quantize_op_dict.items(): + if quantizer_name not in encodings['activation_encodings'] and \ + quantizer_name not in encodings['param_encodings']: + mismatched_info = get_encoding_mismatch_info(quantizer_name, quantizer, None) + if mismatched_info.has_mismatch(): + mismatched_encodings.append(mismatched_info) + continue + + if quantizer_name in encodings['activation_encodings']: + encodings_to_load = encodings['activation_encodings'][quantizer_name] + else: + encodings_to_load = encodings['param_encodings'][quantizer_name] + + mismatched_info = get_encoding_mismatch_info(quantizer_name, quantizer, encodings_to_load) + if mismatched_info.has_mismatch(): + mismatched_encodings.append(mismatched_info) + + log_and_catch_mismatched_encodings(mismatched_encodings, strict) + + # Second pass through quantizers to set quantizer settings for quantizer_name, quantizer in quant_sim_model.qc_quantize_op_dict.items(): if quantizer_name not in encodings['activation_encodings'] and \ quantizer_name not in encodings['param_encodings']: - validate_encoding_settings(quantizer_name, quantizer, None, mismatched_encodings) quantizer.enabled = False continue @@ -614,14 +634,12 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_p data_type = QuantizationDataType.int if encodings_to_load[0]['dtype'] == 'int' else \ QuantizationDataType.float libpymo_encodings = _create_libpymo_encodings(encodings_to_load) - validate_encoding_settings(quantizer_name, quantizer, encodings_to_load, mismatched_encodings) quant_sim_model.qc_quantize_op_dict[quantizer_name].update_quantizer_and_load_encodings( libpymo_encodings, is_symmetric, is_strict_symmetric, is_unsigned_symmetric, data_type) - log_and_catch_mismatched_encodings(mismatched_encodings, strict) return mismatched_encodings -def log_and_catch_mismatched_encodings(mismatched_encodings: List[LoadEncodingMismatchInfo], strict: bool): +def log_and_catch_mismatched_encodings(mismatched_encodings: List[EncodingMismatchInfo], strict: bool): """ If mismatched_encodings is not empty, log details for each entry. If strict is True, raise an AssertionError. @@ -717,18 +735,17 @@ def get_symmetric_properties(encodings: List[Dict]) -> Tuple[Optional[bool], Opt break return is_symmetric, is_strict_symmetric, is_unsigned_symmetric -def validate_encoding_settings(quantizer_name: str, quantizer: QcQuantizeOp, encodings_to_load: Optional[List[Dict]], - mismatched_encodings_info: List[LoadEncodingMismatchInfo]): +def get_encoding_mismatch_info(quantizer_name: str, quantizer: QcQuantizeOp, + encodings_to_load: Optional[List[Dict]]) -> EncodingMismatchInfo: """ Check that quantizer settings align with the settings in encodings_to_load. If settings do not align, track the - mismatching settings in a LoadEncodingMismatchInfo object and add it to mismatched_encodings_info list. + mismatching settings in a EncodingMismatchInfo object and add it to mismatched_encodings_info list. :param quantizer_name: Name of quantizer to check :param quantizer: Quantizer to check :param encodings_to_load: Encodings to check - :param mismatched_encodings_info: List holding information of quantizer names with mismatched settings """ - encoding_mismatch_info = LoadEncodingMismatchInfo() + encoding_mismatch_info = EncodingMismatchInfo(quantizer_name) # Match enabled state if quantizer.enabled and encodings_to_load is None: @@ -756,6 +773,4 @@ def validate_encoding_settings(quantizer_name: str, quantizer: QcQuantizeOp, enc encoding_mismatch_info.is_unsigned_symmetric_mismatch = (quantizer.use_unsigned_symmetric, is_unsigned_symmetric) - if encoding_mismatch_info.has_mismatch(): - encoding_mismatch_info.quantizer_name = quantizer_name - mismatched_encodings_info.append(encoding_mismatch_info) + return encoding_mismatch_info From 9b079a89d1c8c2bbb0177e574641a47112a62d4c Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Mon, 18 Dec 2023 18:36:35 -0800 Subject: [PATCH 3/4] Add logic to catch invalid encoding names Signed-off-by: Kevin Hsieh --- .../onnx/src/python/aimet_onnx/quantsim.py | 11 +++++++++++ .../onnx/test/python/test_quantsim.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index bb8aafbae14..a5d337e8dc1 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -597,6 +597,17 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_p with open(onnx_encoding_path) as json_file: encodings = json.load(json_file) + # Check that all encoding names in the encodings to load are found in the model + encoding_names_not_found = [] + for quantizer_name in list(encodings['activation_encodings'].keys()) + list(encodings['param_encodings'].keys()): + if quantizer_name not in quant_sim_model.qc_quantize_op_dict: + encoding_names_not_found.append(quantizer_name) + if encoding_names_not_found: + logger.error('The following encoding names were present in the encodings to load but not found in the model: ' + '%s', str(encoding_names_not_found)) + raise AssertionError('The following encoding names were present in the encodings to load but not found in the ' + 'model: ' + str(encoding_names_not_found)) + # First pass through quantizers to check for mismatched encodings for quantizer_name, quantizer in quant_sim_model.qc_quantize_op_dict.items(): if quantizer_name not in encodings['activation_encodings'] and \ diff --git a/TrainingExtensions/onnx/test/python/test_quantsim.py b/TrainingExtensions/onnx/test/python/test_quantsim.py index 8eefb30f94e..e5ec6bc8a9d 100644 --- a/TrainingExtensions/onnx/test/python/test_quantsim.py +++ b/TrainingExtensions/onnx/test/python/test_quantsim.py @@ -445,6 +445,20 @@ def callback(session, args): assert np.allclose(out2, out3) + def test_load_encodings_assertion(self): + model = single_residual_model().model + sim = QuantizationSimModel(model, config_file=get_path_for_per_channel_config()) + def callback(session, args): + in_tensor = {'input': np.random.rand(1, 3, 32, 32).astype(np.float32)} + session.run(None, in_tensor) + + sim.compute_encodings(callback, None) + sim.export('./tmp', 'onnx_sim') + model = multi_output_model().model + sim = QuantizationSimModel(model) + with pytest.raises(AssertionError): + load_encodings_to_sim(sim, './tmp/onnx_sim.encodings', strict=False) + @pytest.mark.parametrize('strict', [False, True]) def test_load_encodings_strict_and_non_strict(self, strict): model = single_residual_model().model From 89ad162f8d6d285fe2e62f00c9c4f42909a43eb5 Mon Sep 17 00:00:00 2001 From: Kevin Hsieh Date: Tue, 19 Dec 2023 09:15:37 -0800 Subject: [PATCH 4/4] Fix pylint Signed-off-by: Kevin Hsieh --- .../onnx/src/python/aimet_onnx/quantsim.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py index a5d337e8dc1..8b3d3fbfe32 100644 --- a/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py +++ b/TrainingExtensions/onnx/src/python/aimet_onnx/quantsim.py @@ -597,16 +597,7 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_p with open(onnx_encoding_path) as json_file: encodings = json.load(json_file) - # Check that all encoding names in the encodings to load are found in the model - encoding_names_not_found = [] - for quantizer_name in list(encodings['activation_encodings'].keys()) + list(encodings['param_encodings'].keys()): - if quantizer_name not in quant_sim_model.qc_quantize_op_dict: - encoding_names_not_found.append(quantizer_name) - if encoding_names_not_found: - logger.error('The following encoding names were present in the encodings to load but not found in the model: ' - '%s', str(encoding_names_not_found)) - raise AssertionError('The following encoding names were present in the encodings to load but not found in the ' - 'model: ' + str(encoding_names_not_found)) + validate_encodings_to_load(encodings, quant_sim_model) # First pass through quantizers to check for mismatched encodings for quantizer_name, quantizer in quant_sim_model.qc_quantize_op_dict.items(): @@ -650,6 +641,29 @@ def load_encodings_to_sim(quant_sim_model: QuantizationSimModel, onnx_encoding_p return mismatched_encodings + +def validate_encodings_to_load(encodings_to_load: Dict, quant_sim_model: QuantizationSimModel): + """ + Validate that all names of encodings to load are found in the model. + + :param encodings_to_load: Encodings to load + :param quant_sim_model: Quantsim model to check for encoding names. + """ + # Check that all encoding names in the encodings to load are found in the model. This check only works for verifying + # that names in encodings_to_load are valid. The reverse check will not work, since quantizers which are disabled + # will not show up in encodings_to_load. + encoding_names_not_found = [] + for quantizer_name in (list(encodings_to_load['activation_encodings'].keys()) + + list(encodings_to_load['param_encodings'].keys())): + if quantizer_name not in quant_sim_model.qc_quantize_op_dict: + encoding_names_not_found.append(quantizer_name) + if encoding_names_not_found: + logger.error('The following encoding names were present in the encodings to load but not found in the model: ' + '%s', str(encoding_names_not_found)) + raise AssertionError('The following encoding names were present in the encodings to load but not found in the ' + 'model: ' + str(encoding_names_not_found)) + + def log_and_catch_mismatched_encodings(mismatched_encodings: List[EncodingMismatchInfo], strict: bool): """ If mismatched_encodings is not empty, log details for each entry. If strict is True, raise an AssertionError.