diff --git a/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py b/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py index 9483d40f0..ba52c014c 100644 --- a/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py +++ b/model_compression_toolkit/core/common/mixed_precision/distance_weighting.py @@ -71,3 +71,6 @@ class MpDistanceWeighting(Enum): def __call__(self, distance_matrix: np.ndarray) -> np.ndarray: return self.value(distance_matrix) + + def __deepcopy__(self, memo): + return self diff --git a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py index 2451638af..ff3a27c87 100644 --- a/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py +++ b/model_compression_toolkit/core/common/mixed_precision/resource_utilization_tools/resource_utilization_data.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import copy + import numpy as np from typing import Callable, Any, Dict, Tuple from model_compression_toolkit.constants import FLOAT_BITWIDTH, BITS_TO_BYTES -from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig +from model_compression_toolkit.core import FrameworkInfo, ResourceUtilization, CoreConfig, QuantizationErrorMethod from model_compression_toolkit.core.common import Graph from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX @@ -57,7 +59,7 @@ def compute_resource_utilization_data(in_model: Any, """ - + core_config = _create_core_config_for_ru(core_config) # We assume that the resource_utilization_data API is used to compute the model resource utilization for # mixed precision scenario, so we run graph preparation under the assumption of enabled mixed precision. if transformed_graph is None: @@ -222,6 +224,8 @@ def requires_mixed_precision(in_model: Any, Returns: A boolean indicating if mixed precision is needed. """ is_mixed_precision = False + core_config = _create_core_config_for_ru(core_config) + transformed_graph = graph_preparation_runner(in_model, representative_data_gen, core_config.quantization_config, @@ -247,3 +251,21 @@ def requires_mixed_precision(in_model: Any, is_mixed_precision |= target_resource_utilization.total_memory < total_weights_memory_bytes + max_activation_tensor_size_bytes is_mixed_precision |= target_resource_utilization.bops < bops_count return is_mixed_precision + + +def _create_core_config_for_ru(core_config: CoreConfig) -> CoreConfig: + """ + Create a core config to use for resource utilization computation. + + Args: + core_config: input core config + + Returns: + Core config for resource utilization. + """ + core_config = copy.deepcopy(core_config) + # For resource utilization graph_preparation_runner runs with gptq=False (the default value). HMSE is not supported + # without GPTQ and will raise an error later so we replace it with MSE. + if core_config.quantization_config.weights_error_method == QuantizationErrorMethod.HMSE: + core_config.quantization_config.weights_error_method = QuantizationErrorMethod.MSE + return core_config diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index b85fc8571..e005f1a2e 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -58,13 +58,10 @@ def set_quantization_configuration_to_graph(graph: Graph, if quant_config.weights_error_method == QuantizationErrorMethod.HMSE: if not running_gptq: - Logger.warning(f"The HMSE error method for parameters selection is only supported when running GPTQ " - f"optimization due to long execution time that is not suitable for basic PTQ. " - f"Using the default MSE error method instead.") - quant_config.weights_error_method = QuantizationErrorMethod.MSE - else: - Logger.warning("Using the HMSE error method for weights quantization parameters search. " - "Note: This method may significantly increase runtime during the parameter search process.") + raise ValueError(f"The HMSE error method for parameters selection is only supported when running GPTQ " + f"optimization due to long execution time that is not suitable for basic PTQ.") + Logger.warning("Using the HMSE error method for weights quantization parameters search. " + "Note: This method may significantly increase runtime during the parameter search process.") nodes_to_manipulate_bit_widths = {} if bit_width_config is None else bit_width_config.get_nodes_to_manipulate_bit_widths(graph) diff --git a/tests/keras_tests/function_tests/test_hmse_error_method.py b/tests/keras_tests/function_tests/test_hmse_error_method.py index c217b4c95..24b7eff49 100644 --- a/tests/keras_tests/function_tests/test_hmse_error_method.py +++ b/tests/keras_tests/function_tests/test_hmse_error_method.py @@ -165,44 +165,11 @@ def test_uniform_threshold_selection_hmse_per_tensor(self): self._verify_params_calculation_execution(RANGE_MAX) def test_threshold_selection_hmse_no_gptq(self): - self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True, - running_gptq=False) - - def _verify_node_default_mse_error(node_type): - node = [n for n in self.graph.nodes if n.type == node_type] - self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.") - node = node[0] - - kernel_attr_error_method = ( - node.candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_error_method) - self.assertTrue(kernel_attr_error_method == mct.core.QuantizationErrorMethod.MSE, - f"Expecting {node_type} node quantization parameter error method to be the default " - f"MSE when not running with GPTQ, but is set to {kernel_attr_error_method}.") - - # verifying that the nodes quantization params error method is changed to the default MSE - _verify_node_default_mse_error(layers.Conv2D) - _verify_node_default_mse_error(layers.Dense) - - calculate_quantization_params(self.graph, fw_impl=self.keras_impl, repr_data_gen_fn=representative_dataset, - hessian_info_service=self.his, num_hessian_samples=1) - - def _verify_node_no_hessian_computed(node_type): - node = [n for n in self.graph.nodes if n.type == node_type] - self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.") - node = node[0] - - expected_hessian_request = HessianScoresRequest(mode=HessianMode.WEIGHTS, - granularity=HessianScoresGranularity.PER_ELEMENT, - data_loader=None, - n_samples=1, - target_nodes=[node]) - - with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e: - self.his.fetch_hessian(expected_hessian_request) - - # verifying that no Hessian scores were computed - _verify_node_no_hessian_computed(layers.Conv2D) - _verify_node_no_hessian_computed(layers.Dense) + with self.assertRaises(ValueError) as e: + self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True, + running_gptq=False) + self.assertTrue('The HMSE error method for parameters selection is only supported when running GPTQ ' + 'optimization due to long execution time that is not suitable for basic PTQ.' in e.exception.args[0]) def test_threshold_selection_hmse_no_kernel_attr(self): def _generate_bn_quantization_tpc(quant_method, per_channel): @@ -258,8 +225,9 @@ def _generate_bn_quantization_tpc(quant_method, per_channel): n_samples=1, target_nodes=[node]) - with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e: + with self.assertRaises(ValueError) as e: self.his.fetch_hessian(expected_hessian_request) + self.assertTrue('Not enough hessians are cached to fulfill the request' in e.exception.args[0]) if __name__ == '__main__':