diff --git a/.gitleaksignore b/.gitleaksignore index 5083008ecc..7a6d9b6678 100644 --- a/.gitleaksignore +++ b/.gitleaksignore @@ -2,4 +2,4 @@ 3e2986af4179c38ad5103f571f93372358e4e74f:tests/deployment/test_deployment.py:generic-api-key:59 46d53ae370263367fc49a56638f361495a0ad5d0:tests/deployment/test_deployment.py:generic-api-key:59 2d3b4ca188efb338c03d8d2c921ef39ffc5537e3:tests/deployment/test_deployment.py:generic-api-key:59 -198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59 \ No newline at end of file +198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59 diff --git a/script/make_utils/setup_os_deps.sh b/script/make_utils/setup_os_deps.sh index e360387691..808f2cdead 100755 --- a/script/make_utils/setup_os_deps.sh +++ b/script/make_utils/setup_os_deps.sh @@ -29,8 +29,8 @@ done linux_install_gitleaks () { - GITLEAKS_VERSION=8.5.2 - GITLEAKS_LINUX_X64_SHA256=d83e4721c58638d5a2128ca70341c87fe78b6275483e7dc769a9ca6fe4d25dfd + GITLEAKS_VERSION=8.17.0 + GITLEAKS_LINUX_X64_SHA256=e0e1d641cc55bcf3c0ecc1abcfc6b432e86611a53121d87ce40eacd9467f98c3 GITLEAKS_ARCHIVE_LINK="https://github.com/zricethezav/gitleaks/releases/download/v${GITLEAKS_VERSION}/gitleaks_${GITLEAKS_VERSION}_linux_x64.tar.gz" diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index d0456a6bdc..961b324708 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -577,6 +577,7 @@ def compile( p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, + statuses: Optional[Dict[str, str]] = None, ) -> Circuit: """Compile the module's forward function. @@ -598,9 +599,13 @@ def compile( error to a default value. verbose (bool): Indicate if compilation information should be printed during compilation. Default to False. + statuses (Optional[Dict[str, str]]): same dict as Concrete Python encryption statuses Returns: Circuit: The compiled Circuit. + + Raises: + ValueError: if statuses does not match with the parameters of the quantized module """ inputs = to_tuple(inputs) @@ -621,9 +626,23 @@ def compile( self._clear_forward, self.ordered_module_input_names ) + if statuses is None: + statuses = {arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()} + else: + if not all(arg_name in statuses for arg_name in orig_args_to_proxy_func_args.values()): + raise ValueError( + f"Missing arguments from '{statuses}', expected: " + f"'{list(orig_args_to_proxy_func_args.values())}'." + ) + if not all(value in {"clear", "encrypted"} for value in statuses.values()): + raise ValueError( + "Unexpected status from '{statuses}', expected clear or encrypted." + ) + + assert statuses is not None # For mypy compiler = Compiler( forward_proxy, - {arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()}, + parameter_encryption_statuses=statuses, ) # Quantize the inputs diff --git a/src/concrete/ml/torch/compile.py b/src/concrete/ml/torch/compile.py index 0457d8d4ba..c8f1ef9b5f 100644 --- a/src/concrete/ml/torch/compile.py +++ b/src/concrete/ml/torch/compile.py @@ -3,7 +3,7 @@ import tempfile import warnings from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import numpy import onnx @@ -118,6 +118,7 @@ def _compile_torch_or_onnx_model( p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, + statuses: Optional[Dict[str, str]] = None, ) -> QuantizedModule: """Compile a torch module or ONNX into an FHE equivalent. @@ -142,6 +143,7 @@ def _compile_torch_or_onnx_model( global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 verbose (bool): whether to show compilation information + statuses (Optional[Dict[str, str]]): same dict as Concrete Python encryption statuses Returns: QuantizedModule: The resulting compiled QuantizedModule. @@ -186,6 +188,7 @@ def _compile_torch_or_onnx_model( p_error=p_error, global_p_error=global_p_error, verbose=verbose, + statuses=statuses, ) return quantized_module @@ -204,6 +207,7 @@ def compile_torch_model( p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, + statuses: Optional[Dict[str, str]] = None, ) -> QuantizedModule: """Compile a torch module into an FHE equivalent. @@ -229,6 +233,7 @@ def compile_torch_model( global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 verbose (bool): whether to show compilation information + statuses (Optional[Dict[str, str]]): same dict as Concrete Python encryption statuses Returns: QuantizedModule: The resulting compiled QuantizedModule. @@ -261,6 +266,7 @@ def compile_torch_model( p_error=p_error, global_p_error=global_p_error, verbose=verbose, + statuses=statuses, ) diff --git a/tests/quantization/test_quantized_module.py b/tests/quantization/test_quantized_module.py index 49072a2f50..916ea53049 100644 --- a/tests/quantization/test_quantized_module.py +++ b/tests/quantization/test_quantized_module.py @@ -351,6 +351,110 @@ def test_quantized_module_rounding_fhe(model_class, input_shape, default_configu # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3800 +@pytest.mark.parametrize("model_class, input_shape", [pytest.param(FC, (100, 32 * 32 * 3))]) +def test_statuses(model_class, input_shape, default_configuration): + """Check that giving statuses work properly.""" + + torch_fc_model = model_class(activation_function=nn.ReLU) + torch_fc_model.eval() + + # Create random input + numpy_input = numpy.random.uniform(size=input_shape) + torch_input = torch.from_numpy(numpy_input).float() + assert isinstance(torch_input, torch.Tensor) + assert isinstance(numpy_input, numpy.ndarray) + + # Compile with rounding activated + with pytest.raises(ValueError): + # Empty status + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + statuses={}, + ) + + # Wrong name status + with pytest.raises(ValueError): + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + statuses={"x": "encrypted"}, + ) + + # Wrong encryption status + with pytest.raises(ValueError): + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + statuses={"_onnx__Gemm_0": "random"}, + ) + + # Additional argument (error from Concrete Python) + with pytest.raises(ValueError): + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + statuses={"_onnx__Gemm_0": "encrypted", "test": "encrypted"}, + ) + + # No encrypted value + with pytest.raises(RuntimeError): + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + statuses={"_onnx__Gemm_0": "clear"}, + ) + + # Correct + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + statuses={"_onnx__Gemm_0": "encrypted"}, + ) + + # Default (redundant with other test) + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + statuses=None, + ) + + def quantized_module_predictions_are_equal( quantized_module_1: QuantizedModule, quantized_module_2: QuantizedModule,