Skip to content

Commit

Permalink
chore: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed May 29, 2024
1 parent 35b4e9c commit 3fc2c0a
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 92 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ filterwarnings = [
"ignore:Converting a tensor to a NumPy array might cause the trace to be incorrect.",
"ignore:torch.from_numpy results are registered as constants in the trace.",
"ignore:ONNX Preprocess - Removing mutation from node aten*:UserWarning",
"ignore:Liblinear failed to converge, increase the number of iterations.*:sklearn.exceptions.ConvergenceWarning",
"ignore:Liblinear failed to converge,*:sklearn.exceptions.ConvergenceWarning",
"ignore:lbfgs failed to converge,*:sklearn.exceptions.ConvergenceWarning",
"ignore:Maximum number of iteration reached before convergence.*:sklearn.exceptions.ConvergenceWarning",
]

[tool.semantic_release]
Expand Down
13 changes: 6 additions & 7 deletions src/concrete/ml/deployment/fhe_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@ def run(

result = self.server.run(*deserialized_data, evaluation_keys=deserialized_keys)

if isinstance(result, tuple):
return tuple(res.serialize() for res in result)
return result.serialize()
return (
tuple(res.serialize() for res in result)
if isinstance(result, tuple)
else result.serialize()
)


class FHEModelDev:
Expand Down Expand Up @@ -384,10 +386,7 @@ def quantize_encrypt_serialize(
serialized_enc_qx = tuple(e.serialize() for e in enc_qx)

# Return a single value if the original input was a single value
if len(serialized_enc_qx) == 1:
return serialized_enc_qx[0]

return serialized_enc_qx
return serialized_enc_qx[0] if len(serialized_enc_qx) == 1 else serialized_enc_qx

def deserialize_decrypt(
self, serialized_encrypted_quantized_result: Union[bytes, Tuple[bytes, ...]]
Expand Down
2 changes: 1 addition & 1 deletion src/concrete/ml/quantization/quantized_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def post_processing(
Returns:
Union[numpy.ndarray, Tuple[numpy.ndarray, ...]]: The post-processed values.
"""
return values
return values[0] if len(values) == 1 else values

def _set_output_quantizers(self) -> List[UniformQuantizer]:
"""Get the output quantizers.
Expand Down
9 changes: 1 addition & 8 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import itertools
import time
import warnings
from typing import Any, Dict, Optional, Union

import numpy
Expand Down Expand Up @@ -751,13 +750,7 @@ def partial_fit(
# A partial fit is similar to a fit with a single iteration. The slight differences between
# both are handled in the encrypted method when setting `is_partial_fit` to True.
if self.fit_encrypted:
if fhe is None:
fhe = "disable"
warnings.warn(
"Parameter 'fhe' isn't set while FHE training is enabled.\n"
f"Defaulting to '{fhe=}'",
stacklevel=2,
)
fhe = "disable" if fhe is None else fhe

# Make sure the `fhe` parameter is correct
assert FheMode.is_valid(fhe), (
Expand Down
4 changes: 2 additions & 2 deletions tests/common/test_skearn_model_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,5 @@ def test_get_sklearn_models():
def test_models_and_datasets():
"""Check that the tested model's configuration lists remain fixed."""

assert len(MODELS_AND_DATASETS) == 32
assert len(UNIQUE_MODELS_AND_DATASETS) == 21
assert len(MODELS_AND_DATASETS) == 34
assert len(UNIQUE_MODELS_AND_DATASETS) == 22
3 changes: 2 additions & 1 deletion tests/sklearn/test_dump_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,8 @@ def check_onnx_file_dump(
# KNN can only be compiled with small quantization bit numbers for now
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979
model.n_bits = 2
model.fit(x, y)

model.fit(x, y)

with warnings.catch_warnings():
# Use FHE simulation to not have issues with precision
Expand Down
88 changes: 16 additions & 72 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,53 +41,39 @@ def test_init_error_raises(n_bits, parameter_min_max):
random_state = numpy.random.randint(0, 2**15)
parameters_range = (-parameter_min_max, parameter_min_max)

with pytest.warns(
UserWarning,
match=(
"FHE training is an experimental feature. Please be aware that the API might change "
"in future versions."
with pytest.raises(
ValueError,
match=re.escape(
"Only 'log_loss' is currently supported if FHE training is enabled"
" (fit_encrypted=True). Got loss='perceptron'"
),
):
SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
loss="perceptron",
random_state=random_state,
parameters_range=parameters_range,
)

with pytest.raises(
ValueError,
match=re.escape(
"Only 'log_loss' is currently supported if FHE training is enabled"
" (fit_encrypted=True). Got loss='perceptron'"
),
ValueError, match="Setting 'parameter_range' is mandatory if FHE training is enabled."
):
SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
loss="perceptron",
random_state=random_state,
parameters_range=parameters_range,
parameters_range=None,
fit_intercept=True,
)

with pytest.raises(
ValueError, match="Setting 'parameter_range' is mandatory if FHE training is enabled."
):
SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
random_state=random_state,
parameters_range=None,
fit_intercept=True,
)

SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
random_state=random_state,
parameters_range=parameters_range,
fit_intercept=False,
)
SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
random_state=random_state,
parameters_range=parameters_range,
fit_intercept=False,
)


@pytest.mark.parametrize("n_classes", [1, 3])
Expand Down Expand Up @@ -208,20 +194,6 @@ def test_encrypted_fit_warning_error_raises(n_bits, max_iter, parameter_min_max)
max_iter=max_iter,
)

with pytest.warns(
UserWarning,
match="Parameter 'fhe' isn't set while FHE training is enabled.\n"
"Defaulting to 'fhe='disable''",
):
model.fit(x, y, fhe=None)

with pytest.warns(
UserWarning,
match="Parameter 'fhe' isn't set while FHE training is enabled.\n"
"Defaulting to 'fhe='disable''",
):
model.partial_fit(x, y, fhe=None)

with pytest.raises(
NotImplementedError,
match="Parameter 'sample_weight' is currently not supported for FHE training.",
Expand All @@ -244,41 +216,13 @@ def test_encrypted_fit_warning_error_raises(n_bits, max_iter, parameter_min_max)
with pytest.raises(NotImplementedError, match="Target values must be 1D.*"):
model.partial_fit(x, y_2d, fhe="disable")

with pytest.warns(
UserWarning,
match="FHE training is an experimental feature. "
"Please be aware that the API might change in future versions.",
):
model = SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
random_state=random_state,
parameters_range=parameters_range,
max_iter=max_iter,
loss="log_loss",
)

with pytest.warns(UserWarning, match="ONNX Preprocess - Removing mutation from node .*"):
model.fit(x, y, fhe="disable")

with pytest.raises(NotImplementedError, match=""):
model.loss = "random"
model.predict_proba(x)

with pytest.warns(
UserWarning,
match="FHE training is an experimental feature. "
"Please be aware that the API might change in future versions.",
):
model = SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
random_state=random_state,
parameters_range=parameters_range,
max_iter=max_iter,
loss="log_loss",
)

assert isinstance(y, numpy.ndarray)

with pytest.raises(
Expand Down

0 comments on commit 3fc2c0a

Please sign in to comment.