Skip to content

Commit

Permalink
chore: avoid checking accuracy when early stop in FHE training
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Apr 23, 2024
1 parent 3405978 commit b8c7dd7
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_clear_fit(
model.predict_proba(x, fhe="simulate")


# pylint: disable=too-many-arguments
# pylint: disable=too-many-arguments, too-many-branches
def check_encrypted_fit(
x,
y,
Expand All @@ -318,7 +318,7 @@ def check_encrypted_fit(
parameters_range,
max_iter,
fit_intercept,
check_accuracy,
check_accuracy=None,
fhe=None,
partial_fit=False,
warm_fit=False,
Expand Down Expand Up @@ -414,8 +414,9 @@ def check_encrypted_fit(
(1, 1)
), "When the model is fitted without bias, the bias term should only be made of zeros."

# Check that we overfit properly a linearly separable dataset
check_accuracy(y, y_pred_class)
# If relevant, check that we overfit properly a linearly separable dataset
if check_accuracy is not None:
check_accuracy(y, y_pred_class)

return weights, bias, y_pred_proba, y_pred_class, model.random_number_generator

Expand Down Expand Up @@ -446,7 +447,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
check_accuracy,
check_accuracy=check_accuracy,
fhe="disable",
)
)
Expand All @@ -460,7 +461,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
check_accuracy,
check_accuracy=check_accuracy,
fhe="simulate",
)
)
Expand All @@ -475,7 +476,8 @@ def test_encrypted_fit_coherence(
# Define early break parameters, with a very high tolerance
early_break_kwargs = {"early_stopping": True, "tol": 1e100}

# We don't have any way to detect early break
# We don't have any way to properly test early break, we therefore disable the accuracy check
# in order to avoid flaky issues
check_encrypted_fit(
x,
y,
Expand All @@ -484,7 +486,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
check_accuracy,
check_accuracy=None,
fhe="simulate",
init_kwargs=early_break_kwargs,
)
Expand All @@ -498,7 +500,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
check_accuracy,
check_accuracy=check_accuracy,
partial_fit=True,
)
)
Expand All @@ -520,7 +522,7 @@ def test_encrypted_fit_coherence(
parameters_range,
max_iter,
fit_intercept,
check_accuracy,
check_accuracy=check_accuracy,
warm_fit=True,
init_kwargs=warm_fit_init_kwargs,
)
Expand All @@ -543,7 +545,7 @@ def test_encrypted_fit_coherence(
parameters_range,
first_iterations,
fit_intercept,
check_accuracy,
check_accuracy=check_accuracy,
fhe="simulate",
)

Expand All @@ -566,7 +568,7 @@ def test_encrypted_fit_coherence(
parameters_range,
last_iterations,
fit_intercept,
check_accuracy,
check_accuracy=check_accuracy,
fhe="simulate",
random_number_generator=rng_coef_init,
fit_kwargs=coef_init_fit_kwargs,
Expand Down

0 comments on commit b8c7dd7

Please sign in to comment.