Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
irenaby committed Oct 30, 2024
1 parent f366430 commit 59a4c71
Showing 1 changed file with 7 additions and 39 deletions.
46 changes: 7 additions & 39 deletions tests/keras_tests/function_tests/test_hmse_error_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,44 +165,11 @@ def test_uniform_threshold_selection_hmse_per_tensor(self):
self._verify_params_calculation_execution(RANGE_MAX)

def test_threshold_selection_hmse_no_gptq(self):
self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True,
running_gptq=False)

def _verify_node_default_mse_error(node_type):
node = [n for n in self.graph.nodes if n.type == node_type]
self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.")
node = node[0]

kernel_attr_error_method = (
node.candidates_quantization_cfg[0].weights_quantization_cfg.get_attr_config(KERNEL).weights_error_method)
self.assertTrue(kernel_attr_error_method == mct.core.QuantizationErrorMethod.MSE,
f"Expecting {node_type} node quantization parameter error method to be the default "
f"MSE when not running with GPTQ, but is set to {kernel_attr_error_method}.")

# verifying that the nodes quantization params error method is changed to the default MSE
_verify_node_default_mse_error(layers.Conv2D)
_verify_node_default_mse_error(layers.Dense)

calculate_quantization_params(self.graph, fw_impl=self.keras_impl, repr_data_gen_fn=representative_dataset,
hessian_info_service=self.his, num_hessian_samples=1)

def _verify_node_no_hessian_computed(node_type):
node = [n for n in self.graph.nodes if n.type == node_type]
self.assertTrue(len(node) == 1, f"Expecting exactly 1 {node_type} node in test model.")
node = node[0]

expected_hessian_request = HessianScoresRequest(mode=HessianMode.WEIGHTS,
granularity=HessianScoresGranularity.PER_ELEMENT,
data_loader=None,
n_samples=1,
target_nodes=[node])

with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e:
self.his.fetch_hessian(expected_hessian_request)

# verifying that no Hessian scores were computed
_verify_node_no_hessian_computed(layers.Conv2D)
_verify_node_no_hessian_computed(layers.Dense)
with self.assertRaises(ValueError) as e:
self._setup_with_args(quant_method=mct.target_platform.QuantizationMethod.SYMMETRIC, per_channel=True,
running_gptq=False)
self.assertTrue('The HMSE error method for parameters selection is only supported when running GPTQ '
'optimization due to long execution time that is not suitable for basic PTQ.' in e.exception.args[0])

def test_threshold_selection_hmse_no_kernel_attr(self):
def _generate_bn_quantization_tpc(quant_method, per_channel):
Expand Down Expand Up @@ -258,8 +225,9 @@ def _generate_bn_quantization_tpc(quant_method, per_channel):
n_samples=1,
target_nodes=[node])

with self.assertRaises(ValueError, msg='Not enough hessians are cached to fulfill the request') as e:
with self.assertRaises(ValueError) as e:
self.his.fetch_hessian(expected_hessian_request)
self.assertTrue('Not enough hessians are cached to fulfill the request' in e.exception.args[0])


if __name__ == '__main__':
Expand Down

0 comments on commit 59a4c71

Please sign in to comment.