diff --git a/conftest.py b/conftest.py index 5209a325e..32ba7bae0 100644 --- a/conftest.py +++ b/conftest.py @@ -33,6 +33,7 @@ from concrete.ml.sklearn.base import ( BaseTreeEstimatorMixin, QuantizedTorchEstimatorMixin, + SklearnKNeighborsMixin, SklearnLinearModelMixin, ) @@ -482,7 +483,12 @@ def check_is_good_execution_for_cml_vs_circuit_impl( else: assert isinstance( model, - (QuantizedTorchEstimatorMixin, BaseTreeEstimatorMixin, SklearnLinearModelMixin), + ( + QuantizedTorchEstimatorMixin, + BaseTreeEstimatorMixin, + SklearnLinearModelMixin, + SklearnKNeighborsMixin, + ), ) if model._is_a_public_cml_model: # pylint: disable=protected-access @@ -492,8 +498,14 @@ def check_is_good_execution_for_cml_vs_circuit_impl( # tests), especially since these results are tested in other tests such as the # `check_subfunctions_in_fhe` if is_classifier_or_partial_classifier(model): - results_cnp_circuit = model.predict_proba(*inputs, fhe=fhe_mode) - results_model = model.predict_proba(*inputs, fhe="disable") + if isinstance(model, SklearnKNeighborsMixin): + # For KNN `predict_proba` is not supported for now + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 + results_cnp_circuit = model.predict(*inputs, fhe=fhe_mode) + results_model = model.predict(*inputs, fhe="disable") + else: + results_cnp_circuit = model.predict_proba(*inputs, fhe=fhe_mode) + results_model = model.predict_proba(*inputs, fhe="disable") else: results_cnp_circuit = model.predict(*inputs, fhe=fhe_mode) diff --git a/src/concrete/ml/common/serialization/decoder.py b/src/concrete/ml/common/serialization/decoder.py index eebe4e25a..bd2f8ee74 100644 --- a/src/concrete/ml/common/serialization/decoder.py +++ b/src/concrete/ml/common/serialization/decoder.py @@ -87,6 +87,9 @@ def _get_fully_qualified_name(object_class: Type) -> str: "skorch.dataset.Dataset", "skorch.dataset.ValidSplit", "inspect._empty", + "sklearn.neighbors._classification.KNeighborsClassifier", + "sklearn.metrics._dist_metrics.EuclideanDistance", + "sklearn.neighbors._kd_tree.KDTree", ] ) diff --git a/src/concrete/ml/pytest/utils.py b/src/concrete/ml/pytest/utils.py index ce130991c..b996f15c4 100644 --- a/src/concrete/ml/pytest/utils.py +++ b/src/concrete/ml/pytest/utils.py @@ -18,6 +18,7 @@ DecisionTreeRegressor, ElasticNet, GammaRegressor, + KNeighborsClassifier, Lasso, LinearRegression, LinearSVC, @@ -66,6 +67,7 @@ ] _classifier_models = [ + KNeighborsClassifier, DecisionTreeClassifier, RandomForestClassifier, XGBClassifier, @@ -95,9 +97,25 @@ id=get_model_name(model), ) for model in _classifier_models + if get_model_name(model) != "KNeighborsClassifier" for n_classes in [2, 4] +] + [ + pytest.param( + model, + { + "n_samples": 6, + "n_features": 2, + "n_classes": n_classes, + "n_informative": 2, + "n_redundant": 0, + }, + id=get_model_name(model), + ) + for model in [KNeighborsClassifier] + for n_classes in [2] ] + # Get the data-sets. The data generation is seeded in load_data. # Only LinearRegression supports multi targets # GammaRegressor, PoissonRegressor and TweedieRegressor only handle positive target values @@ -141,8 +159,8 @@ def get_random_extract_of_sklearn_models_and_datasets(): unique_model_classes.append(m) # To avoid to make mistakes and return empty list - assert len(sklearn_models_and_datasets) == 28 - assert len(unique_model_classes) == 18 + assert len(sklearn_models_and_datasets) == 29 + assert len(unique_model_classes) == 19 return unique_model_classes diff --git a/src/concrete/ml/search_parameters/p_error_search.py b/src/concrete/ml/search_parameters/p_error_search.py index bc882937c..dbed2c1f7 100644 --- a/src/concrete/ml/search_parameters/p_error_search.py +++ b/src/concrete/ml/search_parameters/p_error_search.py @@ -61,7 +61,11 @@ from tqdm import tqdm from ..common.utils import is_brevitas_model, is_model_class_in_a_list -from ..sklearn import get_sklearn_neural_net_models, get_sklearn_tree_models +from ..sklearn import ( + get_sklearn_neighbors_models, + get_sklearn_neural_net_models, + get_sklearn_tree_models, +) from ..torch.compile import compile_brevitas_qat_model, compile_torch_model @@ -126,7 +130,10 @@ def compile_and_simulated_fhe_inference( dequantized_output = quantized_module.forward(calibration_data, fhe="simulate") elif is_model_class_in_a_list( - estimator, get_sklearn_neural_net_models() + get_sklearn_tree_models() + estimator, + get_sklearn_neural_net_models() + + get_sklearn_tree_models() + + get_sklearn_neighbors_models(), ): if not estimator.is_fitted: estimator.fit(calibration_data, ground_truth) diff --git a/src/concrete/ml/sklearn/__init__.py b/src/concrete/ml/sklearn/__init__.py index 1b938ac4d..06e5545f3 100644 --- a/src/concrete/ml/sklearn/__init__.py +++ b/src/concrete/ml/sklearn/__init__.py @@ -3,9 +3,16 @@ from ..common.debugging.custom_assert import assert_true from ..common.utils import is_classifier_or_partial_classifier, is_regressor_or_partial_regressor -from .base import _ALL_SKLEARN_MODELS, _LINEAR_MODELS, _NEURALNET_MODELS, _TREE_MODELS +from .base import ( + _ALL_SKLEARN_MODELS, + _LINEAR_MODELS, + _NEIGHBORS_MODELS, + _NEURALNET_MODELS, + _TREE_MODELS, +) from .glm import GammaRegressor, PoissonRegressor, TweedieRegressor from .linear_model import ElasticNet, Lasso, LinearRegression, LogisticRegression, Ridge +from .neighbors import KNeighborsClassifier from .qnn import NeuralNetClassifier, NeuralNetRegressor from .rf import RandomForestClassifier, RandomForestRegressor from .svm import LinearSVC, LinearSVR @@ -31,6 +38,7 @@ def get_sklearn_models(): "linear": sorted(list(_LINEAR_MODELS), key=lambda m: m.__name__), "tree": sorted(list(_TREE_MODELS), key=lambda m: m.__name__), "neural_net": sorted(list(_NEURALNET_MODELS), key=lambda m: m.__name__), + "neighbors": sorted(list(_NEIGHBORS_MODELS), key=lambda m: m.__name__), } return ans @@ -123,3 +131,21 @@ def get_sklearn_neural_net_models( """ prelist = get_sklearn_models()["neural_net"] return _filter_models(prelist, classifier, regressor, str_in_class_name) + + +def get_sklearn_neighbors_models( + classifier: bool = True, regressor: bool = True, str_in_class_name: List[str] = None +): + """Return the list of available neighbor models in Concrete ML. + + Args: + classifier (bool): whether you want classifiers or not + regressor (bool): whether you want regressors or not + str_in_class_name (List[str]): if not None, only return models with the given string or + list of strings as a substring in their class name + + Returns: + the lists of neighbor models in Concrete ML + """ + prelist = get_sklearn_models()["neighbors"] + return _filter_models(prelist, classifier, regressor, str_in_class_name) diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index 5ac220efd..0184615a7 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -19,6 +19,8 @@ import skorch.net import torch from brevitas.export.onnx.qonnx.manager import QONNXManager as BrevitasONNXManager +from concrete.fhe import array as fhe_array +from concrete.fhe import zeros as fhe_zeros from concrete.fhe.compilation.artifacts import DebugArtifacts from concrete.fhe.compilation.circuit import Circuit from concrete.fhe.compilation.compiler import Compiler @@ -60,11 +62,13 @@ # Silence Hummingbird warnings warnings.filterwarnings("ignore") from hummingbird.ml import convert as hb_convert # noqa: E402 +from hummingbird.ml.operator_converters import constants # noqa: E402 _ALL_SKLEARN_MODELS: Set[Type] = set() _LINEAR_MODELS: Set[Type] = set() _TREE_MODELS: Set[Type] = set() _NEURALNET_MODELS: Set[Type] = set() +_NEIGHBORS_MODELS: Set[Type] = set() # Define the supported types for both the input data and the target values. Since the Pandas # library is currently only a dev dependencies, we cannot import it. We therefore need to use type @@ -1690,3 +1694,382 @@ def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> y_logits = self.decision_function(X, fhe=fhe) y_proba = self.post_processing(y_logits) return y_proba + + +# pylint: disable-next=invalid-name,too-many-instance-attributes +class SklearnKNeighborsMixin(BaseEstimator, sklearn.base.BaseEstimator, ABC): + """A Mixin class for sklearn KNeighbors models with FHE. + + This class inherits from sklearn.base.BaseEstimator in order to have access to scikit-learn's + `get_params` and `set_params` methods. + """ + + def __init_subclass__(cls): + for klass in cls.__mro__: + # pylint: disable-next=protected-access + if getattr(klass, "_is_a_public_cml_model", False): + _NEIGHBORS_MODELS.add(cls) + _ALL_SKLEARN_MODELS.add(cls) + + def __init__(self, n_bits: int = 3): + """Initialize the FHE knn model. + + Args: + n_bits (int): Number of bits to quantize the model. IThe value will be used for + quantizing inputs and X_fit. Default to 3. + """ + self.n_bits: int = n_bits + # _q_fit_X: In distance metric algorithms, `_q_fit_X` stores the training set to compute + # the similarity or distance measures. There is no `weights` attribute because there isn't + # a training phase + self._q_fit_X: numpy.ndarray + # _y: Labels of `_q_fit_X` + self._y: numpy.ndarray + # _q_fit_X_quantizer: The quantizer to use for quantizing the model's training set + self._q_fit_X_quantizer: Optional[UniformQuantizer] = None + + BaseEstimator.__init__(self) + + def _set_onnx_model(self, test_input: numpy.ndarray) -> None: + """Retrieve the model's ONNX graph using Hummingbird conversion. + + Args: + test_input (numpy.ndarray): An input data used to trace the model execution. + """ + # Check that the underlying sklearn model has been set and fit + assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message() + + self.onnx_model_ = hb_convert( + self.sklearn_model, + backend="onnx", + test_input=test_input, + extra_config={ + "onnx_target_opset": OPSET_VERSION_FOR_ONNX_EXPORT, + # pylint: disable-next=protected-access, no-member + constants.BATCH_SIZE: self.sklearn_model._fit_X.shape[0], + }, + ).model + + self._clean_graph() + + def _clean_graph(self) -> None: + """Clean the ONNX graph from undesired nodes.""" + assert self.onnx_model_ is not None, self._is_not_fitted_error_message() + + # Remove cast operators as they are not needed + remove_node_types(onnx_model=self.onnx_model_, op_types_to_remove=["Cast"]) + + def fit(self, X: Data, y: Target, **fit_parameters): + # Reset for double fit + self._is_fitted = False + self.input_quantizers = [] + self.output_quantizers = [] + + # KNeighbors handles multi-labels data + X, y = check_X_y_and_assert_multi_output(X, y) + + # Fit the scikit-learn model + self._fit_sklearn_model(X, y, **fit_parameters) + + # Check that the underlying sklearn model has been set and fit + assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message() + + # Retrieve the ONNX graph + self._set_onnx_model(X) + + # Quantize the inputs and store the associated quantizer + input_options = QuantizationOptions(n_bits=self.n_bits, is_signed=True) + q_inputs = QuantizedArray(n_bits=self.n_bits, values=X, options=input_options) + input_quantizer = q_inputs.quantizer + self.input_quantizers.append(input_quantizer) + + # Quantize the _fit_X and store the associated quantizer + # pylint: disable-next=protected-access + _fit_X = self.sklearn_model._fit_X + # We assume that the inputs have the same distribution as the _fit_X + q_fit_X = QuantizedArray( + n_bits=self.n_bits, + values=numpy.expand_dims(_fit_X, axis=1) if len(_fit_X.shape) == 1 else _fit_X, + options=input_options, + ) + self._q_fit_X = q_fit_X.qvalues + self._q_fit_X_quantizer = q_fit_X.quantizer + + # mypy + assert self._q_fit_X_quantizer.scale is not None + + self._y = numpy.array(y) + + # We assume that the query has the same distribution as the data in _X_fit. + # therefore, they use the same scaling and zero point. + # https://arxiv.org/abs/1712.05877 + + self.output_quant_params = UniformQuantizationParameters( + scale=self._q_fit_X_quantizer.scale, + zero_point=self._q_fit_X_quantizer.zero_point, + offset=0, + ) + + output_quantizer = UniformQuantizer(params=self.output_quant_params, no_clipping=True) + + assert output_quantizer.zero_point is not None + self.output_quantizers.append(output_quantizer) + + # Updating post-processing parameters + self._set_post_processing_params() + + self._is_fitted = True + + return self + + def quantize_input(self, X: numpy.ndarray) -> numpy.ndarray: + self.check_model_is_fitted() + q_X = self.input_quantizers[0].quant(X) + + assert q_X.dtype == numpy.int64, "Inputs were not quantized to int64 values" + return q_X + + def dequantize_output(self, q_y_preds: numpy.ndarray) -> numpy.ndarray: + self.check_model_is_fitted() + # We compute the sorted argmax in FHE, which are integers. + # No need to de-quantize the output values + return q_y_preds + + def _get_module_to_compile(self) -> Union[Compiler, QuantizedModule]: + # Define the inference function to compile. + # This function can neither be a class method nor a static one because self we want to avoid + # having self as a parameter while still being able to access some of its attribute + def inference_to_compile(q_X: numpy.ndarray) -> numpy.ndarray: + """Compile the circuit in FHE using only the inputs as parameters. + + Args: + q_X (numpy.ndarray): The quantized input data + + Returns: + numpy.ndarray: The circuit is outputs. + """ + return self._inference(q_X) + + # Create the compiler instance + compiler = Compiler(inference_to_compile, {"q_X": "encrypted"}) + + return compiler + + @staticmethod + def majority_vote(nearest_classes: numpy.ndarray): + """Determine the most common class among nearest neighborsfor each query. + + Args: + nearest_classes (numpy.ndarray): The class labels of the nearest neighbors for a query + + Returns: + numpy.ndarray: The majority-voted class label for the corresponding query. + """ + class_counts = numpy.bincount(nearest_classes) + majority_votes = numpy.argmax(class_counts) + + return majority_votes + + def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray: + """Inference function. + + Args: + q_X (numpy.ndarray): The quantized input values. + + Returns: + numpy.ndarray: The quantized predicted values. + """ + assert self._q_fit_X_quantizer is not None, self._is_not_fitted_error_message() + + def pairwise_euclidean_distance(q_X): + # 1. Pairwise euclidean distance + # dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y)) + return ( + numpy.sum(q_X**2, axis=1, keepdims=True) + - 2 * q_X @ self._q_fit_X.T + + numpy.expand_dims(numpy.sum(self._q_fit_X**2, axis=1), 0) + ) + + def topk_sorting(x, labels): + """Argsort in FHE. + + Time complexity: O(nlogĀ²(k)) + + Args: + x (numpy.ndarray): The quantized input values + labels (numpy.ndarray): The labels of the training data-set + + Returns: + numpy.ndarray: The argsort. + """ + + def gather1d(x, indices): + """Select elements from the input array `x` using the provided `indices`. + + Args: + x (numpy.ndarray): The encrypted input array + indices (numpy.ndarray): The desired indexes + + Returns: + numpy.ndarray: The selected encrypted indexes. + """ + arr = [] + for i in indices: + arr.append(x[i]) + enc_arr = fhe_array(arr) + return enc_arr + + def scatter1d(x, v, indices): + """Rearrange elements of `x` with values from `v` at the specified `indices`. + + Args: + x (numpy.ndarray): The encrypted input array in which items will be updated + v (numpy.ndarray): The array containing values to be inserted into `x` + at the specified `indices`. + indices (numpy.ndarray): The indices indicating where to insert the elements + from `v` into `x`. + + Returns: + numpy.ndarray: The updated encrypted `x` + """ + for idx, i in enumerate(indices): + x[i] = v[idx] + return x + + comparisons = numpy.zeros(x.shape) + labels = labels + fhe_zeros(labels.shape) + + n, k = x.size, self.n_neighbors + ln2n = int(numpy.ceil(numpy.log2(n))) + + # Number of stages + for t in range(ln2n - 1, -1, -1): + p = 2**t + r = 0 + # d: Length of the bitonic sequence + d = p + + for bq in range(ln2n - 1, t - 1, -1): + q = 2**bq + # Determine the range of indexes to be compared + range_i = numpy.array( + [i for i in range(0, n - d) if i & p == r and comparisons[i] < k] + ) + if len(range_i) == 0: + # Edge case, for k=1 + continue + + # Select 2 bitonic sequences `a` and `b` of length `d` + # a = x[range_i]: first bitonic sequence + # a_i = idx[range_i]: Indexes of a_i elements in the original x + a = gather1d(x, range_i) + # a_i = gather1d(idx, range_i) + # b = x[range_i + d]: Second bitonic sequence + # b_i = idx[range_i + d]: Indexes of b_i elements in the original x + b = gather1d(x, range_i + d) + # b_i = gather1d(idx, range_i + d) + + labels_a = gather1d(labels, range_i) # + labels_b = gather1d(labels, range_i + d) # idx[range_i + d] + + # Select max(a, b) + diff = a - b + max_x = a + numpy.maximum(0, b - a) + + # Swap if a > b + # x[range_i] = max_x(a, b): First bitonic sequence gets min(a, b) + x = scatter1d(x, a + b - max_x, range_i) + # x[range_i + d] = min(a, b): Second bitonic sequence gets max(a, b) + x = scatter1d(x, max_x, range_i + d) + + # Max index selection + is_a_greater_than_b = diff <= 0 + + # Update labels array according to the max items + max_labels = labels_a + (labels_b - labels_a) * is_a_greater_than_b + labels = scatter1d(labels, labels_a + labels_b - max_labels, range_i) + labels = scatter1d(labels, max_labels, range_i + d) + + # Update + comparisons[range_i + d] = comparisons[range_i + d] + 1 + d = q - p + r = p + + # Return only the topk indexes + topk_labels = [] + for i in range((self.n_neighbors)): + topk_labels.append(labels[i]) + + return fhe_array(topk_labels) + + # 1. Pairwise_euclidiean distance + distance_matrix = pairwise_euclidean_distance(q_X) + + # The square root in the Euclidean distance calculation is not applied to speed up FHE + # computations. + # Being a monotonic function, it does not affect the logic of the calculation, notably for + # the argsort. + + topk_labels = topk_sorting(distance_matrix.flatten(), self._y) + + return numpy.expand_dims(topk_labels, axis=0) + + # KNN works only for MONO in the latest concrete Python version + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978 + def compile(self, *args, **kwargs) -> Circuit: + # If a configuration instance is given as a positional parameter, set the strategy to + # multi-parameter + if len(args) >= 2: + configuration = force_mono_parameter_in_configuration(args[1]) + args_list = list(args) + args_list[1] = configuration + args = tuple(args_list) + + # Else, retrieve the configuration in kwargs if it exists, or create a new one, and set the + # strategy to multi-parameter + else: + configuration = kwargs.get("configuration", None) + kwargs["configuration"] = force_mono_parameter_in_configuration(configuration) + + return BaseEstimator.compile(self, *args, **kwargs) + + def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray: + """Perform the majority. + + For KNN, the de-quantization step is not required. Because _inference returns the label of + the k-nearest neighbors. + + Args: + y_preds (numpy.ndarray): The topk nearest labels + + Returns: + numpy.ndarray: The majority vote. + """ + y_preds_processed = [] + for y in y_preds: + vote = self.majority_vote(y.flatten()) + y_preds_processed.append(vote) + + return numpy.array(y_preds_processed) + + def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray: + + X = check_array_and_assert(X) + + topk_labels = [] + for query in X: + topk_labels.append(super().predict(query[None], fhe)) + + y_preds = self.post_processing(numpy.array(topk_labels)) + + return numpy.array(y_preds) + + +class SklearnKNeighborsClassifierMixin(SklearnKNeighborsMixin, sklearn.base.ClassifierMixin, ABC): + """A Mixin class for sklearn KNeighbors classifiers with FHE. + + This class is used to create a KNeighbors classifier class that inherits from + SklearnKNeighborsMixin and sklearn.base.ClassifierMixin. + By inheriting from sklearn.base.ClassifierMixin, it allows this class to be recognized + as a classifier." + """ diff --git a/src/concrete/ml/sklearn/neighbors.py b/src/concrete/ml/sklearn/neighbors.py new file mode 100644 index 000000000..368c9690b --- /dev/null +++ b/src/concrete/ml/sklearn/neighbors.py @@ -0,0 +1,125 @@ +"""Implement sklearn linear model.""" +from typing import Any, Dict + +import numpy +import sklearn.linear_model + +from ..common.debugging.custom_assert import assert_true +from .base import SklearnKNeighborsClassifierMixin + + +# pylint: disable=invalid-name,too-many-instance-attributes +class KNeighborsClassifier(SklearnKNeighborsClassifierMixin): + """A k-nearest classifier model with FHE. + + Parameters: + n_bits (int): Number of bits to quantize the model. The value will be used for quantizing + inputs and X_fit. Default to 3. + + For more details on KNeighborsClassifier please refer to the scikit-learn documentation: + https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html + """ + + sklearn_model_class = sklearn.neighbors.KNeighborsClassifier + _is_a_public_cml_model = True + + def __init__( + self, + n_bits=2, + n_neighbors=3, + *, + weights="uniform", + algorithm="auto", + leaf_size=30, + p=2, + metric="minkowski", + metric_params=None, + n_jobs=None, + ): + # Call SklearnKNeighborsClassifierMixin's __init__ method + super().__init__(n_bits=n_bits) + + assert_true( + algorithm in ["brute", "auto"], f"Algorithm = `{algorithm}` is not supported in FHE." + ) + assert_true( + not callable(metric), "The KNeighborsClassifier does not support custom metrics." + ) + assert_true( + p == 2 and metric == "minkowski", + "Only `L2` norm is supported with `p=2` and `metric = 'minkowski'`", + ) + + self._y: numpy.ndarray + self.n_neighbors = n_neighbors + self.algorithm = algorithm + self.leaf_size = leaf_size + self.p = p + self.metric = metric + self.metric_params = metric_params + self.n_jobs = n_jobs + self.weights = weights + + def dump_dict(self) -> Dict[str, Any]: + assert self._q_fit_X_quantizer is not None, self._is_not_fitted_error_message() + + metadata: Dict[str, Any] = {} + + # Concrete ML + metadata["n_bits"] = self.n_bits + metadata["sklearn_model"] = self.sklearn_model + metadata["_is_fitted"] = self._is_fitted + metadata["_is_compiled"] = self._is_compiled + metadata["input_quantizers"] = self.input_quantizers + metadata["_q_fit_X_quantizer"] = self._q_fit_X_quantizer + metadata["_q_fit_X"] = self._q_fit_X + metadata["_y"] = self._y + + metadata["output_quantizers"] = self.output_quantizers + metadata["onnx_model_"] = self.onnx_model_ + metadata["post_processing_params"] = self.post_processing_params + metadata["cml_dumped_class_name"] = type(self).__name__ + + # scikit-learn + metadata["sklearn_model_class"] = self.sklearn_model_class + metadata["n_neighbors"] = self.n_neighbors + metadata["algorithm"] = self.algorithm + metadata["weights"] = self.weights + metadata["leaf_size"] = self.leaf_size + metadata["p"] = self.p + metadata["metric"] = self.metric + metadata["metric_params"] = self.metric_params + metadata["n_jobs"] = self.n_jobs + + return metadata + + @classmethod + def load_dict(cls, metadata: Dict): + + # Instantiate the model + obj = KNeighborsClassifier() + + # Concrete-ML + obj.n_bits = metadata["n_bits"] + obj.sklearn_model = metadata["sklearn_model"] + obj._is_fitted = metadata["_is_fitted"] + obj._is_compiled = metadata["_is_compiled"] + obj.input_quantizers = metadata["input_quantizers"] + obj.output_quantizers = metadata["output_quantizers"] + obj._q_fit_X_quantizer = metadata["_q_fit_X_quantizer"] + obj._q_fit_X = metadata["_q_fit_X"] + obj._y = metadata["_y"] + + obj.onnx_model_ = metadata["onnx_model_"] + + obj.post_processing_params = metadata["post_processing_params"] + + # Scikit-Learn + obj.n_neighbors = metadata["n_neighbors"] + obj.weights = metadata["weights"] + obj.algorithm = metadata["algorithm"] + obj.p = metadata["p"] + obj.metric = metadata["metric"] + obj.metric_params = metadata["metric_params"] + obj.n_jobs = metadata["n_jobs"] + return obj diff --git a/tests/common/test_skearn_model_lists.py b/tests/common/test_skearn_model_lists.py index cd7fe34a2..dc38c716b 100644 --- a/tests/common/test_skearn_model_lists.py +++ b/tests/common/test_skearn_model_lists.py @@ -8,6 +8,7 @@ LogisticRegression, Ridge, ) +from concrete.ml.sklearn.neighbors import KNeighborsClassifier from concrete.ml.sklearn.qnn import NeuralNetClassifier, NeuralNetRegressor from concrete.ml.sklearn.rf import RandomForestClassifier, RandomForestRegressor from concrete.ml.sklearn.svm import LinearSVC, LinearSVR @@ -18,10 +19,12 @@ def test_get_sklearn_models(): """List all available models in Concrete ML.""" dic = get_sklearn_models() + cml_list = dic["all"] linear_list = dic["linear"] tree_list = dic["tree"] neuralnet_list = dic["neural_net"] + neighbors_list = dic["neighbors"] print("All models: ") for m in cml_list: @@ -39,6 +42,10 @@ def test_get_sklearn_models(): for m in neuralnet_list: print(f" {m}") + print("Neighbors models: ") + for m in neighbors_list: + print(f" {m}") + # Check values expected_neuralnet_list = [NeuralNetClassifier, NeuralNetRegressor] assert ( @@ -69,12 +76,18 @@ def test_get_sklearn_models(): Ridge, TweedieRegressor, ] + + expected_neighbors_list = [KNeighborsClassifier] + assert ( linear_list == expected_linear_list ), "Please change the expected number of models if you add new models" # Check number assert cml_list == sorted( - expected_linear_list + expected_neuralnet_list + expected_tree_list, + expected_linear_list + + expected_neuralnet_list + + expected_tree_list + + expected_neighbors_list, key=lambda m: m.__name__, ), "Please change the expected number of models if you add new models" diff --git a/tests/deployment/test_client_server.py b/tests/deployment/test_client_server.py index 783cd07ab..05c7fd53a 100644 --- a/tests/deployment/test_client_server.py +++ b/tests/deployment/test_client_server.py @@ -14,11 +14,15 @@ from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer from concrete.ml.pytest.torch_models import FCSmall -from concrete.ml.pytest.utils import instantiate_model_generic, sklearn_models_and_datasets +from concrete.ml.pytest.utils import ( + get_model_name, + instantiate_model_generic, + sklearn_models_and_datasets, +) from concrete.ml.quantization.quantized_module import QuantizedModule from concrete.ml.torch.compile import compile_torch_model -# pylint: disable=too-many-statements +# pylint: disable=too-many-statements,too-many-locals class OnDiskNetwork: @@ -102,6 +106,14 @@ def test_client_server_sklearn( fhe_circuit = model.compile( x_train, default_configuration, **extra_params, show_mlir=(n_bits <= 8) ) + + if get_model_name(model) == "KNeighborsClassifier": + # Fit the model + with warnings.catch_warnings(): + # Sometimes, we miss convergence, which is not a problem for our test + warnings.simplefilter("ignore", category=ConvergenceWarning) + model.fit(x, y) + max_bit_width = fhe_circuit.graph.maximum_integer_bit_width() print(f"Max width {max_bit_width}") @@ -258,5 +270,10 @@ def client_server_simulation(x_train, x_test, model, default_configuration): y_pred_on_client_dequantized, y_pred_model_server_ds_dequantized ) + # Make sure the clear predictions are the same for the server + if get_model_name(model) == "KNeighborsClassifier": + y_pred_model_clear = model.predict(x_test, fhe="disable") + numpy.testing.assert_array_equal(y_pred_model_clear, y_pred_model_server_ds_dequantized) + # Clean up network.cleanup() diff --git a/tests/parameter_search/test_p_error_binary_search.py b/tests/parameter_search/test_p_error_binary_search.py index d4cc20495..5ab3ffee6 100644 --- a/tests/parameter_search/test_p_error_binary_search.py +++ b/tests/parameter_search/test_p_error_binary_search.py @@ -312,7 +312,13 @@ def test_binary_search_for_built_in_models(model_class, parameters, threshold, p # Skorch but since Scikit-Learn does not, we don't as well. This issue could be fixed by making # neural networks not inherit from Skorch. # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3373 - if predict == "predict_proba" and get_model_name(model_class) == "NeuralNetRegressor": + # Skipping predict_proba for KNN, doesn't work for now. + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 + + if predict == "predict_proba" and get_model_name(model_class) in [ + "NeuralNetRegressor", + "KNeighborsClassifier", + ]: return metric = r2_score if is_regressor_or_partial_regressor(model) else binary_classification_metric diff --git a/tests/sklearn/test_common.py b/tests/sklearn/test_common.py index 3ce9dcede..54ba6d378 100644 --- a/tests/sklearn/test_common.py +++ b/tests/sklearn/test_common.py @@ -10,6 +10,7 @@ from concrete.ml.pytest.utils import sklearn_models_and_datasets from concrete.ml.sklearn import ( get_sklearn_linear_models, + get_sklearn_neighbors_models, get_sklearn_neural_net_models, get_sklearn_tree_models, ) @@ -19,7 +20,10 @@ def test_sklearn_args(): """Check that all arguments from the underlying sklearn model are exposed.""" test_counter = 0 for model_class in ( - get_sklearn_linear_models() + get_sklearn_neural_net_models() + get_sklearn_tree_models() + get_sklearn_linear_models() + + get_sklearn_neural_net_models() + + get_sklearn_tree_models() + + get_sklearn_neighbors_models() ): model_class = get_model_class(model_class) @@ -32,7 +36,7 @@ def test_sklearn_args(): ) test_counter += 1 - assert test_counter == 18 + assert test_counter == 19 @pytest.mark.parametrize("model_class, parameters", sklearn_models_and_datasets) diff --git a/tests/sklearn/test_dump_onnx.py b/tests/sklearn/test_dump_onnx.py index 00b22c4a9..e2957788f 100644 --- a/tests/sklearn/test_dump_onnx.py +++ b/tests/sklearn/test_dump_onnx.py @@ -35,6 +35,11 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau model.set_params(**model_params) + if get_model_name(model) == "KNeighborsClassifier": + # KNN works only for small quantization bits + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 + model.n_bits = 2 + with warnings.catch_warnings(): # Sometimes, we miss convergence, which is not a problem for our test warnings.simplefilter("ignore", category=ConvergenceWarning) @@ -56,6 +61,7 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau "RandomForestClassifier", "RandomForestRegressor", "XGBClassifier", + "KNeighborsClassifier", ]: while len(onnx_model.graph.initializer) > 0: del onnx_model.graph.initializer[0] @@ -415,6 +421,43 @@ def test_dump( ) { %variable = Gemm[alpha = 1, beta = 1](%input_0, %_operators.0.coefficients, %_operators.0.intercepts) return %variable +}""", + "KNeighborsClassifier": """graph torch_jit ( + %input_0[DOUBLE, symx2] +) { + %/_operators.0/Constant_output_0 = Constant[value = ]() + %/_operators.0/Unsqueeze_output_0 = Unsqueeze(%input_0, %/_operators.0/Constant_output_0) + %/_operators.0/Constant_1_output_0 = Constant[value = ]() + %/_operators.0/Sub_output_0 = Sub(%/_operators.0/Unsqueeze_output_0, %onnx::Sub_46) + %/_operators.0/Constant_2_output_0 = Constant[value = ]() + %/_operators.0/Pow_output_0 = Pow(%/_operators.0/Sub_output_0, %/_operators.0/Constant_2_output_0) + %/_operators.0/Constant_3_output_0 = Constant[value = ]() + %/_operators.0/ReduceSum_output_0 = ReduceSum[keepdims = 0, noop_with_empty_axes = 0](%/_operators.0/Pow_output_0, %/_operators.0/Constant_3_output_0) + %/_operators.0/Pow_1_output_0 = Pow(%/_operators.0/ReduceSum_output_0, %/_operators.0/Constant_1_output_0) + %/_operators.0/Constant_4_output_0 = Constant[value = ]() + %/_operators.0/TopK_output_0, %/_operators.0/TopK_output_1 = TopK[axis = 1, largest = 0, sorted = 1](%/_operators.0/Pow_1_output_0, %/_operators.0/Constant_4_output_0) + %/_operators.0/Constant_5_output_0 = Constant[value = ]() + %/_operators.0/Reshape_output_0 = Reshape[allowzero = 0](%/_operators.0/TopK_output_1, %/_operators.0/Constant_5_output_0) + %/_operators.0/Gather_output_0 = Gather[axis = 0](%_operators.0.train_labels, %/_operators.0/Reshape_output_0) + %/_operators.0/Shape_output_0 = Shape(%/_operators.0/TopK_output_1) + %/_operators.0/ConstantOfShape_output_0 = ConstantOfShape[value = ](%/_operators.0/Shape_output_0) + %/_operators.0/Constant_6_output_0 = Constant[value = ]() + %/_operators.0/Reshape_1_output_0 = Reshape[allowzero = 0](%/_operators.0/Gather_output_0, %/_operators.0/Constant_6_output_0) + %/_operators.0/Constant_7_output_0 = Constant[value = ]() + %/_operators.0/ScatterElements_output_0 = ScatterElements[axis = 1](%/_operators.0/Constant_7_output_0, %/_operators.0/Reshape_1_output_0, %/_operators.0/ConstantOfShape_output_0) + %/_operators.0/Constant_8_output_0 = Constant[value = ]() + %/_operators.0/Add_output_0 = Add(%/_operators.0/Constant_8_output_0, %/_operators.0/ScatterElements_output_0) + %onnx::ReduceSum_36 = Constant[value = ]() + %/_operators.0/ReduceSum_1_output_0 = ReduceSum[keepdims = 1](%/_operators.0/Add_output_0, %onnx::ReduceSum_36) + %/_operators.0/Constant_9_output_0 = Constant[value = ]() + %/_operators.0/Equal_output_0 = Equal(%/_operators.0/ReduceSum_1_output_0, %/_operators.0/Constant_9_output_0) + %/_operators.0/Constant_10_output_0 = Constant[value = ]() + %/_operators.0/Where_output_0 = Where(%/_operators.0/Equal_output_0, %/_operators.0/Constant_10_output_0, %/_operators.0/ReduceSum_1_output_0) + %/_operators.0/Constant_11_output_0 = Constant[value = ]() + %/_operators.0/Pow_2_output_0 = Pow(%/_operators.0/Where_output_0, %/_operators.0/Constant_11_output_0) + %onnx::ArgMax_44 = Mul(%/_operators.0/Pow_2_output_0, %/_operators.0/Add_output_0) + %variable = ArgMax[axis = 1, keepdims = 0, select_last_index = 0](%onnx::ArgMax_44) + return %variable, %onnx::ArgMax_44 }""", } diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index 13c33e12c..307e412d3 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -41,7 +41,7 @@ import torch from concrete.fhe import ParameterSelectionStrategy from sklearn.decomposition import PCA -from sklearn.exceptions import ConvergenceWarning +from sklearn.exceptions import ConvergenceWarning, UndefinedMetricWarning from sklearn.metrics import make_scorer, matthews_corrcoef, top_k_accuracy_score from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline @@ -64,6 +64,7 @@ ) from concrete.ml.sklearn import ( get_sklearn_linear_models, + get_sklearn_neighbors_models, get_sklearn_neural_net_models, get_sklearn_tree_models, ) @@ -116,7 +117,9 @@ def get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option): """Prepare the the (x, y) data-set.""" - if not is_model_class_in_a_list(model_class, get_sklearn_linear_models()): + if not is_model_class_in_a_list( + model_class, get_sklearn_linear_models() + get_sklearn_neighbors_models() + ): if n_bits in N_BITS_WEEKLY_ONLY_BUILDS and not is_weekly_option: pytest.skip("Skipping some tests in non-weekly builds, except for linear models") @@ -129,7 +132,9 @@ def get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option): def preamble(model_class, parameters, n_bits, load_data, is_weekly_option): """Prepare the fitted model, and the (x, y) data-set.""" - if not is_model_class_in_a_list(model_class, get_sklearn_linear_models()): + if not is_model_class_in_a_list( + model_class, get_sklearn_linear_models() + get_sklearn_neighbors_models() + ): if n_bits in N_BITS_WEEKLY_ONLY_BUILDS and not is_weekly_option: pytest.skip("Skipping some tests in non-weekly builds") @@ -199,6 +204,7 @@ def check_correctness_with_sklearn( "XGBClassifier": 0.7, "RandomForestClassifier": 0.8, "NeuralNetClassifier": 0.7, + "KNeighborsClassifier": 0.9, } model_name = get_model_name(model_class) @@ -219,6 +225,7 @@ def check_correctness_with_sklearn( def check_double_fit(model_class, n_bits, x_1, x_2, y_1, y_2): """Check double fit.""" + model = instantiate_model_generic(model_class, n_bits=n_bits) # Sometimes, we miss convergence, which is not a problem for our test @@ -273,6 +280,7 @@ def check_double_fit(model_class, n_bits, x_1, x_2, y_1, y_2): # Check that the new quantizers are different from the first ones. This is because we # currently expect all quantizers to be re-computed when re-fitting a model + assert all( quantizer_1 != quantizer_2 for (quantizer_1, quantizer_2) in zip(quantizers_1, quantizers_2) @@ -296,6 +304,7 @@ def check_double_fit(model_class, n_bits, x_1, x_2, y_1, y_2): # Check that the new quantizers are identical from the first ones. Again, we expect the # quantizers to be re-computed when re-fitting. Since we used the same dataset as the first # fit, we also expect these quantizers to be the same. + assert all( quantizer_1 == quantizer_3 for (quantizer_1, quantizer_3) in zip( @@ -471,6 +480,10 @@ def check_subfunctions(fitted_model, model_class, x): ): fitted_model.predict_proba(x) + if get_model_name(fitted_model) == "KNeighborsClassifier": + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 + pytest.skip("Skipping subfunctions test for KNN, doesn't work for now") + if is_classifier_or_partial_classifier(model_class): fitted_model.predict_proba(x) @@ -566,6 +579,9 @@ def cast_input(x, y, input_type): # Similarly, we test `predict_proba` for classifiers if is_classifier_or_partial_classifier(model): + if get_model_name(model_class) == "KNeighborsClassifier": + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 + pytest.skip("Skipping predict_proba for KNN, doesn't work for now") model.predict_proba(x) # If n_bits is above N_BITS_LINEAR_MODEL_CRYPTO_PARAMETERS, do not compile the model @@ -626,8 +642,8 @@ def check_pipeline(model_class, x, y): param_grid = { "model__n_bits": [2, 3], } - - grid_search = GridSearchCV(pipe_cv, param_grid, error_score="raise", cv=3) + # We need a small number of splits, especially for the KNN model, which has a small data-set + grid_search = GridSearchCV(pipe_cv, param_grid, error_score="raise", cv=2) # Sometimes, we miss convergence, which is not a problem for our test with warnings.catch_warnings(): @@ -655,6 +671,8 @@ def check_grid_search(model_class, x, y, scoring): "n_estimators": [5, 10], "n_jobs": [1], } + elif model_class in get_sklearn_neighbors_models(): + param_grid = {"n_bits": [2], "n_neighbors": [2]} else: param_grid = { "n_bits": [20], @@ -663,9 +681,17 @@ def check_grid_search(model_class, x, y, scoring): with warnings.catch_warnings(): # Sometimes, we miss convergence, which is not a problem for our test warnings.simplefilter("ignore", category=ConvergenceWarning) + warnings.simplefilter("ignore", category=UndefinedMetricWarning) + + if get_model_name(model_class) == "KNeighborsClassifier" and scoring in [ + "roc_auc", + "average_precision", + ]: + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 + pytest.skip("Skipping predict_proba for KNN, doesn't work for now") _ = GridSearchCV( - model_class(), param_grid, cv=5, scoring=scoring, error_score="raise", n_jobs=1 + model_class(), param_grid, cv=2, scoring=scoring, error_score="raise", n_jobs=1 ).fit(x, y) @@ -705,7 +731,9 @@ def check_sklearn_equivalence(model_class, n_bits, x, y, check_accuracy, check_r y_pred_sklearn = sklearn_model.decision_function(x) # Else, compute the model's predicted probabilities - else: + # predict_proba not implemented for KNeighborsClassifier for now + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 + elif get_model_name(model_class) != "KNeighborsClassifier": y_pred_cml = model.predict_proba(x) y_pred_sklearn = sklearn_model.predict_proba(x) @@ -762,6 +790,9 @@ def get_hyper_param_combinations(model_class): "importance_type": ["weight", "gain"], "base_score": [0.5, None], } + elif model_class in get_sklearn_neighbors_models(): + # Use small `n_neighbors` values for KNN, because the data-set is too small for now + hyper_param_combinations = {"n_neighbors": [1, 2]} else: assert is_model_class_in_a_list( @@ -852,6 +883,8 @@ def check_fitted_compiled_error_raises(model_class, n_bits, x, y): model.predict(x) if is_classifier_or_partial_classifier(model_class): + if get_model_name(model) == "KNeighborsClassifier": + pytest.skip("predict_proba not implement for KNN") # Predicting probabilities using an untrained linear or tree-based classifier should not # be possible if not is_model_class_in_a_list(model_class, get_sklearn_neural_net_models()): @@ -1405,6 +1438,10 @@ def test_predict_correctness( "Inference in the clear (with " f"number_of_tests_in_non_fhe = {number_of_tests_in_non_fhe})" ) + # KNN works only for smaller quantization bits + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 + if n_bits > 5 and get_model_name(model) == "KNeighborsClassifier": + pytest.skip("Use less than 5 bits with KNN.") y_pred = model.predict(x[:number_of_tests_in_non_fhe]) @@ -1511,10 +1548,14 @@ def test_p_error_global_p_error_simulation( if "global_p_error" in error_param: pytest.skip("global_p_error behave very differently depending on the type of model.") - # Get data-set - n_bits = min(N_BITS_REGULAR_BUILDS) + if get_model_name(model_class) == "KNeighborsClassifier": + # KNN works only for smaller quantization bits + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 + n_bits = min([2] + N_BITS_REGULAR_BUILDS) + else: + n_bits = min(N_BITS_REGULAR_BUILDS) - # Initialize and fit the model + # Get data-set, initialize and fit the model model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option) # Check if model is linear @@ -1526,7 +1567,12 @@ def test_p_error_global_p_error_simulation( def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_RUN): """Detect divergence between simulated/FHE execution and clear run.""" predict_function = ( - model.predict_proba if is_classifier_or_partial_classifier(model) else model.predict + model.predict_proba + if is_classifier_or_partial_classifier(model) + # `predict_prob` not implemented yet for KNeighborsClassifier + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 + and get_model_name(model) != "KNeighborsClassifier" + else model.predict ) y_expected = predict_function(x, fhe="disable") for i in range(max_iterations): @@ -1641,6 +1687,11 @@ def test_mono_parameter_warnings( if is_model_class_in_a_list(model_class, get_sklearn_linear_models()): return + # KNN is manually forced to use mono-parameter + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978 + if is_model_class_in_a_list(model_class, get_sklearn_neighbors_models()): + return + n_bits = min(N_BITS_REGULAR_BUILDS) model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option) diff --git a/use_case_examples/credit_scoring/CreditScoring.ipynb b/use_case_examples/credit_scoring/CreditScoring.ipynb index b5af7d35c..c4ce77f6c 100644 --- a/use_case_examples/credit_scoring/CreditScoring.ipynb +++ b/use_case_examples/credit_scoring/CreditScoring.ipynb @@ -20,11 +20,7 @@ "from functools import partial\n", "\n", "import numpy as np\n", - "import pandas as pd\n", - "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.pipeline import Pipeline\n", - "from sklearn.preprocessing import StandardScaler" + "import pandas as pd" ] }, { @@ -36,6 +32,10 @@ "# Importing the models, from both scikit-learn and Concrete ML\n", "from sklearn.ensemble import RandomForestClassifier as SklearnRandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression\n", + "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import StandardScaler\n", "from sklearn.tree import DecisionTreeClassifier as SklearnDecisionTreeClassifier\n", "from xgboost import XGBClassifier as SklearnXGBoostClassifier\n", "\n",