From cf6f22434fae3042af9b14378e2b8bddafda1f31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jordan=20Fr=C3=A9ry?= Date: Wed, 17 Apr 2024 17:50:15 +0200 Subject: [PATCH] chore: use cp vl when circuit is CRT based (#619) --- .../ml/quantization/quantized_module.py | 8 +++- src/concrete/ml/sklearn/base.py | 8 +++- tests/quantization/test_quantized_module.py | 46 +++++++++++++++++++ 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index 3b61ac816..6155c4ea8 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -522,9 +522,13 @@ def _fhe_forward( # If the inference should be executed using simulation if simulate: + is_crt_encoding = self.fhe_circuit.statistics["packing_key_switch_count"] != 0 - # If the old simulation method should be used - if USE_OLD_VL: + # If the virtual library method should be used + # For now, use the virtual library when simulating + # circuits that use CRT encoding because the official simulation is too slow + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4391 + if USE_OLD_VL or is_crt_encoding: predict_method = partial( self.fhe_circuit.graph, p_error=self.fhe_circuit.p_error ) # pragma: no cover diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index cf8a8e3e1..8a76458c4 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -645,9 +645,13 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy. # If the inference should be executed using simulation if fhe == "simulate": + is_crt_encoding = self.fhe_circuit.statistics["packing_key_switch_count"] != 0 - # If the old simulation method should be used - if USE_OLD_VL: + # If the virtual library method should be used + # For now, use the virtual library when simulating + # circuits that use CRT encoding because the official simulation is too slow + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4391 + if USE_OLD_VL or is_crt_encoding: predict_method = partial( self.fhe_circuit.graph, p_error=self.fhe_circuit.p_error ) # pragma: no cover diff --git a/tests/quantization/test_quantized_module.py b/tests/quantization/test_quantized_module.py index 0599dcad0..59136575c 100644 --- a/tests/quantization/test_quantized_module.py +++ b/tests/quantization/test_quantized_module.py @@ -512,3 +512,49 @@ def test_quantized_module_initialization_error(): ordered_module_output_names=None, # This makes the combination invalid quant_layers_dict={"layer1": (["input1"], "QuantizedOp")}, ) + + +@pytest.mark.parametrize("model_class, input_shape", [pytest.param(FC, (100, 32 * 32 * 3))]) +def test_crt_circuit_creation_with_rounding(model_class, input_shape, default_configuration): + """Test the creation of CRT and non-CRT circuits based on rounding settings.""" + + torch_fc_model = model_class(activation_function=nn.ReLU) + torch_fc_model.eval() + + # Create random input + numpy_input = numpy.random.uniform(size=input_shape) + torch_input = torch.from_numpy(numpy_input).float() + + # Compile with rounding_threshold_bits = 6 + quantized_model_with_rounding = compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=6, + p_error=0.01, + rounding_threshold_bits=4, + ) + + # Check that the packing key switch count is not present + # This should not be using CRT encoding + assert ( + quantized_model_with_rounding.fhe_circuit.statistics.get("packing_key_switch_count", 0) == 0 + ), "Packing key switch count should be 0 when rounding_threshold_bits is set." + + # Compile with rounding_threshold_bits = None + # Which should create a CRT bsaed encoding circuit + quantized_model_without_rounding = compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=4, + rounding_threshold_bits=None, + ) + + # Check that the packing key switch count is not zero + assert ( + quantized_model_without_rounding.fhe_circuit.statistics.get("packing_key_switch_count", 0) + > 0 + ), "Packing key switch count should be > 0 when rounding_threshold_bits is not set."