From d31429a23482dd90cf280c11923defbfdce52b48 Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Tue, 11 Jun 2024 17:48:58 +0200 Subject: [PATCH] chore: clean and speed up fhe training tests --- .../ml/deployment/fhe_client_server.py | 4 ++-- .../ml/quantization/quantized_module.py | 2 -- src/concrete/ml/sklearn/base.py | 2 -- src/concrete/ml/sklearn/linear_model.py | 8 ++++++- tests/deployment/test_client_server.py | 2 +- tests/sklearn/test_fhe_training.py | 22 +++++++++++++++++-- 6 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/concrete/ml/deployment/fhe_client_server.py b/src/concrete/ml/deployment/fhe_client_server.py index 70de77ad61..7018524e29 100644 --- a/src/concrete/ml/deployment/fhe_client_server.py +++ b/src/concrete/ml/deployment/fhe_client_server.py @@ -470,6 +470,6 @@ def deserialize_decrypt_dequantize( # In training mode, note that this step does not make much sense for now. Still, nothing # breaks since QuantizedModule don't do anything in post-processing - result = self.model.post_processing(*result) + result_post_processed = self.model.post_processing(*result) - return result + return result_post_processed diff --git a/src/concrete/ml/quantization/quantized_module.py b/src/concrete/ml/quantization/quantized_module.py index de9a071fa4..6c3db8f1bc 100644 --- a/src/concrete/ml/quantization/quantized_module.py +++ b/src/concrete/ml/quantization/quantized_module.py @@ -886,8 +886,6 @@ def compile( global_p_error=global_p_error, verbose=verbose, single_precision=False, - fhe_simulation=False, - fhe_execution=True, compress_input_ciphertexts=enable_input_compression, ) diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index 2acd87eb96..c7adddbe29 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -579,8 +579,6 @@ def compile( global_p_error=global_p_error, verbose=verbose, single_precision=False, - fhe_simulation=False, - fhe_execution=True, compress_input_ciphertexts=enable_input_compression, ) diff --git a/src/concrete/ml/sklearn/linear_model.py b/src/concrete/ml/sklearn/linear_model.py index 341a57fd74..1a5e529467 100644 --- a/src/concrete/ml/sklearn/linear_model.py +++ b/src/concrete/ml/sklearn/linear_model.py @@ -190,6 +190,7 @@ def __init__( self.learning_rate_value = 1.0 self.batch_size = 8 self.training_p_error = 0.01 + self.training_fhe_configuration = None self.fit_encrypted = fit_encrypted self.parameters_range = parameters_range @@ -344,10 +345,15 @@ def _get_training_quantized_module( fit_bias=self.fit_intercept, ) + if self.training_fhe_configuration is None: + configuration = Configuration() + else: + configuration = self.training_fhe_configuration + # Enable the underlying FHE circuit to be composed with itself # This feature is used in order to be able to iterate in the clear n times without having # to encrypt/decrypt the weight/bias values between each loop - configuration = Configuration(composable=True) + configuration.composable = True composition_mapping = {0: 2, 1: 3} diff --git a/tests/deployment/test_client_server.py b/tests/deployment/test_client_server.py index c14e64aea7..8bfed42e67 100644 --- a/tests/deployment/test_client_server.py +++ b/tests/deployment/test_client_server.py @@ -63,11 +63,11 @@ def dev_send_clientspecs_and_modelspecs_to_client(self): @pytest.mark.parametrize("model_class, parameters", MODELS_AND_DATASETS) @pytest.mark.parametrize("n_bits", [2]) def test_client_server_sklearn_inference( - default_configuration, model_class, parameters, n_bits, load_data, + default_configuration, check_is_good_execution_for_cml_vs_circuit, check_array_equal, check_float_array_equal, diff --git a/tests/sklearn/test_fhe_training.py b/tests/sklearn/test_fhe_training.py index a5aa76aeab..a9935503c1 100644 --- a/tests/sklearn/test_fhe_training.py +++ b/tests/sklearn/test_fhe_training.py @@ -312,6 +312,7 @@ def check_encrypted_fit( parameters_range, max_iter, fit_intercept, + configuration, check_accuracy=None, fhe=None, partial_fit=False, @@ -356,6 +357,8 @@ def check_encrypted_fit( # We need to lower the p-error to make sure that the test passes model.training_p_error = 1e-15 + model.training_fhe_configuration = configuration + if partial_fit: # Check that we can swap between disable and simulation modes without any impact on the # final training performance @@ -418,7 +421,13 @@ def check_encrypted_fit( @pytest.mark.parametrize("label_offset", [0, 1]) @pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 30, 1.0)]) def test_encrypted_fit_coherence( - fit_intercept, label_offset, n_bits, max_iter, parameter_min_max, check_accuracy + fit_intercept, + label_offset, + n_bits, + max_iter, + parameter_min_max, + check_accuracy, + simulation_configuration, ): """Test that encrypted fitting works properly.""" @@ -439,6 +448,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, + simulation_configuration, check_accuracy=check_accuracy, fhe="disable", ) @@ -453,6 +463,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, + simulation_configuration, check_accuracy=check_accuracy, fhe="simulate", ) @@ -474,6 +485,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, + simulation_configuration, check_accuracy=check_accuracy, partial_fit=True, ) @@ -496,6 +508,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, + simulation_configuration, check_accuracy=check_accuracy, warm_fit=True, init_kwargs=warm_fit_init_kwargs, @@ -519,6 +532,7 @@ def test_encrypted_fit_coherence( parameters_range, first_iterations, fit_intercept, + simulation_configuration, check_accuracy=check_accuracy, fhe="simulate", ) @@ -542,6 +556,7 @@ def test_encrypted_fit_coherence( parameters_range, last_iterations, fit_intercept, + simulation_configuration, check_accuracy=check_accuracy, fhe="simulate", random_number_generator=rng_coef_init, @@ -569,6 +584,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, + simulation_configuration, check_accuracy=None, fhe="simulate", init_kwargs=early_break_kwargs, @@ -576,7 +592,7 @@ def test_encrypted_fit_coherence( @pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 2, 1.0)]) -def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max): +def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max, default_configuration): """Test that encrypted fitting works properly when executed in FHE.""" # Model parameters @@ -600,6 +616,7 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max): parameters_range, max_iter, fit_intercept, + default_configuration, fhe="disable", ) ) @@ -613,6 +630,7 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max): parameters_range, max_iter, fit_intercept, + default_configuration, fhe="execute", )