diff --git a/conftest.py b/conftest.py index 23ad1ff5e5..edf9a2ee15 100644 --- a/conftest.py +++ b/conftest.py @@ -2,6 +2,7 @@ import hashlib import json +import os import random import re from pathlib import Path @@ -153,7 +154,7 @@ def default_configuration(): insecure_key_cache_location="ConcretePythonKeyCache", fhe_simulation=False, fhe_execution=True, - compress_input_ciphertexts=True, + compress_input_ciphertexts=os.environ.get("USE_INPUT_COMPRESSION", "1") == "1", ) @@ -171,7 +172,7 @@ def simulation_configuration(): insecure_key_cache_location="ConcretePythonKeyCache", fhe_simulation=True, fhe_execution=False, - compress_input_ciphertexts=True, + compress_input_ciphertexts=os.environ.get("USE_INPUT_COMPRESSION", "1") == "1", ) diff --git a/src/concrete/ml/common/utils.py b/src/concrete/ml/common/utils.py index f769531b7d..03bf4c0dd1 100644 --- a/src/concrete/ml/common/utils.py +++ b/src/concrete/ml/common/utils.py @@ -1,6 +1,7 @@ """Utils that can be re-used by other pieces of code in the module.""" import enum +import os import string from functools import partial from types import FunctionType @@ -47,6 +48,11 @@ # should be exact compared to their Concrete ML QuantizedModule QUANT_ROUND_LIKE_ROUND_PBS = False +# Enable input ciphertext compression +# Note: This setting is fixed and cannot be altered by users +# However, for internal testing purposes, we retain the capability to disable this feature +os.environ["USE_INPUT_COMPRESSION"] = os.environ.get("USE_INPUT_COMPRESSION", "1") + class FheMode(str, enum.Enum): """Enum representing the execution mode. diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index bf09a8f5a1..97172911c1 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -1,6 +1,7 @@ """QuantizedModule API.""" import copy +import os import re from functools import partial from typing import Any, Dict, Generator, Iterable, List, Optional, Sequence, TextIO, Tuple, Union @@ -737,6 +738,9 @@ def compile( # Find the right way to set parameters for compiler, depending on the way we want to default p_error, global_p_error = manage_parameters_for_pbs_errors(p_error, global_p_error) + # Enable input ciphertext compression + enable_input_compression = os.environ.get("USE_INPUT_COMPRESSION", "1") == "1" + self.fhe_circuit = compiler.compile( inputset, configuration=configuration, @@ -748,7 +752,7 @@ def compile( single_precision=False, fhe_simulation=False, fhe_execution=True, - compress_input_ciphertexts=True, + compress_input_ciphertexts=enable_input_compression, ) self._is_compiled = True diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index ef3075590a..8902a23a94 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -560,6 +560,9 @@ def compile( f"{type(module_to_compile)}." ) + # Enable input ciphertext compression + enable_input_compression = os.environ.get("USE_INPUT_COMPRESSION", "1") == "1" + self.fhe_circuit_ = module_to_compile.compile( inputset, configuration=configuration, @@ -571,7 +574,7 @@ def compile( single_precision=False, fhe_simulation=False, fhe_execution=True, - compress_input_ciphertexts=True, + compress_input_ciphertexts=enable_input_compression, ) self._is_compiled = True diff --git a/tests/deployment/test_client_server.py b/tests/deployment/test_client_server.py index b0507bfc56..0372f1910b 100644 --- a/tests/deployment/test_client_server.py +++ b/tests/deployment/test_client_server.py @@ -1,6 +1,7 @@ """Tests the deployment APIs.""" import json +import os import tempfile import warnings import zipfile @@ -107,8 +108,15 @@ def test_client_server_sklearn( x_test, model, key_dir, check_array_equal, check_float_array_equal ) + compilation_kwargs = { + "X": x_train, + "configuration": default_configuration, + } + # Compile the model - fhe_circuit = model.compile(x_train, configuration=default_configuration) + fhe_circuit = model.compile(**compilation_kwargs) + + check_input_compression(model, fhe_circuit, is_torch=False, **compilation_kwargs) # Check that client and server files are properly generated check_client_server_files(model) @@ -150,12 +158,17 @@ def test_client_server_custom_model( torch_model = FCSmall(2, nn.ReLU) + compilation_kwargs = { + "torch_inputset": x_train, + "configuration": default_configuration, + "n_bits": 2, + } + # Get the quantized module from the model and compile it - quantized_numpy_module = compile_torch_model( - torch_model, - x_train, - configuration=default_configuration, - n_bits=2, + quantized_numpy_module = compile_torch_model(torch_model, **compilation_kwargs) + + check_input_compression( + torch_model, quantized_numpy_module.fhe_circuit, is_torch=True, **compilation_kwargs ) # Check that client and server files are properly generated @@ -288,3 +301,34 @@ def check_client_server_execution( # Clean up disk_network.cleanup() + + +def check_input_compression(model, fhe_circuit_compressed, is_torch, **compilation_kwargs): + """Check that input compression properly reduces input sizes.""" + + # Check that input ciphertext compression is enabled + assert os.environ.get("USE_INPUT_COMPRESSION") == "1", "'USE_INPUT_COMPRESSION' is not enabled" + + compressed_size = fhe_circuit_compressed.size_of_inputs + + with pytest.MonkeyPatch.context() as mp_context: + + # Disable input ciphertext compression + mp_context.setenv("USE_INPUT_COMPRESSION", "0") + + # Check that input ciphertext compression is disabled + assert ( + os.environ.get("USE_INPUT_COMPRESSION") == "0" + ), "'USE_INPUT_COMPRESSION' is not disabled" + + if is_torch: + fhe_circuit_uncompressed = compile_torch_model(model, **compilation_kwargs).fhe_circuit + else: + fhe_circuit_uncompressed = model.compile(**compilation_kwargs) + + uncompressed_size = fhe_circuit_uncompressed.size_of_inputs + + assert compressed_size < uncompressed_size, ( + "Compressed input ciphertext's is not smaller than the uncompressed input ciphertext. Got " + f"{compressed_size} bytes (compressed) and {uncompressed_size} bytes (uncompressed)." + ) diff --git a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py index dbfc811f87..afec3d6a87 100644 --- a/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py +++ b/use_case_examples/cifar/cifar_brevitas_training/evaluate_one_example_fhe.py @@ -82,7 +82,6 @@ def wrapper(*args, **kwargs): enable_unsafe_features=True, use_insecure_key_cache=True, insecure_key_cache_location=KEYGEN_CACHE_DIR, - compress_input_ciphertexts=True, ) print("Compiling the model.")