From 9ef890e7c1ea3ebf7ce32637413a8e633bff0405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jordan=20Fr=C3=A9ry?= Date: Tue, 2 Apr 2024 11:30:18 +0200 Subject: [PATCH] feat: add new approx rounding --- conftest.py | 3 +- src/concrete/ml/common/utils.py | 2 - .../ml/quantization/base_quantized_op.py | 45 +++- src/concrete/ml/quantization/post_training.py | 18 +- src/concrete/ml/torch/compile.py | 69 +++-- tests/sklearn/test_sklearn_models.py | 55 ++-- tests/torch/test_brevitas_qat.py | 2 +- tests/torch/test_compile_torch.py | 237 +++++++++--------- .../cifar/cifar_brevitas_training/README.md | 2 +- .../evaluate_one_example_fhe.py | 8 +- 10 files changed, 261 insertions(+), 180 deletions(-) diff --git a/conftest.py b/conftest.py index c5b92d295..5663c7247 100644 --- a/conftest.py +++ b/conftest.py @@ -147,7 +147,8 @@ def default_configuration(): enable_unsafe_features=True, use_insecure_key_cache=True, insecure_key_cache_location="ConcreteNumpyKeyCache", - fhe_simulation=False, # Simulation compilation is done lazilly on circuit.simulate + # Simulation compilation is done lazily on circuit.simulate + fhe_simulation=False, fhe_execution=True, ) diff --git a/src/concrete/ml/common/utils.py b/src/concrete/ml/common/utils.py index b4cfe8de3..325de1617 100644 --- a/src/concrete/ml/common/utils.py +++ b/src/concrete/ml/common/utils.py @@ -36,8 +36,6 @@ # Indicate if the old virtual library method should be used instead of the compiler simulation # when simulating FHE executions -# Set 'USE_OLD_VL' to False by default once the new simulation is fixed -# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4091 USE_OLD_VL = False # Debug option for testing round PBS optimization diff --git a/src/concrete/ml/quantization/base_quantized_op.py b/src/concrete/ml/quantization/base_quantized_op.py index 385295e07..eb93ac942 100644 --- a/src/concrete/ml/quantization/base_quantized_op.py +++ b/src/concrete/ml/quantization/base_quantized_op.py @@ -20,6 +20,8 @@ UniformQuantizationParameters, ) +# pylint: disable=too-many-lines + ONNXOpInputOutputType = Union[ numpy.ndarray, QuantizedArray, @@ -873,13 +875,22 @@ class QuantizedMixingOp(QuantizedOp, is_utility=True): """ lsbs_to_remove: Optional[Union[int, dict]] = None - rounding_threshold_bits: Optional[int] = None + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None - def __init__(self, *args, rounding_threshold_bits: Optional[int] = None, **kwargs) -> None: + def __init__( + self, + *args, + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None, + **kwargs, + ) -> None: """Initialize quantized ops parameters plus specific parameters. Args: - rounding_threshold_bits (Optional[int]): Number of bits to round to. + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): if not None, + every accumulators in the model are rounded down to the given bits of precision. + Can be an int or a dictionary with keys 'method' and 'n_bits', where 'method' is + either fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE, and 'n_bits' is either + 'auto' or an int. *args: positional argument to pass to the parent class. **kwargs: named argument to pass to the parent class. """ @@ -959,15 +970,27 @@ def cnp_round( Returns: numpy.ndarray: The rounded array. """ - # Ensure lsbs_to_remove is initialized as a dictionary if not hasattr(self, "lsbs_to_remove") or not isinstance(self.lsbs_to_remove, dict): self.lsbs_to_remove = {} - if self.rounding_threshold_bits is not None and calibrate_rounding: + n_bits = None + exactness = fhe.Exactness.EXACT + + if isinstance(self.rounding_threshold_bits, dict): + n_bits = self.rounding_threshold_bits.get("n_bits", None) + exactness = self.rounding_threshold_bits.get("method", exactness) + # PoT is replacing inplace the rounding_threshold_bits to an int + elif isinstance(self.rounding_threshold_bits, int): + n_bits = self.rounding_threshold_bits + + if n_bits is not None and calibrate_rounding: # Compute lsbs_to_remove only when calibration is True current_n_bits_accumulator = compute_bits_precision(x) - computed_lsbs_to_remove = current_n_bits_accumulator - self.rounding_threshold_bits + + # mypy + assert isinstance(n_bits, int) + computed_lsbs_to_remove = current_n_bits_accumulator - n_bits assert_true( not isinstance(x, fhe.tracing.Tracer), @@ -987,6 +1010,12 @@ def cnp_round( assert isinstance(lsbs_value, int) if lsbs_value > 0: - x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_value) - + # Rounding to low bit-width with approximate can cause issues with overflow protection + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4345 + if exactness == fhe.Exactness.APPROXIMATE: + x = fhe.round_bit_pattern( + x, lsbs_to_remove=lsbs_value, exactness=exactness, overflow_protection=False + ) + else: + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_value, exactness=exactness) return x diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index a6930ec56..27d6a3317 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -212,21 +212,23 @@ class ONNXConverter: of the network's inputs. "op_inputs" and "op_weights" both control the quantization for inputs and weights of all layers. numpy_model (NumpyModule): Model in numpy. - rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down - to the given bits of precision + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision + rounding for model accumulators. Accepts None, an int, or a dict. + The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE) + and 'n_bits' ('auto' or int) """ quant_ops_dict: Dict[str, Tuple[Tuple[str, ...], QuantizedOp]] n_bits: Dict[str, int] quant_params: Dict[str, numpy.ndarray] numpy_model: NumpyModule - rounding_threshold_bits: Optional[int] + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] def __init__( self, n_bits: Union[int, Dict], numpy_model: NumpyModule, - rounding_threshold_bits: Optional[int] = None, + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None, ): self.quant_ops_dict = {} @@ -834,8 +836,12 @@ class PostTrainingAffineQuantization(ONNXConverter): - op_weights: learned parameters or constants in the network - model_outputs: final model output quantization bits numpy_model (NumpyModule): Model in numpy. - rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down - to the given bits of precision + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): if not None, every + accumulators in the model are rounded down to the given + bits of precision. Can be an int or a dictionary with keys + 'method' and 'n_bits', where 'method' is either + fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE, and + 'n_bits' is either 'auto' or an int. is_signed: Whether the weights of the layers can be signed. Currently, only the weights can be signed. diff --git a/src/concrete/ml/torch/compile.py b/src/concrete/ml/torch/compile.py index 9ee4c7e9c..4fa820b27 100644 --- a/src/concrete/ml/torch/compile.py +++ b/src/concrete/ml/torch/compile.py @@ -11,7 +11,7 @@ from brevitas.export.onnx.qonnx.manager import QONNXManager as BrevitasONNXManager from brevitas.nn.quant_layer import QuantInputOutputLayer as QNNMixingLayer from brevitas.nn.quant_layer import QuantNonLinearActLayer as QNNUnivariateLayer -from concrete.fhe import ParameterSelectionStrategy +from concrete.fhe import Exactness, ParameterSelectionStrategy from concrete.fhe.compilation.artifacts import DebugArtifacts from concrete.fhe.compilation.configuration import Configuration @@ -72,7 +72,7 @@ def build_quantized_module( torch_inputset: Dataset, import_qat: bool = False, n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, - rounding_threshold_bits: Optional[int] = None, + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None, reduce_sum_copy=False, ) -> QuantizedModule: """Build a quantized module from a Torch or ONNX model. @@ -88,8 +88,10 @@ def build_quantized_module( import_qat (bool): Flag to signal that the network being imported contains quantizers in in its computation graph and that Concrete ML should not re-quantize it n_bits: the number of bits for the quantization - rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down - to the given bits of precision + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision + rounding for model accumulators. Accepts None, an int, or a dict. + The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE) + and 'n_bits' ('auto' or int) reduce_sum_copy (bool): if the inputs of QuantizedReduceSum should be copied to avoid bit-width propagation @@ -135,7 +137,7 @@ def _compile_torch_or_onnx_model( artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, - rounding_threshold_bits: Optional[int] = None, + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, @@ -164,8 +166,10 @@ def _compile_torch_or_onnx_model( - "model_inputs" and "model_outputs" (optional, default to 5 bits). When using a single integer for n_bits, its value is assigned to "op_inputs" and "op_weights" bits. Default is 8 bits. - rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down - to the given bits of precision + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision + rounding for model accumulators. Accepts None, an int, or a dict. + The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE) + and 'n_bits' ('auto' or int) p_error (Optional[float]): probability of error of a single PBS global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 @@ -177,7 +181,34 @@ def _compile_torch_or_onnx_model( Returns: QuantizedModule: The resulting compiled QuantizedModule. + + Raises: + NotImplementedError: If 'auto' rounding is specified but not implemented. + ValueError: If an invalid type or value is provided for rounding_threshold_bits. """ + n_bits_rounding: Union[None, str, int] = None + method: Exactness = Exactness.EXACT + + # Only process if rounding_threshold_bits is not None + if rounding_threshold_bits is not None: + if isinstance(rounding_threshold_bits, int): + n_bits_rounding = rounding_threshold_bits + elif isinstance(rounding_threshold_bits, dict): + n_bits_rounding = rounding_threshold_bits.get("n_bits") + if n_bits_rounding == "auto": + raise NotImplementedError("Automatic rounding is not implemented yet.") + method_str = rounding_threshold_bits.get("method", method).upper() + if method_str in ["EXACT", "APPROXIMATE"]: + method = Exactness[method_str] + else: + raise ValueError( + f"{method_str} is not a valid method. Must be one of EXACT, APPROXIMATE." + ) + else: + raise ValueError("Invalid type for rounding_threshold_bits. Must be int or dict.") + + assert n_bits_rounding is not None, "n_bits_rounding cannot be None" + rounding_threshold_bits = {"n_bits": n_bits_rounding, "method": method} inputset_as_numpy_tuple = tuple( convert_torch_tensor_or_numpy_array_to_numpy_array(val) for val in to_tuple(torch_inputset) @@ -234,7 +265,7 @@ def compile_torch_model( artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, - rounding_threshold_bits: Optional[int] = None, + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, @@ -264,8 +295,10 @@ def compile_torch_model( - "model_inputs" and "model_outputs" (optional, default to 5 bits). When using a single integer for n_bits, its value is assigned to "op_inputs" and "op_weights" bits. Default is 8 bits. - rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down - to the given bits of precision + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision + rounding for model accumulators. Accepts None, an int, or a dict. + The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE) + and 'n_bits' ('auto' or int) p_error (Optional[float]): probability of error of a single PBS global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 @@ -316,7 +349,7 @@ def compile_onnx_model( artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, - rounding_threshold_bits: Optional[int] = None, + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, @@ -346,8 +379,10 @@ def compile_onnx_model( - "model_inputs" and "model_outputs" (optional, default to 5 bits). When using a single integer for n_bits, its value is assigned to "op_inputs" and "op_weights" bits. Default is 8 bits. - rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down - to the given bits of precision + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision + rounding for model accumulators. Accepts None, an int, or a dict. + The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE) + and 'n_bits' ('auto' or int) p_error (Optional[float]): probability of error of a single PBS global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 @@ -393,7 +428,7 @@ def compile_brevitas_qat_model( configuration: Optional[Configuration] = None, artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, - rounding_threshold_bits: Optional[int] = None, + rounding_threshold_bits: Union[None, int, Dict[str, Union[str, int]]] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, output_onnx_file: Union[None, Path, str] = None, @@ -424,8 +459,10 @@ def compile_brevitas_qat_model( during compilation show_mlir (bool): if set, the MLIR produced by the converter and which is going to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo - rounding_threshold_bits (int): if not None, every accumulators in the model are rounded down - to the given bits of precision + rounding_threshold_bits (Union[None, int, Dict[str, Union[str, int]]]): Defines precision + rounding for model accumulators. Accepts None, an int, or a dict. + The dict can specify 'method' (fhe.Exactness.EXACT or fhe.Exactness.APPROXIMATE) + and 'n_bits' ('auto' or int) p_error (Optional[float]): probability of error of a single PBS global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index a82081e80..8298ca6a1 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -46,7 +46,6 @@ from concrete.ml.common.serialization.dumpers import dump, dumps from concrete.ml.common.serialization.loaders import load, loads from concrete.ml.common.utils import ( - USE_OLD_VL, array_allclose_and_same_shape, get_model_class, get_model_name, @@ -1689,14 +1688,15 @@ def test_fitted_compiled_error_raises( @pytest.mark.flaky @pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS) @pytest.mark.parametrize( - "error_param", - [{"p_error": 0.9999999999990905}], # 1 - 2**-40 - ids=["p_error"], + "error_param, expected_diff", + [({"p_error": 1 - 2**-40}, True), ({"p_error": 2**-40}, False)], + ids=["p_error_high", "p_error_low"], ) def test_p_error_simulation( model_class, parameters, error_param, + expected_diff, load_data, is_weekly_option, ): @@ -1711,17 +1711,12 @@ def test_p_error_simulation( # Get data-set, initialize and fit the model model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option) - # Check if model is linear - is_linear_model = is_model_class_in_a_list(model_class, _get_sklearn_linear_models()) - - # Do not run the test for linear models since there is no PBS (i.e. p_error has no impact) - if is_linear_model: - pytest.skip("Linear models do not have PBS") - - # Compile with a large p_error to be sure the result is random. + # Compile with the specified p_error. model.compile(x, **error_param) - def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_RUN): + def check_for_divergent_predictions( + x, model, fhe, max_iterations=N_ALLOWED_FHE_RUN, tolerance=1e-5 + ): """Detect divergence between simulated/FHE execution and clear run.""" # KNeighborsClassifier does not provide a predict_proba method for now @@ -1736,31 +1731,35 @@ def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_ y_expected = predict_function(x, fhe="disable") for i in range(max_iterations): y_pred = predict_function(x[i : i + 1], fhe=fhe).ravel() - if not numpy.array_equal(y_pred, y_expected[i : i + 1].ravel()): + if not numpy.allclose(y_pred, y_expected[i : i + 1].ravel(), atol=tolerance): return True return False simulation_diff_found = check_for_divergent_predictions(x, model, fhe="simulate") fhe_diff_found = check_for_divergent_predictions(x, model, fhe="execute") - # Check for differences in predictions - # Remark that, with the old VL, linear models (or, more generally, circuits without PBS) were - # badly simulated. It has been fixed in the new simulation. - if is_linear_model and USE_OLD_VL: - - # In FHE, high p_error affect the crypto parameters which - # makes the predictions slightly different - assert fhe_diff_found, "FHE predictions should be different for linear models" + # Check if model is linear + is_linear_model = is_model_class_in_a_list(model_class, _get_sklearn_linear_models()) - # linear models p_error is not simulated - assert not simulation_diff_found, "SIMULATE predictions not the same for linear models" + # Skip the following if model is linear + # Simulation and FHE differs with very high p_error on leveled circuit + # FIXME https://github.com/zama-ai/concrete-ml-internal/issues/4343 + if is_linear_model: + pytest.skip("Skipping test for linear models") + # Check for differences in predictions based on expected_diff + if expected_diff: + assert_msg = ( + "With high p_error, predictions should differ in both FHE and simulation." + f" Found differences: FHE={fhe_diff_found}, Simulation={simulation_diff_found}" + ) + assert fhe_diff_found and simulation_diff_found, assert_msg else: - assert fhe_diff_found and simulation_diff_found, ( - f"Predictions not different in at least one run.\n" - f"FHE predictions differ: {fhe_diff_found}\n" - f"SIMULATE predictions differ: {simulation_diff_found}" + assert_msg = ( + "With low p_error, predictions should not differ in FHE or simulation." + f" Found differences: FHE={fhe_diff_found}, Simulation={simulation_diff_found}" ) + assert not (fhe_diff_found or simulation_diff_found), assert_msg # This test is only relevant for classifier models diff --git a/tests/torch/test_brevitas_qat.py b/tests/torch/test_brevitas_qat.py index 769171d98..fa778431a 100644 --- a/tests/torch/test_brevitas_qat.py +++ b/tests/torch/test_brevitas_qat.py @@ -542,7 +542,7 @@ def test_brevitas_power_of_two( elif manual_rounding: # If manual rounding was set, LSBs_to_remove must be equal # to the accumulator size minus the requested rounding_threshold_bits - assert node_op.rounding_threshold_bits == manual_rounding + assert node_op.rounding_threshold_bits.get("n_bits", None) == manual_rounding assert node_op.produces_graph_output or node_op.lsbs_to_remove is not None # The power-of-two optimization will only work diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index 1d3e65eba..06c4d245d 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -300,13 +300,12 @@ def accuracy_test_rounding( check_is_good_execution_for_cml_vs_circuit, is_brevitas_qat=False, ): - """Check rounding behavior. + """Check rounding behavior with both EXACT and APPROXIMATE methods. The original quantized_numpy_module, compiled over the torch_model without rounding is - compared against quantized_numpy_module_round_low_precision, the torch_model compiled with - a rounding threshold of 2 bits, and quantized_numpy_module_round_high_precision, - the torch_model compiled with maximum bit-width computed with 8 bits (we can't go higher - as rounding does not work with CRT enconding). + compared against quantized_numpy_module_round_low_precision and + quantized_numpy_module_round_high_precision, the torch_model compiled with a rounding threshold + of 2 bits and 8 bits respectively, using both EXACT and APPROXIMATE methods. The final assertion tests whether the mean absolute error between quantized_numpy_module_round_high_precision and quantized_numpy_module is lower than @@ -318,125 +317,77 @@ def accuracy_test_rounding( # feature with enough precision. assert quantized_numpy_module.fhe_circuit.graph.maximum_integer_bit_width() >= 4 - # Compile with a rounding threshold equal to the maximum bit-width - # computed in the original quantized_numpy_module - if is_brevitas_qat: - # the q_result_high_precision should round to 8 bits as we can't round higher - quantized_numpy_module_round_high_precision = compile_brevitas_qat_model( - torch_model, - inputset, - n_bits=n_bits, - configuration=configuration, - rounding_threshold_bits=8, - verbose=verbose, - ) - - # and another quantized module with a rounding threshold equal to 2 bits - quantized_numpy_module_round_low_precision = compile_brevitas_qat_model( - torch_model, - inputset, - n_bits=n_bits, - configuration=configuration, - rounding_threshold_bits=2, - verbose=verbose, - ) + # Define rounding thresholds for high and low precision with both EXACT and APPROXIMATE methods + rounding_thresholds = { + "high_exact": {"method": "EXACT", "n_bits": 8}, + "low_exact": {"method": "EXACT", "n_bits": 2}, + "high_approximate": {"method": "APPROXIMATE", "n_bits": 8}, + "low_approximate": {"method": "APPROXIMATE", "n_bits": 2}, + } - else: - # the q_result_high_precision should round to 8 bits as we can't round higher - quantized_numpy_module_round_high_precision = compile_torch_model( - torch_model, - inputset, - import_qat=import_qat, - configuration=configuration, - n_bits=n_bits, - rounding_threshold_bits=8, - verbose=verbose, - ) + compiled_modules = {} - # and another quantized module with a rounding threshold equal to 2 bits - quantized_numpy_module_round_low_precision = compile_torch_model( - torch_model, - inputset, - import_qat=import_qat, - configuration=configuration, - n_bits=n_bits, - verbose=verbose, - rounding_threshold_bits=2, - ) + # Compile models with different rounding thresholds and methods + for key, rounding_threshold in rounding_thresholds.items(): + if is_brevitas_qat: + compiled_modules[key] = compile_brevitas_qat_model( + torch_model, + inputset, + n_bits=n_bits, + configuration=configuration, + rounding_threshold_bits=rounding_threshold, + verbose=verbose, + ) + else: + compiled_modules[key] = compile_torch_model( + torch_model, + inputset, + import_qat=import_qat, + configuration=configuration, + n_bits=n_bits, + rounding_threshold_bits=rounding_threshold, + verbose=verbose, + ) n_percent_inputset_examples_test = 0.1 # Using the input-set allows to remove any chance of overflow. x_test = create_test_inputset(inputset, n_percent_inputset_examples_test) - # Make sure the two modules have the same quantization result + # Make sure the modules have the same quantization result qtest = to_tuple(quantized_numpy_module.quantize_input(*x_test)) - qtest_high = to_tuple(quantized_numpy_module_round_high_precision.quantize_input(*x_test)) - qtest_low = to_tuple(quantized_numpy_module_round_low_precision.quantize_input(*x_test)) - - assert all( - numpy.array_equal(qtest_i, qtest_high_i) - for (qtest_i, qtest_high_i) in zip(qtest, qtest_high) - ) - assert all( - numpy.array_equal(qtest_i, qtest_low_i) for (qtest_i, qtest_low_i) in zip(qtest, qtest_low) - ) - - results = [] - results_high_precision = [] - results_low_precision = [] + for _, module in compiled_modules.items(): + qtest_rounded = to_tuple(module.quantize_input(*x_test)) + assert all( + numpy.array_equal(qtest_i, qtest_rounded_i) + for (qtest_i, qtest_rounded_i) in zip(qtest, qtest_rounded) + ) + results: dict = {key: [] for key in compiled_modules} for i in range(x_test[0].shape[0]): - - # Extract example i for each tensor in the test tuple with quantized values while - # keeping the dimension of the original tensors (e.g., if it is a tuple of two (100, 10) - # tensors, then each quantized value becomes a tuple of two tensors of shape (1, 10). q_x = tuple(q[[i]] for q in to_tuple(qtest)) - - # encrypt, run, and decrypt with different precision modes - q_result = quantized_numpy_module.quantized_forward(*q_x, fhe="simulate") - q_result_high_precision = quantized_numpy_module_round_high_precision.quantized_forward( - *q_x, - fhe="simulate", - ) - q_result_low_precision = quantized_numpy_module_round_low_precision.quantized_forward( - *q_x, - fhe="simulate", - ) - - # de-quantize the results to obtain the actual output values - result = quantized_numpy_module.dequantize_output(q_result) - result_high_precision = quantized_numpy_module_round_high_precision.dequantize_output( - q_result_high_precision - ) - result_low_precision = quantized_numpy_module_round_low_precision.dequantize_output( - q_result_low_precision - ) - - # append the results to respective lists - results.append(result) - results_high_precision.append(result_high_precision) - results_low_precision.append(result_low_precision) + for key, module in compiled_modules.items(): + q_result = module.quantized_forward(*q_x, fhe="simulate") + result = module.dequantize_output(q_result) + results[key].append(result) # Check modules predictions FHE simulation vs Concrete ML. - check_is_good_execution_for_cml_vs_circuit(x_test, quantized_numpy_module, simulate=simulate) - check_is_good_execution_for_cml_vs_circuit( - x_test, quantized_numpy_module_round_high_precision, simulate=simulate - ) - # low bit-width rounding is not behaving as expected with new simulation - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4331 - # check_is_good_execution_for_cml_vs_circuit( - # x_test, quantized_numpy_module_round_low_precision, simulate=simulate - # ) - - # Check that high precision gives a better match than low precision - # MSE is preferred over MAE here to spot a lack of diversity in the 2 bits rounded model - # e.g., results_low_precision = mean(results) should impact more MSE than MAE. - # mse_high_precision = numpy.mean(numpy.square(numpy.subtract(results, results_high_precision))) - # mse_low_precision = numpy.mean(numpy.square(numpy.subtract(results, results_low_precision))) - - # This assert is too unstable and creates more and more flaky tests, we will investigate a - # better way to assess the rounding feature's performance - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3662 - # assert mse_high_precision <= mse_low_precision, "Rounding is not working as expected." + for key, module in compiled_modules.items(): + + # low bit-width rounding is not behaving as expected with new simulation + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/433 + if "low" not in key: + check_is_good_execution_for_cml_vs_circuit(x_test, module, simulate=simulate) + + # FIXME: The following MSE comparison is commented out due to instability issues. + # We will investigate a better way to assess the rounding feature's performance. + # https://github.com/zama-ai/concrete-ml-internal/issues/3662 + # mse_results = { + # key: numpy.mean(numpy.square(numpy.subtract(results['original'], result_list))) + # for key, result_list in results.items() + # } + # assert (mse_results['high_exact'] <= mse_results['low_exact'], + # "Rounding is not working as expected.") + # assert (mse_results['high_approximate'] <= mse_results['low_approximate'], + # "Rounding is not working as expected.") # This test is a known flaky @@ -1442,3 +1393,65 @@ def forward(self, x): AssertionError, match="Input 'x' is missing in the ONNX graph after export." ): compile_torch_model(model, torch_input, rounding_threshold_bits=3) + + +@pytest.mark.parametrize( + "rounding_threshold_bits, expected_exception, match_message", + [ + ({"n_bits": "auto"}, NotImplementedError, "Automatic rounding is not implemented yet."), + ( + "invalid_type", + ValueError, + "Invalid type for rounding_threshold_bits. Must be int or dict.", + ), + ( + {"method": "INVALID_METHOD"}, + ValueError, + "INVALID_METHOD is not a valid method. Must be one of EXACT, APPROXIMATE.", + ), + ], +) +def test_compile_torch_model_rounding_threshold_bits_errors( + rounding_threshold_bits, expected_exception, match_message, default_configuration +): + """Test that compile_torch_model raises errors for invalid rounding_threshold_bits.""" + model = FCSmall(input_output=5, activation_function=nn.ReLU) + torch_inputset = torch.randn(10, 5) + + with pytest.raises(expected_exception, match=match_message): + compile_torch_model( + torch_model=model, + torch_inputset=torch_inputset, + rounding_threshold_bits=rounding_threshold_bits, + configuration=default_configuration, + ) + + +@pytest.mark.parametrize( + "rounding_method, expected_reinterpret", + [ + ("APPROXIMATE", True), + ("EXACT", False), + ], +) +def test_rounding_mode(rounding_method, expected_reinterpret, default_configuration): + """Test that the underlying FHE circuit uses the right rounding method.""" + model = FCSmall(input_output=5, activation_function=nn.ReLU) + torch_inputset = torch.randn(10, 5) + configuration = default_configuration + + compiled_module = compile_torch_model( + torch_model=model, + torch_inputset=torch_inputset, + rounding_threshold_bits={"method": rounding_method, "n_bits": 4}, + configuration=configuration, + ) + + # Convert compiled module to string to search for patterns + mlir = compiled_module.fhe_circuit.mlir + if expected_reinterpret: + assert ( + "reinterpret_precision" in mlir and "round" not in mlir + ), "Expected 'reinterpret_precision' found but 'round' should not be present." + else: + assert "reinterpret_precision" not in mlir, "Unexpected 'reinterpret_precision' found." diff --git a/use_case_examples/cifar/cifar_brevitas_training/README.md b/use_case_examples/cifar/cifar_brevitas_training/README.md index b826c07b7..ba426aebd 100644 --- a/use_case_examples/cifar/cifar_brevitas_training/README.md +++ b/use_case_examples/cifar/cifar_brevitas_training/README.md @@ -98,7 +98,7 @@ Experiments were conducted on an m6i.metal machine offering 128 CPU cores and 51 | VGG FHE (simulation\*) | 6 bits | 86.0 | | VGG FHE | 6 bits | 86.0\*\* | -We ran the FHE inference over 10 examples and achieved 100% similar predictions between the simulation and FHE. The overall accuracy for the entire data-set is expected to match the simulation. The original model (no rounding) with a maximum of 13 bits of precision runs in around 9 hours on the specified hardware. Using the rounding approach, the final model ran in **4 minutes**. This significant performance improvement demonstrates the benefits of the rounding operator in the FHE setting. +We ran the FHE inference over 10 examples and achieved 100% similar predictions between the simulation and FHE. The overall accuracy for the entire data-set is expected to match the simulation. The original model (no rounding) with a maximum of 13 bits of precision runs in around 9 hours on the specified hardware. Using the rounding approach, the final model ran in **40 seconds**. This significant performance improvement demonstrates the benefits of the rounding operator in the FHE setting. \* Simulation is used to evaluate the accuracy in the clear for faster debugging. \*\* We ran the FHE inference over 10 examples and got 100% similar predictions between the simulation and FHE. The overall accuracy for the entire data-set is expected to match the simulation. diff --git a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py index cd2b34ff7..afec3d6a8 100644 --- a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py +++ b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py @@ -6,6 +6,7 @@ from pathlib import Path import torch +from concrete.fhe import Exactness from concrete.fhe.compilation.configuration import Configuration from models import cnv_2w2a from torch.utils.data import DataLoader @@ -90,7 +91,7 @@ def wrapper(*args, **kwargs): torch_model, x, configuration=configuration, - rounding_threshold_bits=6, + rounding_threshold_bits={"method": Exactness.APPROXIMATE, "n_bits": 6}, p_error=P_ERROR, ) assert isinstance(quantized_numpy_module, QuantizedModule) @@ -137,11 +138,8 @@ def wrapper(*args, **kwargs): print(f"Quantization of a single input (image) took {quantization_execution_time} seconds") print(f"Size of CLEAR input is {q_x_numpy.nbytes} bytes\n") - # Use new VL with .simulate() once CP's multi-parameter/precision bug is fixed - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3856 - p_error = quantized_numpy_module.fhe_circuit.p_error expected_quantized_prediction, clear_inference_time = measure_execution_time( - partial(quantized_numpy_module.fhe_circuit.graph, p_error=p_error) + partial(quantized_numpy_module.fhe_circuit.simulate) )(q_x_numpy) # Encrypt the input