From 782f399c3780322fa7d5d1d10b097ddfeca56b84 Mon Sep 17 00:00:00 2001 From: Luis Montero Date: Mon, 10 Jul 2023 15:37:00 +0200 Subject: [PATCH] feat: expose statuses to compile torch --- .gitleaksignore | 2 +- .../ml/quantization/quantized_module.py | 41 ++++++- src/concrete/ml/torch/compile.py | 24 +++- tests/quantization/test_quantized_module.py | 109 +++++++++++++++++- 4 files changed, 165 insertions(+), 11 deletions(-) diff --git a/.gitleaksignore b/.gitleaksignore index 5083008ecc..7a6d9b6678 100644 --- a/.gitleaksignore +++ b/.gitleaksignore @@ -2,4 +2,4 @@ 3e2986af4179c38ad5103f571f93372358e4e74f:tests/deployment/test_deployment.py:generic-api-key:59 46d53ae370263367fc49a56638f361495a0ad5d0:tests/deployment/test_deployment.py:generic-api-key:59 2d3b4ca188efb338c03d8d2c921ef39ffc5537e3:tests/deployment/test_deployment.py:generic-api-key:59 -198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59 \ No newline at end of file +198d3fef188aaf3e3a582b9f7943f7ac6e9b5186:tests/deployment/test_deployment.py:generic-api-key:59 diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index d0456a6bdc..6aa5abf600 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -2,7 +2,7 @@ import copy import re from functools import partial -from typing import Any, Dict, Generator, Iterable, List, Optional, TextIO, Tuple, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, TextIO, Tuple, Union import numpy import onnx @@ -577,6 +577,7 @@ def compile( p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, + inputs_encryption_status: Optional[Sequence[str]] = None, ) -> Circuit: """Compile the module's forward function. @@ -598,9 +599,15 @@ def compile( error to a default value. verbose (bool): Indicate if compilation information should be printed during compilation. Default to False. + inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', + 'encrypted') for each input. Returns: Circuit: The compiled Circuit. + + Raises: + ValueError: if inputs_encryption_status does not match with the + parameters of the quantized module """ inputs = to_tuple(inputs) @@ -621,9 +628,39 @@ def compile( self._clear_forward, self.ordered_module_input_names ) + if inputs_encryption_status is None: + inputs_encryption_status = tuple( + "encrypted" for _ in orig_args_to_proxy_func_args.values() + ) + else: + if len(inputs_encryption_status) < len(orig_args_to_proxy_func_args.values()): + raise ValueError( + f"Missing arguments from '{inputs_encryption_status}', expected " + f"{len(orig_args_to_proxy_func_args.values())} arguments." + ) + if len(inputs_encryption_status) > len(orig_args_to_proxy_func_args.values()): + raise ValueError( + f"Too many arguments in '{inputs_encryption_status}', expected " + f"{len(orig_args_to_proxy_func_args.values())} arguments." + ) + if not all(value in {"clear", "encrypted"} for value in inputs_encryption_status): + raise ValueError( + f"Unexpected status from '{inputs_encryption_status}'," + " expected 'clear' or 'encrypted'." + ) + if not any(value == "encrypted" for value in inputs_encryption_status): + raise ValueError( + f"At least one input should be encrypted but got {inputs_encryption_status}" + ) + + assert inputs_encryption_status is not None # For mypy + inputs_encryption_status_dict = dict( + zip(orig_args_to_proxy_func_args.values(), inputs_encryption_status) + ) + compiler = Compiler( forward_proxy, - {arg_name: "encrypted" for arg_name in orig_args_to_proxy_func_args.values()}, + parameter_encryption_statuses=inputs_encryption_status_dict, ) # Quantize the inputs diff --git a/src/concrete/ml/torch/compile.py b/src/concrete/ml/torch/compile.py index 6130da146b..fd7aeef926 100644 --- a/src/concrete/ml/torch/compile.py +++ b/src/concrete/ml/torch/compile.py @@ -3,7 +3,7 @@ import tempfile import warnings from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Dict, Optional, Sequence, Tuple, Union import numpy import onnx @@ -118,6 +118,7 @@ def _compile_torch_or_onnx_model( p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, + inputs_encryption_status: Optional[Sequence[str]] = None, ) -> QuantizedModule: """Compile a torch module or ONNX into an FHE equivalent. @@ -142,6 +143,8 @@ def _compile_torch_or_onnx_model( global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 verbose (bool): whether to show compilation information + inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') + for each input. Returns: QuantizedModule: The resulting compiled QuantizedModule. @@ -186,6 +189,7 @@ def _compile_torch_or_onnx_model( p_error=p_error, global_p_error=global_p_error, verbose=verbose, + inputs_encryption_status=inputs_encryption_status, ) return quantized_module @@ -199,11 +203,12 @@ def compile_torch_model( configuration: Optional[Configuration] = None, artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, - n_bits=MAX_BITWIDTH_BACKWARD_COMPATIBLE, + n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, rounding_threshold_bits: Optional[int] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, + inputs_encryption_status: Optional[Sequence[str]] = None, ) -> QuantizedModule: """Compile a torch module into an FHE equivalent. @@ -229,6 +234,8 @@ def compile_torch_model( global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 verbose (bool): whether to show compilation information + inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') + for each input. Returns: QuantizedModule: The resulting compiled QuantizedModule. @@ -261,6 +268,7 @@ def compile_torch_model( p_error=p_error, global_p_error=global_p_error, verbose=verbose, + inputs_encryption_status=inputs_encryption_status, ) @@ -272,11 +280,12 @@ def compile_onnx_model( configuration: Optional[Configuration] = None, artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, - n_bits: Union[int, Dict] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, + n_bits: Union[int, Dict[str, int]] = MAX_BITWIDTH_BACKWARD_COMPATIBLE, rounding_threshold_bits: Optional[int] = None, p_error: Optional[float] = None, global_p_error: Optional[float] = None, verbose: bool = False, + inputs_encryption_status: Optional[Sequence[str]] = None, ) -> QuantizedModule: """Compile a torch module into an FHE equivalent. @@ -302,6 +311,8 @@ def compile_onnx_model( global_p_error (Optional[float]): probability of error of the full circuit. In FHE simulation `global_p_error` is set to 0 verbose (bool): whether to show compilation information + inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') + for each input. Returns: QuantizedModule: The resulting compiled QuantizedModule. @@ -326,6 +337,7 @@ def compile_onnx_model( p_error=p_error, global_p_error=global_p_error, verbose=verbose, + inputs_encryption_status=inputs_encryption_status, ) @@ -333,7 +345,7 @@ def compile_onnx_model( def compile_brevitas_qat_model( torch_model: torch.nn.Module, torch_inputset: Dataset, - n_bits: Optional[Union[int, dict]] = None, + n_bits: Optional[Union[int, Dict[str, int]]] = None, configuration: Optional[Configuration] = None, artifacts: Optional[DebugArtifacts] = None, show_mlir: bool = False, @@ -342,6 +354,7 @@ def compile_brevitas_qat_model( global_p_error: Optional[float] = None, output_onnx_file: Union[None, Path, str] = None, verbose: bool = False, + inputs_encryption_status: Optional[Sequence[str]] = None, ) -> QuantizedModule: """Compile a Brevitas Quantization Aware Training model. @@ -374,6 +387,8 @@ def compile_brevitas_qat_model( output_onnx_file (str): temporary file to store ONNX model. If None a temporary file is generated verbose (bool): whether to show compilation information + inputs_encryption_status (Optional[Sequence[str]]): encryption status ('clear', 'encrypted') + for each input. Returns: QuantizedModule: The resulting compiled QuantizedModule. @@ -476,6 +491,7 @@ def compile_brevitas_qat_model( p_error=p_error, global_p_error=global_p_error, verbose=verbose, + inputs_encryption_status=inputs_encryption_status, ) # Remove the tempfile if we used one diff --git a/tests/quantization/test_quantized_module.py b/tests/quantization/test_quantized_module.py index 49072a2f50..0c1a840aca 100644 --- a/tests/quantization/test_quantized_module.py +++ b/tests/quantization/test_quantized_module.py @@ -305,10 +305,14 @@ def test_bitwidth_report(model_class, input_shape, activation_function, default_ ) # Ensure the value are the expected ones - # Disable mypy here as it does not seem to understand that the `range` value is - # a tuple in both the report dictionaries - assert op_report["range"][0] == expected_report["range"][0] # type: ignore[index] - assert op_report["range"][1] == expected_report["range"][1] # type: ignore[index] + assert isinstance(expected_report, dict) + assert isinstance(op_report, dict) + op_report_range = op_report["range"] + expected_report_range = expected_report["range"] + assert isinstance(op_report_range, (tuple, list)) + assert isinstance(expected_report_range, (tuple, list)) + assert op_report_range[0] == expected_report_range[0] + assert op_report_range[1] == expected_report_range[1] assert op_report["bitwidth"] == expected_report["bitwidth"] @@ -351,6 +355,103 @@ def test_quantized_module_rounding_fhe(model_class, input_shape, default_configu # TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3800 +# TODO: extend this test with multi-input encryption status +@pytest.mark.parametrize("model_class, input_shape", [pytest.param(FC, (100, 32 * 32 * 3))]) +def test_inputs_encryption_status(model_class, input_shape, default_configuration): + """Check that giving inputs_encryption_status work properly.""" + + torch_fc_model = model_class(activation_function=nn.ReLU) + torch_fc_model.eval() + + # Create random input + numpy_input = numpy.random.uniform(size=input_shape) + torch_input = torch.from_numpy(numpy_input).float() + assert isinstance(torch_input, torch.Tensor) + assert isinstance(numpy_input, numpy.ndarray) + + # Compile with rounding activated + with pytest.raises(ValueError, match="Missing arguments from '.*', expected 1 arguments."): + # Empty status + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + inputs_encryption_status=[], + ) + + # Wrong encryption status + with pytest.raises( + ValueError, match="Unexpected status from '.*', expected 'clear' or 'encrypted'." + ): + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + inputs_encryption_status=("random",), + ) + + # Additional argument (error from Concrete Python) + with pytest.raises(ValueError, match="Too many arguments in '.*', expected 1 arguments."): + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + inputs_encryption_status=( + "encrypted", + "encrypted", + ), + ) + + # No encrypted value + with pytest.raises(ValueError, match="At least one input should be encrypted but got .*"): + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + inputs_encryption_status=("clear",), + ) + + # Correct + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + inputs_encryption_status=("encrypted",), + ) + + # Default (redundant with other test) + compile_torch_model( + torch_fc_model, + torch_input, + False, + default_configuration, + n_bits=2, + p_error=0.01, + rounding_threshold_bits=6, + inputs_encryption_status=None, + ) + + def quantized_module_predictions_are_equal( quantized_module_1: QuantizedModule, quantized_module_2: QuantizedModule,