From 5891aed8439e0b66029a34c37a4d4a70639067cc Mon Sep 17 00:00:00 2001 From: Andrei Stoian Date: Thu, 7 Sep 2023 10:16:07 +0200 Subject: [PATCH] chore: fix correctness --- src/concrete/ml/common/utils.py | 8 ++++++++ src/concrete/ml/onnx/ops_impl.py | 6 +++++- .../ml/quantization/quantized_module_passes.py | 8 +++----- src/concrete/ml/quantization/quantizers.py | 6 +++++- src/concrete/ml/sklearn/qnn_module.py | 10 +++++----- tests/torch/test_brevitas_qat.py | 15 +++++++-------- 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/concrete/ml/common/utils.py b/src/concrete/ml/common/utils.py index 3af1800b5..4414056fb 100644 --- a/src/concrete/ml/common/utils.py +++ b/src/concrete/ml/common/utils.py @@ -42,6 +42,14 @@ # Indicate if the old simulation method should be used when simulating FHE executions USE_OLD_VL = True +# Debug option for testing round PBS optimization +# Setting this option to true will make quantizers "round half up" +# For example: 0.5 -> 1, 1.5 -> 2 instead of "round half to even" +# When the option is set to false, Concrete ML uses numpy.rint +# which has the same behavior as torch.round -> Brevitas nets +# should be exact compared to their Concrete ML QuantizedModule +QUANT_ROUND_LIKE_ROUND_PBS = False + class FheMode(str, enum.Enum): """Enum representing the execution mode. diff --git a/src/concrete/ml/onnx/ops_impl.py b/src/concrete/ml/onnx/ops_impl.py index f1d64f712..d06489a49 100644 --- a/src/concrete/ml/onnx/ops_impl.py +++ b/src/concrete/ml/onnx/ops_impl.py @@ -14,6 +14,7 @@ from scipy import special from typing_extensions import SupportsIndex +from ..common import utils from ..common.debugging import assert_false, assert_true from .onnx_impl_utils import ( compute_onnx_pool_padding, @@ -1653,7 +1654,10 @@ def numpy_brevitas_quant( y = numpy.clip(y, min_int_val, max_int_val) # Quantize to produce integers representing the float quantized values - y = numpy.rint(y) + if utils.QUANT_ROUND_LIKE_ROUND_PBS: + y = numpy.floor(y + 0.5) + else: + y = numpy.rint(y) # Compute quantized floating point values y = (y - zero_point) * scale diff --git a/src/concrete/ml/quantization/quantized_module_passes.py b/src/concrete/ml/quantization/quantized_module_passes.py index 35a51eb78..bda911c27 100644 --- a/src/concrete/ml/quantization/quantized_module_passes.py +++ b/src/concrete/ml/quantization/quantized_module_passes.py @@ -295,11 +295,9 @@ def integer_log2(value: float) -> Tuple[int, bool]: # number of lsbs to round is the negative of the sum of log2 # of the scale factors lsbs_to_round = -(log2_input + log2_weights - log2_output) - # log2_output - log2_input - log2_weights - # TODO: check this part with Andrei - # How is it possible to have like that? - path_start_node.rounding_threshold_bits = lsbs_to_round - path_start_node.lsbs_to_remove = lsbs_to_round + if lsbs_to_round > 0: + path_start_node.rounding_threshold_bits = lsbs_to_round + path_start_node.lsbs_to_remove = lsbs_to_round else: invalid_paths.append(path_start_node) diff --git a/src/concrete/ml/quantization/quantizers.py b/src/concrete/ml/quantization/quantizers.py index 73e9d145d..c3317cea4 100644 --- a/src/concrete/ml/quantization/quantizers.py +++ b/src/concrete/ml/quantization/quantizers.py @@ -7,6 +7,7 @@ import numpy +from ..common import utils from ..common.debugging import assert_true from ..common.serialization.dumpers import dump, dumps @@ -745,7 +746,10 @@ def quant(self, values: numpy.ndarray) -> numpy.ndarray: assert self.offset is not None assert self.scale is not None - qvalues = numpy.rint(values / self.scale + self.zero_point) + if utils.QUANT_ROUND_LIKE_ROUND_PBS: + qvalues = numpy.floor(values / self.scale + self.zero_point + 0.5) + else: + qvalues = numpy.rint(values / self.scale + self.zero_point) # Clipping can be performed for PTQ and for precomputed (for now only Brevitas) QAT # (where quantizer parameters are available in ONNX layers). diff --git a/src/concrete/ml/sklearn/qnn_module.py b/src/concrete/ml/sklearn/qnn_module.py index 8fa46de76..92d415e6d 100644 --- a/src/concrete/ml/sklearn/qnn_module.py +++ b/src/concrete/ml/sklearn/qnn_module.py @@ -9,7 +9,6 @@ from torch import nn from ..common.debugging import assert_true -from ..common.utils import MAX_BITWIDTH_BACKWARD_COMPATIBLE from ..quantization.qat_quantizers import Int8ActPerTensorPoT, Int8WeightPerTensorPoT @@ -29,14 +28,15 @@ def __init__( n_layers: int, n_outputs: int, n_hidden_neurons_multiplier: int = 4, - n_w_bits: int = 3, - n_a_bits: int = 3, - n_accum_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE, + n_w_bits: int = 4, + n_a_bits: int = 4, + # No pruning by default as roundPBS keeps the PBS precision low + n_accum_bits: int = 32, n_prune_neurons_percentage: float = 0.0, activation_function: Type = nn.ReLU, quant_narrow: bool = False, quant_signed: bool = True, - power_of_two_scaling: bool = False, + power_of_two_scaling: bool = True, # Default to true: use roundPBS to speed up the NNs ): """Sparse Quantized Neural Network constructor. diff --git a/tests/torch/test_brevitas_qat.py b/tests/torch/test_brevitas_qat.py index b9b8c9751..216148106 100644 --- a/tests/torch/test_brevitas_qat.py +++ b/tests/torch/test_brevitas_qat.py @@ -15,6 +15,7 @@ from torch import nn from torch.utils.data import DataLoader, TensorDataset +from concrete.ml.common import utils from concrete.ml.common.utils import ( is_classifier_or_partial_classifier, is_regressor_or_partial_regressor, @@ -514,6 +515,8 @@ def test_brevitas_power_of_two( net, x_all, _ = train_brevitas_network_tinymnist(is_cnn, n_bits, True, False, power_of_two) + utils.QUANT_ROUND_LIKE_ROUND_PBS = True + # If rounding threshold is set -> nothing happens # If Quantizer is not setup -> nothing happens quantized_module = compile_brevitas_qat_model( @@ -590,11 +593,7 @@ def test_brevitas_power_of_two( ) # # Compare the result with the optimized network and without - # # they should be equal (allow 3 non-matching value out of 100) - # TODO: actually verify correctness here, this is just a placeholder - # https://github.com/zama-ai/concrete-ml-internal/issues/3946 - assert y_pred_sim_round.shape == y_pred_clear_round.shape - assert y_pred_clear_round.shape == y_pred_clear_no_round.shape - - # assert numpy.sum(y_pred_sim_round != y_pred_clear_round) <= 3 - # assert numpy.sum(y_pred_clear_round != y_pred_clear_no_round) <= 3 + # # they should be equal + + assert numpy.sum(y_pred_sim_round != y_pred_clear_round) == 0 + assert numpy.sum(y_pred_clear_round != y_pred_clear_no_round) == 0