Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve binary classification check in encrypted training #671

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,10 @@ def _fit_encrypted(

assert isinstance(self.classes_, numpy.ndarray)

if len(self.classes_) != 2:
# Allow the training set to only provide a single class. This can happen, for example,
# when running 'partial_fit' on a small batch of values. Even with a single class, the
# model remains binary
if len(self.classes_) not in [1, 2]:
raise NotImplementedError(
f"Only binary classification is currently supported when FHE training is "
f"enabled. Got {len(self.classes_)} labels: {self.classes_}."
Expand Down
43 changes: 36 additions & 7 deletions tests/sklearn/test_fhe_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,15 @@
from concrete.ml.sklearn import SGDClassifier


def get_blob_data(binary_targets=True, scale_input=False, parameters_range=None):
def get_blob_data(n_classes=2, scale_input=False, parameters_range=None):
"""Get the training data."""

n_samples = 1000
n_features = 8

# Determine the number of target classes to generate
centers = 2 if binary_targets else 3

# Generate the input and target values
# pylint: disable-next=unbalanced-tuple-unpacking
x, y = make_blobs(n_samples=n_samples, centers=centers, n_features=n_features)
x, y = make_blobs(n_samples=n_samples, centers=n_classes, n_features=n_features)

# Scale the input values if needed
if scale_input:
Expand Down Expand Up @@ -107,7 +104,7 @@ def test_fit_error_if_non_binary_targets(n_bits, max_iter, parameter_min_max):
parameters_range = (-parameter_min_max, parameter_min_max)

# Generate a data-set with three target classes
x, y = get_blob_data(binary_targets=False)
x, y = get_blob_data(n_classes=3)

with warnings.catch_warnings():

Expand Down Expand Up @@ -136,6 +133,37 @@ def test_fit_error_if_non_binary_targets(n_bits, max_iter, parameter_min_max):
model.partial_fit(x, y, fhe="disable")


@pytest.mark.parametrize("n_bits, max_iter, parameter_min_max", [pytest.param(7, 30, 1.0)])
def test_fit_single_target_class(n_bits, max_iter, parameter_min_max):
"""Test that training in FHE on a data-set with a single target class works properly."""

# Model parameters
random_state = numpy.random.randint(0, 2**15)
parameters_range = (-parameter_min_max, parameter_min_max)

# Generate a data-set with a single target class
x, y = get_blob_data(n_classes=1)

with warnings.catch_warnings():

# FHE training is an experimental feature and a warning is raised each time `fit_encrypted`
# is set to True
warnings.filterwarnings("ignore", message="FHE training is an experimental feature.*")

model = SGDClassifier(
n_bits=n_bits,
fit_encrypted=True,
random_state=random_state,
parameters_range=parameters_range,
max_iter=max_iter,
)

with pytest.warns(UserWarning, match="ONNX Preprocess - Removing mutation from node .*"):
model.fit(x, y, fhe="disable")

model.partial_fit(x, y, fhe="disable")


def test_clear_fit_error_raises():
"""Test that training in clear using wrong parameters raises proper errors."""

Expand Down Expand Up @@ -285,9 +313,10 @@ def test_clear_fit(
# Model parameters
random_state = numpy.random.randint(0, 2**15)
parameters_range = (-parameter_min_max, parameter_min_max)
n_classes = 2 if binary else 3

# Generate a data-set
x, y = get_blob_data(binary_targets=binary, scale_input=True, parameters_range=parameters_range)
x, y = get_blob_data(n_classes=n_classes, scale_input=True, parameters_range=parameters_range)

random_state = numpy.random.randint(0, 2**15)

Expand Down
Loading