diff --git a/src/concrete/ml/search_parameters/p_error_search.py b/src/concrete/ml/search_parameters/p_error_search.py index eec213001e..dbed2c1f7a 100644 --- a/src/concrete/ml/search_parameters/p_error_search.py +++ b/src/concrete/ml/search_parameters/p_error_search.py @@ -58,11 +58,9 @@ import numpy import torch -from concrete.fhe import ParameterSelectionStrategy -from concrete.fhe.compilation import Configuration from tqdm import tqdm -from ..common.utils import get_model_name, is_brevitas_model, is_model_class_in_a_list +from ..common.utils import is_brevitas_model, is_model_class_in_a_list from ..sklearn import ( get_sklearn_neighbors_models, get_sklearn_neural_net_models, @@ -110,16 +108,6 @@ def compile_and_simulated_fhe_inference( """ compile_params: Dict = {} - - default_configuration = Configuration( - dump_artifacts_on_unexpected_failures=False, - enable_unsafe_features=True, - use_insecure_key_cache=True, - insecure_key_cache_location="ConcreteNumpyKeyCache", - parameter_selection_strategy=ParameterSelectionStrategy.MONO - if get_model_name(estimator) == "KNeighborsClassifier" - else ParameterSelectionStrategy.MULTI, - ) compile_function: Callable[..., Any] dequantized_output: numpy.ndarray @@ -150,11 +138,7 @@ def compile_and_simulated_fhe_inference( if not estimator.is_fitted: estimator.fit(calibration_data, ground_truth) - estimator.compile( - calibration_data, - p_error=p_error, - configuration=default_configuration, - ) + estimator.compile(calibration_data, p_error=p_error) predict_method = getattr(estimator, predict) dequantized_output = predict_method(calibration_data, fhe="simulate") diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index 83f40bf1f1..2024457a65 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -1426,7 +1426,7 @@ class BaseTreeClassifierMixin( """ -# pylint: disable=invalid-name,too-many-instance-attributes +# pylint: disable-next=invalid-name,too-many-instance-attributes class SklearnLinearModelMixin(BaseEstimator, sklearn.base.BaseEstimator, ABC): """A Mixin class for sklearn linear models with FHE. @@ -1697,7 +1697,7 @@ def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> return y_proba -# pylint: disable=invalid-name,too-many-instance-attributes +# pylint: disable-next=invalid-name,too-many-instance-attributes class SklearnKNeighborsMixin(BaseEstimator, sklearn.base.BaseEstimator, ABC): """A Mixin class for sklearn KNeighbors models with FHE. @@ -1728,6 +1728,9 @@ def __init__(self, n_bits: Union[int, Dict[str, int]] = 3): #: The quantizer to use for quantizing the model's weights self._weight_quantizer: Optional[UniformQuantizer] = None + # In distance metric algorithms, there is no `weights` attribute because there isn't + # a training phase. Instead, we have `_X_fit` attribute, commonly used in scikit-learn, which + # stores the training set, in order to compute the pairewise Euclidean distance. self._q_X_fit_quantizer: Optional[UniformQuantizer] = None self._q_X_fit: numpy.ndarray @@ -1748,7 +1751,7 @@ def _set_onnx_model(self, test_input: numpy.ndarray) -> None: test_input=test_input, extra_config={ "onnx_target_opset": OPSET_VERSION_FOR_ONNX_EXPORT, - # pylint: disable=protected-access, no-member + # pylint: disable-next=protected-access, no-member constants.BATCH_SIZE: self.sklearn_model._fit_X.shape[0], }, ).model @@ -1796,7 +1799,7 @@ def fit(self, X: Data, y: Target, **fit_parameters): # Quantize the _X_fit and store the associated quantizer # Weights in KNN algorithms are the train data points - # pylint: disable=protected-access + # pylint: disable-next=protected-access _X_fit = self.sklearn_model._fit_X q_X_fit = QuantizedArray( n_bits=n_bits["op_weights"], @@ -1951,57 +1954,60 @@ def scatter1d(x, v, indices): x[i] = v[idx] return x - def mul_tlu(a, b): - """Matrix multiplication. - - Args: - a (numpy.ndarray): An encrypted array - b (numpy.ndarray): An encrypted array - - Returns: - numpy.ndarray: The result of a * b - """ - return a * b - comparisons = numpy.zeros(x.shape) idx = numpy.arange(x.size) + fhe_zeros(x.shape) n, k = x.size, self.n_neighbors ln2n = int(numpy.ceil(numpy.log2(n))) + # Number of stages for t in range(ln2n - 1, -1, -1): p = 2**t r = 0 + # d: Length of the bitonic sequence d = p for bq in range(ln2n - 1, t - 1, -1): q = 2**bq + # Determine the range of indices to be compared range_i = numpy.array( [i for i in range(0, n - d) if i & p == r and comparisons[i] < k] ) if len(range_i) == 0: + # Edge case, for k=1 continue - a = gather1d(x, range_i) # x[range_i] - a_i = gather1d(idx, range_i) # idx[range_i] - b = gather1d(x, range_i + d) # x[range_i + d] - b_i = gather1d(idx, range_i + d) # idx[range_i + d] + # Select 2 bitonic sequences `a` and `b` of length `d` + # a = x[range_i]: first bitonic sequence + a = gather1d(x, range_i) + a_i = gather1d(idx, range_i) + # b = x[range_i + d]: Second bitonic sequence + # b_i = idx[range_i]: Indices of a_i elements in the original x + b = gather1d(x, range_i + d) + b_i = gather1d(idx, range_i + d) + # Select max(a, b) diff = a - b - sign = diff < 0 - max_x = a + numpy.maximum(0, b - a) - x = scatter1d(x, a + b - max_x, range_i) # x[range_i] = a + b - max_x - x = scatter1d(x, max_x, range_i + d) # x[range_i + d] = max_x - max_idx = a_i + mul_tlu((b_i - a_i), sign) + # Swap if a > b + # x[range_i] = max_x(a, b): First bitonic sequence gets min(a, b) + x = scatter1d(x, a + b - max_x, range_i) + # x[range_i + d] = min(a, b): Second bitonic sequence gets max(a, b) + x = scatter1d(x, max_x, range_i + d) - # idx[range_i] = a_i + b_i - max_idx + # Max index selection + sign = diff < 0 + max_idx = a_i + (b_i - a_i) * sign + + # Update index array according to max items + # idx[range_i] = a_i + b_i - max_idx <=> min_idx idx = scatter1d(idx, a_i + b_i - max_idx, range_i) - idx = scatter1d(idx, max_idx, range_i + d) # idx[range_i + d] = max_idx + # idx[range_i + d] = max_idx + idx = scatter1d(idx, max_idx, range_i + d) + # Update comparisons[range_i + d] = comparisons[range_i + d] + 1 - d = q - p r = p @@ -2031,6 +2037,25 @@ def mul_tlu(a, b): return numpy.expand_dims(sorted_args, axis=0) + # KNN works only for MONO in the latest concrete Python version + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978 + def compile(self, *args, **kwargs) -> Circuit: + # If a configuration instance is given as a positional parameter, set the strategy to + # multi-parameter + if len(args) >= 2: + configuration = force_mono_parameter_in_configuration(args[1]) + args_list = list(args) + args_list[1] = configuration + args = tuple(args_list) + + # Else, retrieve the configuration in kwargs if it exists, or create a new one, and set the + # strategy to multi-parameter + else: + configuration = kwargs.get("configuration", None) + kwargs["configuration"] = force_mono_parameter_in_configuration(configuration) + + return BaseEstimator.compile(self, *args, **kwargs) + def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray: X = check_array_and_assert(X) @@ -2040,7 +2065,7 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy. # Argsort arg_sort = super().predict(query[None], fhe) # Majority vote - # pylint: disable=protected-access + # pylint: disable-next=protected-access label_indices = self._y[arg_sort.flatten()] y_pred = self.majority_vote(label_indices) y_preds.append(y_pred) diff --git a/src/concrete/ml/sklearn/neighbors.py b/src/concrete/ml/sklearn/neighbors.py index d7dad8639e..25d6a6d2f8 100644 --- a/src/concrete/ml/sklearn/neighbors.py +++ b/src/concrete/ml/sklearn/neighbors.py @@ -28,7 +28,7 @@ class KNeighborsClassifier(SklearnKNeighborsClassifierMixin): def __init__( self, - n_bits=3, + n_bits=2, n_neighbors=3, *, weights="uniform", @@ -53,7 +53,7 @@ def __init__( self._y = None def dump_dict(self) -> Dict[str, Any]: - assert self._weight_quantizer is not None, self._is_not_fitted_error_message() + assert self._q_X_fit_quantizer is not None, self._is_not_fitted_error_message() metadata: Dict[str, Any] = {} @@ -63,7 +63,6 @@ def dump_dict(self) -> Dict[str, Any]: metadata["_is_fitted"] = self._is_fitted metadata["_is_compiled"] = self._is_compiled metadata["input_quantizers"] = self.input_quantizers - metadata["_weight_quantizer"] = self._weight_quantizer metadata["_q_X_fit_quantizer"] = self._q_X_fit_quantizer metadata["_q_X_fit"] = self._q_X_fit metadata["_y"] = self._y @@ -99,7 +98,6 @@ def load_dict(cls, metadata: Dict): obj._is_compiled = metadata["_is_compiled"] obj.input_quantizers = metadata["input_quantizers"] obj.output_quantizers = metadata["output_quantizers"] - obj._weight_quantizer = metadata["_weight_quantizer"] obj._q_X_fit_quantizer = metadata["_q_X_fit_quantizer"] obj._q_X_fit = metadata["_q_X_fit"] obj._y = metadata["_y"] diff --git a/tests/common/test_pbs_error_probability_settings.py b/tests/common/test_pbs_error_probability_settings.py index 4066119eb9..31aad3aea9 100644 --- a/tests/common/test_pbs_error_probability_settings.py +++ b/tests/common/test_pbs_error_probability_settings.py @@ -4,12 +4,9 @@ import numpy import pytest -from concrete.fhe.compilation import Configuration from sklearn.exceptions import ConvergenceWarning from torch import nn -from concrete import fhe -from concrete.ml.common.utils import get_model_name from concrete.ml.pytest.torch_models import FCSmall from concrete.ml.pytest.utils import sklearn_models_and_datasets from concrete.ml.torch.compile import compile_torch_model @@ -29,7 +26,7 @@ {"global_p_error": 0.038, "p_error": 0.39}, ], ) -def test_config_sklearn(model_class, parameters, kwargs, load_data, default_configuration): +def test_config_sklearn(model_class, parameters, kwargs, load_data): """Testing with p_error and global_p_error configs with sklearn models.""" x, y = load_data(model_class, **parameters) @@ -41,24 +38,12 @@ def test_config_sklearn(model_class, parameters, kwargs, load_data, default_conf # Fit the model model.fit(x, y) - if get_model_name(model_class) == "KNeighborsClassifier": - - default_configuration = Configuration( - dump_artifacts_on_unexpected_failures=False, - enable_unsafe_features=True, - use_insecure_key_cache=True, - insecure_key_cache_location="ConcreteNumpyKeyCache", - parameter_selection_strategy=fhe.ParameterSelectionStrategy.MONO, - single_precision=True, - ) - if kwargs.get("p_error", None) is not None and kwargs.get("global_p_error", None) is not None: with pytest.raises(ValueError) as excinfo: - model.compile(x, default_configuration, verbose=True, **kwargs) + model.compile(x, verbose=True, **kwargs) assert "Please only set one of (p_error, global_p_error) values" in str(excinfo.value) else: - - model.compile(x, default_configuration, verbose=True, **kwargs) + model.compile(x, verbose=True, **kwargs) # We still need to check that we have the expected probabilities # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2206 diff --git a/tests/deployment/test_client_server.py b/tests/deployment/test_client_server.py index f5e4a8e438..7df681a1a5 100644 --- a/tests/deployment/test_client_server.py +++ b/tests/deployment/test_client_server.py @@ -9,12 +9,9 @@ import numpy import pytest -from concrete.fhe.compilation import Configuration from sklearn.exceptions import ConvergenceWarning from torch import nn -from concrete import fhe -from concrete.ml.common.utils import get_model_name from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer from concrete.ml.pytest.torch_models import FCSmall from concrete.ml.pytest.utils import instantiate_model_generic, sklearn_models_and_datasets @@ -98,20 +95,10 @@ def test_client_server_sklearn( # Compile extra_params = {"global_p_error": 1 / 100_000} - if get_model_name(model_class) == "KNeighborsClassifier": - - default_configuration = Configuration( - dump_artifacts_on_unexpected_failures=False, - enable_unsafe_features=True, - use_insecure_key_cache=True, - insecure_key_cache_location="ConcreteNumpyKeyCache", - parameter_selection_strategy=fhe.ParameterSelectionStrategy.MONO, - single_precision=True, - ) - # Running the simulation using a model that is not compiled should not be possible with pytest.raises(AttributeError, match=".* model is not compiled.*"): client_server_simulation(x_train, x_test, model, default_configuration) + # With n_bits = 3, KNN is not compilable fhe_circuit = model.compile( x_train, default_configuration, **extra_params, show_mlir=(n_bits <= 8) diff --git a/tests/sklearn/test_dump_onnx.py b/tests/sklearn/test_dump_onnx.py index f1949a6ca3..e2957788f1 100644 --- a/tests/sklearn/test_dump_onnx.py +++ b/tests/sklearn/test_dump_onnx.py @@ -9,7 +9,6 @@ import pytest from sklearn.exceptions import ConvergenceWarning -from concrete import fhe from concrete.ml.common.utils import is_model_class_in_a_list from concrete.ml.pytest.utils import get_model_name, sklearn_models_and_datasets from concrete.ml.sklearn import get_sklearn_tree_models @@ -37,9 +36,9 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau model.set_params(**model_params) if get_model_name(model) == "KNeighborsClassifier": - model.n_bits = 4 - default_configuration.parameter_selection_strategy = fhe.ParameterSelectionStrategy.MONO - default_configuration.single_precision = True + # KNN works only for small quantization bits + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 + model.n_bits = 2 with warnings.catch_warnings(): # Sometimes, we miss convergence, which is not a problem for our test @@ -50,6 +49,7 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau with warnings.catch_warnings(): # Use FHE simulation to not have issues with precision model.compile(x, default_configuration) + # Get ONNX model onnx_model = model.onnx_model @@ -423,7 +423,7 @@ def test_dump( return %variable }""", "KNeighborsClassifier": """graph torch_jit ( - %input_0[DOUBLE, symx3] + %input_0[DOUBLE, symx2] ) { %/_operators.0/Constant_output_0 = Constant[value = ]() %/_operators.0/Unsqueeze_output_0 = Unsqueeze(%input_0, %/_operators.0/Constant_output_0) diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index 931d0c3222..90c2eee036 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -585,7 +585,6 @@ def cast_input(x, y, input_type): # Sometimes, we miss convergence, which is not a problem for our test with warnings.catch_warnings(): warnings.simplefilter("ignore", category=ConvergenceWarning) - model.fit(x, y) # Make sure `predict` is working when FHE is disabled @@ -656,8 +655,8 @@ def check_pipeline(model_class, x, y): param_grid = { "model__n_bits": [2, 3], } - - grid_search = GridSearchCV(pipe_cv, param_grid, error_score="raise", cv=3) + # Since the data-set is really small for KNN, we have to decrease the number of splits + grid_search = GridSearchCV(pipe_cv, param_grid, error_score="raise", cv=2) # Sometimes, we miss convergence, which is not a problem for our test with warnings.catch_warnings(): @@ -686,9 +685,7 @@ def check_grid_search(model_class, x, y, scoring): "n_jobs": [1], } elif model_class in get_sklearn_neighbors_models(): - param_grid = { - "n_bits": [3], - } + param_grid = {"n_bits": [2], "n_neighbors": [2]} else: param_grid = { "n_bits": [20], @@ -706,8 +703,11 @@ def check_grid_search(model_class, x, y, scoring): # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962 pytest.skip("Skipping predict_proba for KNN, doesn't work for now") + # pylint: disable=invalid-name + cv = 2 if get_model_name(model_class) == "KNeighborsClassifier" else 5 + _ = GridSearchCV( - model_class(), param_grid, cv=5, scoring=scoring, error_score="raise", n_jobs=1 + model_class(), param_grid, cv=cv, scoring=scoring, error_score="raise", n_jobs=1 ).fit(x, y) @@ -807,7 +807,8 @@ def get_hyper_param_combinations(model_class): "base_score": [0.5, None], } elif model_class in get_sklearn_neighbors_models(): - hyper_param_combinations = {"n_neighbors": [2, 4]} + # Use small `n_neighbors` values for KNN, because the data-set is too small for now + hyper_param_combinations = {"n_neighbors": [1, 2]} else: assert is_model_class_in_a_list( @@ -1350,6 +1351,7 @@ def test_input_support( ): """Test all models with Pandas, List or Torch inputs.""" x, y = get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option) + if verbose: print("Run input_support") @@ -1475,11 +1477,6 @@ def test_predict_correctness( print("Compile the model") with warnings.catch_warnings(): - - if get_model_name(model) == "KNeighborsClassifier": - default_configuration.parameter_selection_strategy = ( - ParameterSelectionStrategy.MONO - ) fhe_circuit = model.compile( x, default_configuration, @@ -1553,7 +1550,6 @@ def test_p_error_global_p_error_simulation( parameters, error_param, load_data, - default_configuration, is_weekly_option, ): """Test p_error and global_p_error simulation. @@ -1567,23 +1563,24 @@ def test_p_error_global_p_error_simulation( if "global_p_error" in error_param: pytest.skip("global_p_error behave very differently depending on the type of model.") - # Get data-set - n_bits = min(N_BITS_REGULAR_BUILDS) if get_model_name(model_class) == "KNeighborsClassifier": - n_bits = min(n_bits, 2) - default_configuration.parameter_selection_strategy = ParameterSelectionStrategy.MONO + # KNN works only for smaller quantization bits + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 + n_bits = min([2] + N_BITS_REGULAR_BUILDS) + else: + n_bits = min(N_BITS_REGULAR_BUILDS) - # Initialize and fit the model + # Get data-set, initialize and fit the model model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option) # Check if model is linear is_linear_model = is_model_class_in_a_list(model_class, get_sklearn_linear_models()) - # Check if model is linear + # Check if model is a distance metrics model is_knn_model = is_model_class_in_a_list(model_class, get_sklearn_neighbors_models()) # Compile with a large p_error to be sure the result is random. - model.compile(x, default_configuration, **error_param) + model.compile(x, **error_param) def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_RUN): """Detect divergence between simulated/FHE execution and clear run.""" @@ -1595,7 +1592,6 @@ def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_ else model.predict ) y_expected = predict_function(x, fhe="disable") - for i in range(max_iterations): y_pred = predict_function(x[i : i + 1], fhe=fhe).ravel() if not numpy.array_equal(y_pred, y_expected[i : i + 1].ravel()): @@ -1617,6 +1613,7 @@ def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_ simulation_diff_found = check_for_divergent_predictions(x, model, fhe="simulate") fhe_diff_found = check_for_divergent_predictions(x, model, fhe="execute") + # Check for differences in predictions # Remark that, with the old VL, linear models (or, more generally, circuits without PBS) were # badly simulated. It has been fixed in the new simulation. @@ -1720,9 +1717,10 @@ def test_mono_parameter_warnings( if is_model_class_in_a_list(model_class, get_sklearn_linear_models()): return - # KNN works only for ParameterSelectionStrategy.MULTI + # KNN is manually forced to use mono-parameter + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978 if is_model_class_in_a_list(model_class, get_sklearn_neighbors_models()): - pytest.skip("Skipping predict_proba for KNN, doesn't work for now") + return n_bits = min(N_BITS_REGULAR_BUILDS) diff --git a/use_case_examples/credit_scoring/CreditScoring.ipynb b/use_case_examples/credit_scoring/CreditScoring.ipynb index b5af7d35c1..c4ce77f6cc 100644 --- a/use_case_examples/credit_scoring/CreditScoring.ipynb +++ b/use_case_examples/credit_scoring/CreditScoring.ipynb @@ -20,11 +20,7 @@ "from functools import partial\n", "\n", "import numpy as np\n", - "import pandas as pd\n", - "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.pipeline import Pipeline\n", - "from sklearn.preprocessing import StandardScaler" + "import pandas as pd" ] }, { @@ -36,6 +32,10 @@ "# Importing the models, from both scikit-learn and Concrete ML\n", "from sklearn.ensemble import RandomForestClassifier as SklearnRandomForestClassifier\n", "from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression\n", + "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.preprocessing import StandardScaler\n", "from sklearn.tree import DecisionTreeClassifier as SklearnDecisionTreeClassifier\n", "from xgboost import XGBClassifier as SklearnXGBoostClassifier\n", "\n",