From 61c36847e4c6c3d7cffd22cc6bfa63cbf6f51264 Mon Sep 17 00:00:00 2001 From: Roman Bredehoft Date: Tue, 23 Apr 2024 14:36:48 +0200 Subject: [PATCH] chore: avoid checking accuracy when early stop in FHE training --- tests/sklearn/test_fhe_training.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/sklearn/test_fhe_training.py b/tests/sklearn/test_fhe_training.py index ed86d37ea..fd6ea0fdb 100644 --- a/tests/sklearn/test_fhe_training.py +++ b/tests/sklearn/test_fhe_training.py @@ -309,7 +309,7 @@ def test_clear_fit( model.predict_proba(x, fhe="simulate") -# pylint: disable=too-many-arguments +# pylint: disable=too-many-arguments, too-many-branches def check_encrypted_fit( x, y, @@ -318,7 +318,7 @@ def check_encrypted_fit( parameters_range, max_iter, fit_intercept, - check_accuracy, + check_accuracy=None, fhe=None, partial_fit=False, warm_fit=False, @@ -414,8 +414,9 @@ def check_encrypted_fit( (1, 1) ), "When the model is fitted without bias, the bias term should only be made of zeros." - # Check that we overfit properly a linearly separable dataset - check_accuracy(y, y_pred_class) + # If relevant, check that we overfit properly a linearly separable dataset + if check_accuracy is not None: + check_accuracy(y, y_pred_class) return weights, bias, y_pred_proba, y_pred_class, model.random_number_generator @@ -446,7 +447,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, - check_accuracy, + check_accuracy=check_accuracy, fhe="disable", ) ) @@ -460,7 +461,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, - check_accuracy, + check_accuracy=check_accuracy, fhe="simulate", ) ) @@ -475,7 +476,8 @@ def test_encrypted_fit_coherence( # Define early break parameters, with a very high tolerance early_break_kwargs = {"early_stopping": True, "tol": 1e100} - # We don't have any way to detect early break + # We don't have any way to properly test early break, we therefore disable the accuracy check + # in order to avoid flaky issues check_encrypted_fit( x, y, @@ -484,7 +486,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, - check_accuracy, + check_accuracy=None, fhe="simulate", init_kwargs=early_break_kwargs, ) @@ -498,7 +500,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, - check_accuracy, + check_accuracy=check_accuracy, partial_fit=True, ) ) @@ -520,7 +522,7 @@ def test_encrypted_fit_coherence( parameters_range, max_iter, fit_intercept, - check_accuracy, + check_accuracy=check_accuracy, warm_fit=True, init_kwargs=warm_fit_init_kwargs, ) @@ -543,7 +545,7 @@ def test_encrypted_fit_coherence( parameters_range, first_iterations, fit_intercept, - check_accuracy, + check_accuracy=check_accuracy, fhe="simulate", ) @@ -566,7 +568,7 @@ def test_encrypted_fit_coherence( parameters_range, last_iterations, fit_intercept, - check_accuracy, + check_accuracy=check_accuracy, fhe="simulate", random_number_generator=rng_coef_init, fit_kwargs=coef_init_fit_kwargs,