Skip to content

Commit

Permalink
chore: remove manual warning catch in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed May 28, 2024
1 parent 0931788 commit 95dc517
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 354 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ filterwarnings = [
"ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning",
"ignore:Converting a tensor to a NumPy array might cause the trace to be incorrect.",
"ignore:torch.from_numpy results are registered as constants in the trace.",
"ignore:ONNX Preprocess - Removing mutation from node aten*:UserWarning",
"ignore:Liblinear failed to converge, increase the number of iterations.*:sklearn.exceptions.ConvergenceWarning",
]

[tool.semantic_release]
Expand Down
14 changes: 1 addition & 13 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,6 @@ def __init__(
if self.fit_encrypted:
self.classes_: Optional[numpy.ndarray] = None

warnings.warn(
"FHE training is an experimental feature. Please be aware that the API might "
"change in future versions.",
stacklevel=2,
)

# Check the presence of mandatory attributes
if self.loss != "log_loss":
raise ValueError(
Expand Down Expand Up @@ -687,13 +681,7 @@ def fit( # type: ignore[override]

# If the model should be trained using FHE training
if self.fit_encrypted:
if fhe is None:
fhe = "disable"
warnings.warn(
"Parameter 'fhe' isn't set while FHE training is enabled.\n"
f"Defaulting to '{fhe=}'",
stacklevel=2,
)
fhe = "disable" if fhe is None else fhe

# Make sure the `fhe` parameter is correct
assert FheMode.is_valid(fhe), (
Expand Down
8 changes: 2 additions & 6 deletions tests/common/test_pbs_error_probability_settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Tests for the sklearn linear models."""

import warnings
from inspect import signature

import numpy
import pytest
from sklearn.exceptions import ConvergenceWarning
from torch import nn

from concrete.ml.pytest.torch_models import FCSmall
Expand Down Expand Up @@ -34,10 +32,8 @@ def test_config_sklearn(model_class, parameters, kwargs, load_data):

model = model_class()

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
# Fit the model
model.fit(x, y)
# Fit the model
model.fit(x, y)

if kwargs.get("p_error", None) is not None and kwargs.get("global_p_error", None) is not None:
with pytest.raises(ValueError) as excinfo:
Expand Down
6 changes: 1 addition & 5 deletions tests/common/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import inspect
import io
import warnings
from functools import partial

import numpy
Expand All @@ -18,7 +17,6 @@
from concrete.fhe.compilation import Circuit
from numpy.random import RandomState
from sklearn.datasets import make_regression
from sklearn.exceptions import ConvergenceWarning
from skops.io.exceptions import UntrustedTypesFoundException
from skorch.dataset import ValidSplit
from torch import nn
Expand Down Expand Up @@ -123,9 +121,7 @@ def test_serialize_sklearn_model(concrete_model_class, load_data):
# Instantiate and fit a Concrete model to recover its underlying Scikit Learn model
concrete_model = concrete_model_class()

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
_, sklearn_model = concrete_model.fit_benchmark(x, y)
_, sklearn_model = concrete_model.fit_benchmark(x, y)

# Both JSON string are not compared as scikit-learn models are serialized using Skops or pickle,
# which does not make string comparison possible
Expand Down
28 changes: 8 additions & 20 deletions tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
import json
import os
import tempfile
import warnings
import zipfile
from functools import partial
from pathlib import Path
from shutil import copyfile

import numpy
import pytest
from sklearn.exceptions import ConvergenceWarning
from torch import nn

from concrete.ml.deployment.fhe_client_server import (
Expand Down Expand Up @@ -115,15 +113,10 @@ def test_client_server_sklearn(
x_test = x[-1:]

# Instantiate the model
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
model = instantiate_model_generic(model_class, n_bits=n_bits)
model = instantiate_model_generic(model_class, n_bits=n_bits)

# Fit the model
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
warnings.simplefilter("ignore", category=UserWarning)
model.fit(x_train, y_train)
model.fit(x_train, y_train)

key_dir = default_configuration.insecure_key_cache_location

Expand Down Expand Up @@ -393,19 +386,14 @@ def test_save_mode_handling(n_bits, fit_encrypted, mode, error_message):
y_train = y[:-1]

# Instantiate the model
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
parameters_range = [-1, 1] if fit_encrypted else None
model = instantiate_model_generic(
partial(SGDClassifier, fit_encrypted=fit_encrypted, parameters_range=parameters_range),
n_bits=n_bits,
)
parameters_range = [-1, 1] if fit_encrypted else None
model = instantiate_model_generic(
partial(SGDClassifier, fit_encrypted=fit_encrypted, parameters_range=parameters_range),
n_bits=n_bits,
)

# Fit the model
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
warnings.simplefilter("ignore", category=UserWarning)
model.fit(x_train, y_train)
model.fit(x_train, y_train)

# Compile
model.compile(X=x_train)
Expand Down
29 changes: 8 additions & 21 deletions tests/parameter_search/test_p_error_binary_search.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Test binary search class."""

import os
import warnings
from pathlib import Path

import numpy
import pytest
import torch
from sklearn.datasets import make_classification
from sklearn.exceptions import ConvergenceWarning
from sklearn.metrics import r2_score, top_k_accuracy_score
from tensorflow import keras

Expand Down Expand Up @@ -126,9 +124,8 @@ def test_update_valid_attr_method(attr, value, model_name, quant_type, metric, l
predict="predict",
n_simulation=1,
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
search.run(x=x_calib, ground_truth=y, strategy=all, **{attr: value})

search.run(x=x_calib, ground_truth=y, strategy=all, **{attr: value})

assert getattr(search, attr) == value

Expand Down Expand Up @@ -167,10 +164,7 @@ def test_non_convergence_for_built_in_models(model_class, parameters, load_data,
max_metric_loss=-10,
is_qat=False,
)

warnings.simplefilter("always")
with pytest.warns(UserWarning, match="ConvergenceWarning: .*"):
search.run(x=x_calib, ground_truth=y, strategy=all)
search.run(x=x_calib, ground_truth=y, strategy=all)


@pytest.mark.parametrize("model_name, quant_type", [("CustomModel", "qat")])
Expand Down Expand Up @@ -205,9 +199,7 @@ def test_non_convergence_for_custom_models(model_name, quant_type):
labels=numpy.arange(MODELS_ARGS[model_name]["dataset"]["n_classes"]),
)

warnings.simplefilter("always")
with pytest.warns(UserWarning, match="ConvergenceWarning: .*"):
search.run(x=x_calib, ground_truth=y, strategy=all)
search.run(x=x_calib, ground_truth=y, strategy=all)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -279,9 +271,8 @@ def test_binary_search_for_custom_models(model_name, quant_type, threshold):
k=1,
labels=numpy.arange(MODELS_ARGS[model_name]["dataset"]["n_classes"]),
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
largest_perror = search.run(x=x_calib, ground_truth=y, strategy=all)

largest_perror = search.run(x=x_calib, ground_truth=y, strategy=all)

assert 1.0 > largest_perror > 0.0
assert (
Expand Down Expand Up @@ -337,9 +328,7 @@ def test_binary_search_for_built_in_models(model_class, parameters, threshold, p
# The model does not have `predict`
return

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
largest_perror = search.run(x=x_calib, ground_truth=y, strategy=all)
largest_perror = search.run(x=x_calib, ground_truth=y, strategy=all)

assert 1.0 > largest_perror > 0.0
assert (
Expand Down Expand Up @@ -475,9 +464,7 @@ def test_success_save_option(model_name, quant_type, metric, directory, log_file
# When instantiating the class, if the file exists, it is deleted, to avoid overwriting it
assert not path.exists()

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
search.run(x=x_calib, ground_truth=y)
search.run(x=x_calib, ground_truth=y)

# Check that the file has been properly created
assert path.exists()
Expand Down
17 changes: 4 additions & 13 deletions tests/seeding/test_seeding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

import inspect
import random
import warnings

import numpy
import pytest
from sklearn.exceptions import ConvergenceWarning
from sklearn.tree import plot_tree

from concrete.ml.common.utils import get_model_name
from concrete.ml.pytest.utils import MODELS_AND_DATASETS


Expand Down Expand Up @@ -93,16 +90,10 @@ def test_seed_sklearn(model_class, parameters, load_data, default_configuration)
model_params["random_state"] = numpy.random.randint(0, 2**15)

# First case: user gives his own random_state
# Warning skip due to SGDClassifier with fit_encrypted = True
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
model = model_class(**model_params)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
warnings.simplefilter("ignore", category=UserWarning)
# Fit the model
model, sklearn_model = model.fit_benchmark(x, y)
model = model_class(**model_params)

# Fit the model
model, sklearn_model = model.fit_benchmark(x, y)

lpvoid_ptr_plot_tree = getattr(model, "plot_tree", None)
if callable(lpvoid_ptr_plot_tree):
Expand Down
20 changes: 6 additions & 14 deletions tests/sklearn/test_common.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Tests common to all sklearn models."""

import inspect
import warnings

import numpy
import pytest
from sklearn.exceptions import ConvergenceWarning

from concrete.ml.common.utils import get_model_class
from concrete.ml.pytest.utils import MODELS_AND_DATASETS
Expand Down Expand Up @@ -44,10 +42,8 @@ def test_seed_sklearn(model_class, parameters, load_data):
# First case: user gives his own random_state
model = model_class(random_state=random_state_constructor)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
# Fit the model
model, sklearn_model = model.fit_benchmark(x, y, random_state=random_state_user)
# Fit the model
model, sklearn_model = model.fit_benchmark(x, y, random_state=random_state_user)

assert (
model.random_state == random_state_user and sklearn_model.random_state == random_state_user
Expand All @@ -56,10 +52,8 @@ def test_seed_sklearn(model_class, parameters, load_data):
# Second case: user does not give random_state but seeds the constructor
model = model_class(random_state=random_state_constructor)

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
# Fit the model
model, sklearn_model = model.fit_benchmark(x, y)
# Fit the model
model, sklearn_model = model.fit_benchmark(x, y)

assert (model.random_state == random_state_constructor) and (
sklearn_model.random_state == random_state_constructor
Expand All @@ -69,10 +63,8 @@ def test_seed_sklearn(model_class, parameters, load_data):
model = model_class(random_state=None)
assert model.random_state is None

with warnings.catch_warnings():
warnings.simplefilter("ignore", category=ConvergenceWarning)
# Fit the model
model, sklearn_model = model.fit_benchmark(x, y)
# Fit the model
model, sklearn_model = model.fit_benchmark(x, y)

# model.random_state and sklearn_model.random_state should now be seeded with the same value
assert model.random_state is not None and sklearn_model.random_state is not None
Expand Down
6 changes: 0 additions & 6 deletions tests/sklearn/test_dump_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy
import onnx
import pytest
from sklearn.exceptions import ConvergenceWarning

from concrete.ml.common.utils import is_model_class_in_a_list
from concrete.ml.pytest.utils import UNIQUE_MODELS_AND_DATASETS, get_model_name
Expand Down Expand Up @@ -419,11 +418,6 @@ def check_onnx_file_dump(
# KNN can only be compiled with small quantization bit numbers for now
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979
model.n_bits = 2

with warnings.catch_warnings():
# Sometimes, we miss convergence, which is not a problem for our test
warnings.simplefilter("ignore", category=ConvergenceWarning)

model.fit(x, y)

with warnings.catch_warnings():
Expand Down
Loading

0 comments on commit 95dc517

Please sign in to comment.