diff --git a/tests/quantization/test_compilation.py b/tests/quantization/test_compilation.py index e482a57b3..e32e81418 100644 --- a/tests/quantization/test_compilation.py +++ b/tests/quantization/test_compilation.py @@ -46,6 +46,7 @@ @pytest.mark.parametrize("n_bits", [2]) @pytest.mark.parametrize("simulate", [True, False]) @pytest.mark.parametrize("verbose", [True, False]) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_quantized_module_compilation( input_output_feature, model, @@ -55,6 +56,7 @@ def test_quantized_module_compilation( simulate, check_graph_input_has_no_tlu, verbose, + compress_evaluation_keys, check_is_good_execution_for_cml_vs_circuit, ): """Test a neural network compilation for FHE inference.""" @@ -82,6 +84,7 @@ def test_quantized_module_compilation( numpy_input, default_configuration, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=simulate) @@ -104,6 +107,7 @@ def test_quantized_module_compilation( ) @pytest.mark.parametrize("simulate", [True, False]) @pytest.mark.parametrize("verbose", [True, False]) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_quantized_cnn_compilation( input_output_feature, model, @@ -112,6 +116,7 @@ def test_quantized_cnn_compilation( check_graph_input_has_no_tlu, simulate, verbose, + compress_evaluation_keys, check_is_good_execution_for_cml_vs_circuit, ): """Test a convolutional neural network compilation for FHE inference.""" @@ -142,6 +147,7 @@ def test_quantized_cnn_compilation( numpy_input, default_configuration, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=simulate) check_graph_input_has_no_tlu(quantized_model.fhe_circuit.graph) @@ -317,10 +323,7 @@ def test_compile_multi_input_nn_with_input_tlus( quantized_model = post_training_quant.quantize_module(numpy_input) # Compile - quantized_model.compile( - numpy_input, - default_configuration, - ) + quantized_model.compile(numpy_input, default_configuration) check_is_good_execution_for_cml_vs_circuit(numpy_input, quantized_model, simulate=True) if can_remove_tlu: diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index 6666087e1..be7cb7929 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -1492,6 +1492,7 @@ def test_input_support( @pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_inference_methods( model_class, parameters, @@ -1499,6 +1500,7 @@ def test_inference_methods( is_weekly_option, check_float_array_equal, default_configuration, + compress_evaluation_keys, verbose=True, ): """Test inference methods.""" @@ -1506,7 +1508,7 @@ def test_inference_methods( model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option) - model.compile(x, default_configuration) + model.compile(x, default_configuration, compress_evaluation_keys=compress_evaluation_keys) if verbose: print("Run check_inference_methods") @@ -1562,6 +1564,7 @@ def test_pipeline( < min(N_BITS_LINEAR_MODEL_CRYPTO_PARAMETERS, N_BITS_THRESHOLD_TO_FORCE_EXECUTION_NOT_IN_FHE) ], ) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) # pylint: disable=too-many-branches def test_predict_correctness( model_class, @@ -1572,6 +1575,7 @@ def test_predict_correctness( default_configuration, check_is_good_execution_for_cml_vs_circuit, is_weekly_option, + compress_evaluation_keys, verbose=True, ): """Test prediction correctness between clear quantized and FHE simulation or execution.""" @@ -1591,7 +1595,7 @@ def test_predict_correctness( if verbose: print("Compile the model") - model.compile(x, default_configuration) + model.compile(x, default_configuration, compress_evaluation_keys=compress_evaluation_keys) if verbose: print(f"Check prediction correctness for {fhe_samples} samples.") @@ -1623,6 +1627,7 @@ def test_predict_correctness( < min(N_BITS_LINEAR_MODEL_CRYPTO_PARAMETERS, N_BITS_THRESHOLD_TO_FORCE_EXECUTION_NOT_IN_FHE) ], ) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) # pylint: disable=too-many-branches def test_separated_inference( model_class, @@ -1633,6 +1638,7 @@ def test_separated_inference( default_configuration, is_weekly_option, check_float_array_equal, + compress_evaluation_keys, verbose=True, ): """Test prediction correctness between clear quantized and FHE simulation or execution.""" @@ -1653,7 +1659,9 @@ def test_separated_inference( if verbose: print("Compile the model") - fhe_circuit = model.compile(x, default_configuration) + fhe_circuit = model.compile( + x, default_configuration, compress_evaluation_keys=compress_evaluation_keys + ) if verbose: print("Run check_separated_inference") diff --git a/tests/torch/test_compile_torch.py b/tests/torch/test_compile_torch.py index 41e7846ec..c0555b2a2 100644 --- a/tests/torch/test_compile_torch.py +++ b/tests/torch/test_compile_torch.py @@ -70,7 +70,9 @@ def create_test_inputset(inputset, n_percent_inputset_examples_test): return x_test -def get_and_compile_quantized_module(model, inputset, import_qat, n_bits, configuration, verbose): +def get_and_compile_quantized_module( + model, inputset, import_qat, n_bits, configuration, verbose, compress_evaluation_keys +): """Get and compile the quantized module built from the given model.""" quantized_numpy_module = build_quantized_module( model, @@ -87,6 +89,7 @@ def get_and_compile_quantized_module(model, inputset, import_qat, n_bits, config p_error=p_error, global_p_error=global_p_error, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) return quantized_numpy_module @@ -108,6 +111,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many get_and_compile=False, input_shape=None, is_brevitas_qat=False, + compress_evaluation_keys=True, ) -> QuantizedModule: """Test the different model architecture from torch numpy.""" @@ -164,6 +168,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many n_bits=n_bits, configuration=default_configuration, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) else: @@ -174,6 +179,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many configuration=default_configuration, n_bits=n_bits, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) else: if is_brevitas_qat: @@ -185,6 +191,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many n_bits=n_bits, configuration=default_configuration, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) elif get_and_compile: @@ -195,6 +202,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many n_bits=n_bits, configuration=default_configuration, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) else: @@ -205,6 +213,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many configuration=default_configuration, n_bits=n_bits, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) n_examples_test = 1 @@ -228,6 +237,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many n_bits=n_bits, configuration=default_configuration, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) else: @@ -253,6 +263,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many n_bits=n_bits, configuration=default_configuration, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) else: @@ -263,6 +274,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many configuration=default_configuration, n_bits=n_bits, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) accuracy_test_rounding( @@ -276,6 +288,7 @@ def compile_and_test_torch_or_onnx( # pylint: disable=too-many-locals, too-many verbose=verbose, check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit, is_brevitas_qat=is_brevitas_qat, + compress_evaluation_keys=compress_evaluation_keys, ) if dump_onnx: @@ -298,6 +311,7 @@ def accuracy_test_rounding( simulate, verbose, check_is_good_execution_for_cml_vs_circuit, + compress_evaluation_keys, is_brevitas_qat=False, ): """Check rounding behavior. @@ -329,6 +343,7 @@ def accuracy_test_rounding( configuration=configuration, rounding_threshold_bits=8, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) # and another quantized module with a rounding threshold equal to 2 bits @@ -339,6 +354,7 @@ def accuracy_test_rounding( configuration=configuration, rounding_threshold_bits=2, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) else: @@ -351,6 +367,7 @@ def accuracy_test_rounding( n_bits=n_bits, rounding_threshold_bits=8, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) # and another quantized module with a rounding threshold equal to 2 bits @@ -362,6 +379,7 @@ def accuracy_test_rounding( n_bits=n_bits, verbose=verbose, rounding_threshold_bits=2, + compress_evaluation_keys=compress_evaluation_keys, ) n_percent_inputset_examples_test = 0.1 @@ -463,6 +481,9 @@ def accuracy_test_rounding( @pytest.mark.parametrize("simulate", [True, False], ids=["FHE_simulation", "FHE"]) @pytest.mark.parametrize("is_onnx", [True, False], ids=["is_onnx", ""]) @pytest.mark.parametrize("get_and_compile", [True, False], ids=["get_and_compile", "compile"]) +@pytest.mark.parametrize( + "compress_evaluation_keys", [True, False], ids=["compress_evaluation_keys", ""] +) def test_compile_torch_or_onnx_networks( input_output_feature, model, @@ -473,6 +494,7 @@ def test_compile_torch_or_onnx_networks( get_and_compile, check_is_good_execution_for_cml_vs_circuit, is_weekly_option, + compress_evaluation_keys, ): """Test the different model architecture from torch numpy.""" @@ -495,6 +517,7 @@ def test_compile_torch_or_onnx_networks( check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit, verbose=False, get_and_compile=get_and_compile, + compress_evaluation_keys=compress_evaluation_keys, ) @@ -517,6 +540,7 @@ def test_compile_torch_or_onnx_networks( ) @pytest.mark.parametrize("simulate", [True, False]) @pytest.mark.parametrize("is_onnx", [True, False]) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument model, is_1d, @@ -527,6 +551,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument check_graph_input_has_no_tlu, check_graph_output_has_no_tlu, check_is_good_execution_for_cml_vs_circuit, + compress_evaluation_keys, ): """Test the different model architecture from torch numpy.""" @@ -547,6 +572,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit, verbose=False, input_shape=input_shape, + compress_evaluation_keys=compress_evaluation_keys, ) check_graph_input_has_no_tlu(q_module.fhe_circuit.graph) @@ -598,6 +624,7 @@ def test_compile_torch_or_onnx_conv_networks( # pylint: disable=unused-argument ) @pytest.mark.parametrize("simulate", [True, False]) @pytest.mark.parametrize("is_onnx", [True, False]) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_compile_torch_or_onnx_activations( input_output_feature, model, @@ -606,6 +633,7 @@ def test_compile_torch_or_onnx_activations( simulate, is_onnx, check_is_good_execution_for_cml_vs_circuit, + compress_evaluation_keys, ): """Test the different model architecture from torch numpy.""" @@ -622,6 +650,7 @@ def test_compile_torch_or_onnx_activations( is_onnx, check_is_good_execution_for_cml_vs_circuit, verbose=False, + compress_evaluation_keys=compress_evaluation_keys, ) @@ -640,6 +669,7 @@ def test_compile_torch_or_onnx_activations( [pytest.param(n_bits) for n_bits in [1, 2]], ) @pytest.mark.parametrize("simulate", [True, False]) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_compile_torch_qat( input_output_feature, model, @@ -647,6 +677,7 @@ def test_compile_torch_qat( default_configuration, simulate, check_is_good_execution_for_cml_vs_circuit, + compress_evaluation_keys, ): """Test the different model architecture from torch numpy.""" @@ -666,6 +697,7 @@ def test_compile_torch_qat( is_onnx, check_is_good_execution_for_cml_vs_circuit, verbose=False, + compress_evaluation_keys=compress_evaluation_keys, ) @@ -678,6 +710,7 @@ def test_compile_torch_qat( [pytest.param(n_bits) for n_bits in [2]], ) @pytest.mark.parametrize("simulate", [True, False]) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_compile_brevitas_qat( model_class, input_output_feature, @@ -686,6 +719,7 @@ def test_compile_brevitas_qat( simulate, default_configuration, check_is_good_execution_for_cml_vs_circuit, + compress_evaluation_keys, ): """Test compile_brevitas_qat_model.""" @@ -710,6 +744,7 @@ def test_compile_brevitas_qat( check_is_good_execution_for_cml_vs_circuit=check_is_good_execution_for_cml_vs_circuit, verbose=False, is_brevitas_qat=is_brevitas_qat, + compress_evaluation_keys=compress_evaluation_keys, ) @@ -758,12 +793,14 @@ def test_compile_brevitas_qat( pytest.param(nn.ReLU, id="relu"), ], ) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) def test_dump_torch_network( model_class, expected_onnx_str, activation_function, default_configuration, check_is_good_execution_for_cml_vs_circuit, + compress_evaluation_keys, ): """This is a test which is equivalent to tests in test_dump_onnx.py, but for torch modules.""" input_output_feature = 7 @@ -783,10 +820,12 @@ def test_dump_torch_network( dump_onnx=True, expected_onnx_str=expected_onnx_str, verbose=False, + compress_evaluation_keys=compress_evaluation_keys, ) @pytest.mark.parametrize("verbose", [True, False], ids=["with_verbose", "without_verbose"]) +@pytest.mark.parametrize("compress_evaluation_keys", [True, False]) # pylint: disable-next=too-many-locals def test_pretrained_mnist_qat( default_configuration, @@ -795,6 +834,7 @@ def test_pretrained_mnist_qat( check_graph_output_has_no_tlu, check_is_good_execution_for_cml_vs_circuit, is_weekly_option, + compress_evaluation_keys, ): """Load a QAT MNIST model and confirm we get the same results in FHE simulation as with ONNX.""" if not is_weekly_option: @@ -845,6 +885,7 @@ def test_pretrained_mnist_qat( configuration=default_configuration, n_bits=n_bits, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) quantized_numpy_module.check_model_is_compiled() @@ -884,6 +925,7 @@ def test_pretrained_mnist_qat( configuration=default_configuration, n_bits=n_bits, verbose=verbose, + compress_evaluation_keys=compress_evaluation_keys, ) # As this is a custom QAT network, the input goes through multiple univariate