Skip to content

Commit

Permalink
chore: fix and speed up fhe training tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Jun 11, 2024
1 parent 99fe596 commit 156fb4d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
12 changes: 9 additions & 3 deletions tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def test_client_server_sklearn_inference(
model = instantiate_model_generic(model_class, n_bits=n_bits)

# Fit the model
model.fit(x_train, y_train, fhe="disable")
if getattr(model, "fit_encrypted", False):
model.fit(x_train, y_train, fhe="disable")
else:
model.fit(x_train, y_train)

key_dir = default_configuration.insecure_key_cache_location

Expand Down Expand Up @@ -361,7 +364,10 @@ def test_save_mode_handling(n_bits, fit_encrypted, mode, error_message):
)

# Fit the model in the clear
model.fit(x_train, y_train, fhe="disable")
if getattr(model, "fit_encrypted", False):
model.fit(x_train, y_train, fhe="disable")
else:
model.fit(x_train, y_train)

# Compile
model.compile(X=x_train)
Expand Down Expand Up @@ -691,7 +697,7 @@ def test_client_server_sklearn_training(
# so we fix a lower value in order to speed-up tests, especially since we do not actually check
# any score here
model.batch_size = batch_size
model.n_bits_training = 2
model.n_bits_training = n_bits

# Generate the min and max values for x_train and y_train
x_min, x_max = x_train.min(axis=0), x_train.max(axis=0)
Expand Down
21 changes: 11 additions & 10 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
from concrete.ml.sklearn import SGDClassifier


def get_blob_data(n_classes=2, scale_input=False, parameters_range=None):
def get_blob_data(
n_samples=1000, n_classes=2, n_features=8, scale_input=False, parameters_range=None
):
"""Get the training data."""

n_samples = 1000
n_features = 8

# Generate the input and target values
# pylint: disable-next=unbalanced-tuple-unpacking
x, y = make_blobs(n_samples=n_samples, centers=n_classes, n_features=n_features)
Expand Down Expand Up @@ -417,7 +416,7 @@ def check_encrypted_fit(
# pylint: disable=too-many-statements,protected-access,too-many-locals
@pytest.mark.parametrize("fit_intercept", [True, False])
@pytest.mark.parametrize("label_offset", [0, 1])
@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 30, 1.0)])
@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 20, 1.0)])
def test_encrypted_fit_coherence(
fit_intercept, label_offset, n_bits, max_iter, parameter_min_max, check_accuracy
):
Expand Down Expand Up @@ -576,8 +575,8 @@ def test_encrypted_fit_coherence(
)


@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 5, 1.0)])
def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max, check_accuracy):
@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 2, 1.0)])
def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max):
"""Test that encrypted fitting works properly when executed in FHE."""

# Model parameters
Expand All @@ -586,9 +585,12 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max, check_accurac
fit_intercept = True

# Generate a data-set with binary target classes
x, y = get_blob_data(scale_input=True, parameters_range=parameters_range)
x, y = get_blob_data(n_features=2, scale_input=True, parameters_range=parameters_range)
y = y + 1

# Avoid checking the accuracy. Since this test is mostly here to make sure that FHE execution
# properly matches the quantized clear one, some parameters (for example, the number of
# features) were set to make it quicker, without considering the model's accuracy
weights_disable, bias_disable, y_pred_proba_disable, y_pred_class_disable, _ = (
check_encrypted_fit(
x,
Expand All @@ -598,11 +600,11 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max, check_accurac
parameters_range,
max_iter,
fit_intercept,
check_accuracy=check_accuracy,
fhe="disable",
)
)

# Same, avoid checking the accuracy
weights_fhe, bias_fhe, y_pred_proba_fhe, y_pred_class_fhe, _ = check_encrypted_fit(
x,
y,
Expand All @@ -611,7 +613,6 @@ def test_encrypted_fit_in_fhe(n_bits, max_iter, parameter_min_max, check_accurac
parameters_range,
max_iter,
fit_intercept,
check_accuracy=check_accuracy,
fhe="execute",
)

Expand Down

0 comments on commit 156fb4d

Please sign in to comment.