Skip to content

Commit

Permalink
chore: improve inference tests (decision_function, predict, predict_p…
Browse files Browse the repository at this point in the history
…roba) + post_processing
  • Loading branch information
RomanBredehoft authored Oct 4, 2023
1 parent c92d353 commit 152a2e2
Show file tree
Hide file tree
Showing 8 changed files with 561 additions and 501 deletions.
59 changes: 28 additions & 31 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,41 +288,40 @@ def check_circuit_precision():
return check_circuit_precision_impl


def check_array_equality_impl(actual: Any, expected: Any, verbose: bool = True):
"""Assert that `actual` is equal to `expected`."""

assert numpy.array_equal(actual, expected), (
""
if not verbose
else f"""
@pytest.fixture
def check_array_equal():
"""Fixture to check array equality."""

Expected Output
===============
{expected}
def check_array_equal_impl(actual: Any, expected: Any, verbose: bool = True):
"""Assert that `actual` is equal to `expected`."""

Actual Output
=============
{actual}
assert numpy.array_equal(actual, expected), (
""
if not verbose
else f"""
"""
)
Expected Output
===============
{expected}
Actual Output
=============
{actual}
@pytest.fixture
def check_array_equality():
"""Fixture to check array equality."""
"""
)

return check_array_equality_impl
return check_array_equal_impl


@pytest.fixture
def check_float_arrays_equal():
def check_float_array_equal():
"""Fixture to check if two float arrays are equal with epsilon precision tolerance."""

def check_float_arrays_equal_impl(a, b):
def check_float_array_equal_impl(a, b):
assert numpy.all(numpy.isclose(a, b, rtol=0, atol=0.001))

return check_float_arrays_equal_impl
return check_float_array_equal_impl


@pytest.fixture
Expand Down Expand Up @@ -492,15 +491,13 @@ def check_is_good_execution_for_cml_vs_circuit_impl(
# as much post-processing steps in the clear (that could lead to more flaky
# tests), especially since these results are tested in other tests such as the
# `check_subfunctions_in_fhe`
if is_classifier_or_partial_classifier(model):
if isinstance(model, SklearnKNeighborsMixin):
# For KNN `predict_proba` is not supported for now
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962
results_cnp_circuit = model.predict(*inputs, fhe=fhe_mode)
results_model = model.predict(*inputs, fhe="disable")
else:
results_cnp_circuit = model.predict_proba(*inputs, fhe=fhe_mode)
results_model = model.predict_proba(*inputs, fhe="disable")
# For KNN `predict_proba` is not supported for now
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962
if is_classifier_or_partial_classifier(model) and not isinstance(
model, SklearnKNeighborsMixin
):
results_cnp_circuit = model.predict_proba(*inputs, fhe=fhe_mode)
results_model = model.predict_proba(*inputs, fhe="disable")

else:
results_cnp_circuit = model.predict(*inputs, fhe=fhe_mode)
Expand Down
23 changes: 6 additions & 17 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,6 @@ def fhe_circuit(self) -> Optional[Circuit]:
assert isinstance(self.fhe_circuit_, Circuit) or self.fhe_circuit_ is None
return self.fhe_circuit_

@fhe_circuit.setter
def fhe_circuit(self, value: Circuit) -> None:
"""Set the FHE circuit.
Args:
value (Circuit): The FHE circuit to set.
"""
self.fhe_circuit_ = value

def _sklearn_model_is_not_fitted_error_message(self) -> str:
return (
f"The underlying model (class: {self.sklearn_model_class}) is not fitted and thus "
Expand Down Expand Up @@ -556,7 +547,7 @@ def compile(

# Jit compiler is now deprecated and will soon be removed, it is thus forced to False
# by default
self.fhe_circuit = module_to_compile.compile(
self.fhe_circuit_ = module_to_compile.compile(
inputset,
configuration=configuration,
artifacts=artifacts,
Expand All @@ -570,14 +561,16 @@ def compile(
jit=False,
)

# For mypy
assert isinstance(self.fhe_circuit, Circuit)

# CRT simulation is not supported yet
# TODO: https://github.com/zama-ai/concrete-ml-internal/issues/3841
if not USE_OLD_VL:
self.fhe_circuit.enable_fhe_simulation() # pragma: no cover

self._is_compiled = True

assert isinstance(self.fhe_circuit, Circuit)
return self.fhe_circuit

@abstractmethod
Expand Down Expand Up @@ -883,10 +876,6 @@ def output_quantizers(self, value: List[UniformQuantizer]) -> None:
def fhe_circuit(self) -> Circuit:
return self.quantized_module_.fhe_circuit

@fhe_circuit.setter
def fhe_circuit(self, value: Circuit) -> None:
self.quantized_module_.fhe_circuit = value

def get_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator.
Expand Down Expand Up @@ -2093,11 +2082,11 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.

topk_labels = []
for query in X:
topk_labels.append(super().predict(query[None], fhe))
topk_labels.append(BaseEstimator.predict(self, query[None], fhe=fhe))

y_preds = self.post_processing(numpy.array(topk_labels))

return numpy.array(y_preds)
return y_preds


class SklearnKNeighborsClassifierMixin(SklearnKNeighborsMixin, sklearn.base.ClassifierMixin, ABC):
Expand Down
28 changes: 26 additions & 2 deletions src/concrete/ml/sklearn/neighbors.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Implement sklearn linear model."""
from typing import Any, Dict
from typing import Any, Dict, Union

import numpy
import sklearn.linear_model

from ..common.debugging.custom_assert import assert_true
from .base import SklearnKNeighborsClassifierMixin
from ..common.utils import FheMode
from .base import Data, SklearnKNeighborsClassifierMixin


# pylint: disable=invalid-name,too-many-instance-attributes
Expand Down Expand Up @@ -123,3 +124,26 @@ def load_dict(cls, metadata: Dict):
obj.metric_params = metadata["metric_params"]
obj.n_jobs = metadata["n_jobs"]
return obj

# KNeighborsClassifier does not provide a predict_proba method for now
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962
def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
"""Predict class probabilities.
Args:
X (Data): The input values to predict, as a Numpy array, Torch tensor, Pandas DataFrame
or List.
fhe (Union[FheMode, str]): The mode to use for prediction.
Can be FheMode.DISABLE for Concrete ML Python inference,
FheMode.SIMULATE for FHE simulation and FheMode.EXECUTE for actual FHE execution.
Can also be the string representation of any of these values.
Default to FheMode.DISABLE.
Raises:
NotImplementedError: The method is not implemented for now.
"""

raise NotImplementedError(
"The `predict_proba` method is not implemented for KNeighborsClassifier. Please "
"call `predict` instead."
)
Loading

0 comments on commit 152a2e2

Please sign in to comment.