Skip to content

Commit

Permalink
chore: force the configuration of KNN to run under MONO settings
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Sep 18, 2023
1 parent ca03c3c commit 18ccae1
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 119 deletions.
20 changes: 2 additions & 18 deletions src/concrete/ml/search_parameters/p_error_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@

import numpy
import torch
from concrete.fhe import ParameterSelectionStrategy
from concrete.fhe.compilation import Configuration
from tqdm import tqdm

from ..common.utils import get_model_name, is_brevitas_model, is_model_class_in_a_list
from ..common.utils import is_brevitas_model, is_model_class_in_a_list
from ..sklearn import (
get_sklearn_neighbors_models,
get_sklearn_neural_net_models,
Expand Down Expand Up @@ -110,16 +108,6 @@ def compile_and_simulated_fhe_inference(
"""

compile_params: Dict = {}

default_configuration = Configuration(
dump_artifacts_on_unexpected_failures=False,
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location="ConcreteNumpyKeyCache",
parameter_selection_strategy=ParameterSelectionStrategy.MONO
if get_model_name(estimator) == "KNeighborsClassifier"
else ParameterSelectionStrategy.MULTI,
)
compile_function: Callable[..., Any]
dequantized_output: numpy.ndarray

Expand Down Expand Up @@ -150,11 +138,7 @@ def compile_and_simulated_fhe_inference(
if not estimator.is_fitted:
estimator.fit(calibration_data, ground_truth)

estimator.compile(
calibration_data,
p_error=p_error,
configuration=default_configuration,
)
estimator.compile(calibration_data, p_error=p_error)
predict_method = getattr(estimator, predict)
dequantized_output = predict_method(calibration_data, fhe="simulate")

Expand Down
86 changes: 55 additions & 31 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,7 +1426,7 @@ class BaseTreeClassifierMixin(
"""


# pylint: disable=invalid-name,too-many-instance-attributes
# pylint: disable-next=invalid-name,too-many-instance-attributes
class SklearnLinearModelMixin(BaseEstimator, sklearn.base.BaseEstimator, ABC):
"""A Mixin class for sklearn linear models with FHE.
Expand Down Expand Up @@ -1697,7 +1697,7 @@ def predict_proba(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) ->
return y_proba


# pylint: disable=invalid-name,too-many-instance-attributes
# pylint: disable-next=invalid-name,too-many-instance-attributes
class SklearnKNeighborsMixin(BaseEstimator, sklearn.base.BaseEstimator, ABC):
"""A Mixin class for sklearn KNeighbors models with FHE.
Expand Down Expand Up @@ -1728,6 +1728,10 @@ def __init__(self, n_bits: Union[int, Dict[str, int]] = 3):

#: The quantizer to use for quantizing the model's weights
self._weight_quantizer: Optional[UniformQuantizer] = None
# In distance metric algorithms, there is no `weights` attribute because there isn't
# a training phase.
# Instead, we have `_X_fit` attribute, commonly used in scikit-learn, which stores the
# training set, in order to compute the pairewise Euclidean distance.
self._q_X_fit_quantizer: Optional[UniformQuantizer] = None
self._q_X_fit: numpy.ndarray

Expand All @@ -1748,7 +1752,7 @@ def _set_onnx_model(self, test_input: numpy.ndarray) -> None:
test_input=test_input,
extra_config={
"onnx_target_opset": OPSET_VERSION_FOR_ONNX_EXPORT,
# pylint: disable=protected-access, no-member
# pylint: disable-next=protected-access, no-member
constants.BATCH_SIZE: self.sklearn_model._fit_X.shape[0],
},
).model
Expand Down Expand Up @@ -1796,7 +1800,7 @@ def fit(self, X: Data, y: Target, **fit_parameters):

# Quantize the _X_fit and store the associated quantizer
# Weights in KNN algorithms are the train data points
# pylint: disable=protected-access
# pylint: disable-next=protected-access
_X_fit = self.sklearn_model._fit_X
q_X_fit = QuantizedArray(
n_bits=n_bits["op_weights"],
Expand Down Expand Up @@ -1951,57 +1955,60 @@ def scatter1d(x, v, indices):
x[i] = v[idx]
return x

def mul_tlu(a, b):
"""Matrix multiplication.
Args:
a (numpy.ndarray): An encrypted array
b (numpy.ndarray): An encrypted array
Returns:
numpy.ndarray: The result of a * b
"""
return a * b

comparisons = numpy.zeros(x.shape)
idx = numpy.arange(x.size) + fhe_zeros(x.shape)

n, k = x.size, self.n_neighbors
ln2n = int(numpy.ceil(numpy.log2(n)))

# Number of stages
for t in range(ln2n - 1, -1, -1):
p = 2**t
r = 0
# d: Length of the bitonic sequence
d = p

for bq in range(ln2n - 1, t - 1, -1):
q = 2**bq
# Determine the range of indices to be compared
range_i = numpy.array(
[i for i in range(0, n - d) if i & p == r and comparisons[i] < k]
)
if len(range_i) == 0:
# Edge case, for k=1
continue

a = gather1d(x, range_i) # x[range_i]
a_i = gather1d(idx, range_i) # idx[range_i]
b = gather1d(x, range_i + d) # x[range_i + d]
b_i = gather1d(idx, range_i + d) # idx[range_i + d]
# Select 2 bitonic sequences `a` and `b` of length `d`
# a = x[range_i]: first bitonic sequence
a = gather1d(x, range_i)
a_i = gather1d(idx, range_i)
# b = x[range_i + d]: Second bitonic sequence
# b_i = idx[range_i]: Indices of a_i elements in the original x
b = gather1d(x, range_i + d)
b_i = gather1d(idx, range_i + d)

# Select max(a, b)
diff = a - b
sign = diff < 0

max_x = a + numpy.maximum(0, b - a)
x = scatter1d(x, a + b - max_x, range_i) # x[range_i] = a + b - max_x
x = scatter1d(x, max_x, range_i + d) # x[range_i + d] = max_x

max_idx = a_i + mul_tlu((b_i - a_i), sign)
# Swap if a > b
# x[range_i] = max_x(a, b): First bitonic sequence gets min(a, b)
x = scatter1d(x, a + b - max_x, range_i)
# x[range_i + d] = min(a, b): Second bitonic sequence gets max(a, b)
x = scatter1d(x, max_x, range_i + d)

# idx[range_i] = a_i + b_i - max_idx
# Max index selection
sign = diff < 0
max_idx = a_i + (b_i - a_i) * sign

# Update index array according to max items
# idx[range_i] = a_i + b_i - max_idx <=> min_idx
idx = scatter1d(idx, a_i + b_i - max_idx, range_i)
idx = scatter1d(idx, max_idx, range_i + d) # idx[range_i + d] = max_idx
# idx[range_i + d] = max_idx
idx = scatter1d(idx, max_idx, range_i + d)

# Update
comparisons[range_i + d] = comparisons[range_i + d] + 1

d = q - p
r = p

Expand All @@ -2011,8 +2018,6 @@ def mul_tlu(a, b):

topk_indexes = fhe_array(topk_indexes)

assert topk_indexes.shape[0] == self.n_neighbors

return topk_indexes

# 1. Pairwise_euclidiean distance
Expand All @@ -2031,6 +2036,25 @@ def mul_tlu(a, b):

return numpy.expand_dims(sorted_args, axis=0)

# KNN works only for MONO in the latest concrete Python version
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3978
def compile(self, *args, **kwargs) -> Circuit:
# If a configuration instance is given as a positional parameter, set the strategy to
# multi-parameter
if len(args) >= 2:
configuration = force_mono_parameter_in_configuration(args[1])
args_list = list(args)
args_list[1] = configuration
args = tuple(args_list)

# Else, retrieve the configuration in kwargs if it exists, or create a new one, and set the
# strategy to multi-parameter
else:
configuration = kwargs.get("configuration", None)
kwargs["configuration"] = force_mono_parameter_in_configuration(configuration)

return BaseEstimator.compile(self, *args, **kwargs)

def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:

X = check_array_and_assert(X)
Expand All @@ -2040,7 +2064,7 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.
# Argsort
arg_sort = super().predict(query[None], fhe)
# Majority vote
# pylint: disable=protected-access
# pylint: disable-next=protected-access
label_indices = self._y[arg_sort.flatten()]
y_pred = self.majority_vote(label_indices)
y_preds.append(y_pred)
Expand Down
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class KNeighborsClassifier(SklearnKNeighborsClassifierMixin):

def __init__(
self,
n_bits=3,
n_bits=2,
n_neighbors=3,
*,
weights="uniform",
Expand All @@ -53,7 +53,7 @@ def __init__(
self._y = None

def dump_dict(self) -> Dict[str, Any]:
assert self._weight_quantizer is not None, self._is_not_fitted_error_message()
assert self._q_X_fit_quantizer is not None, self._is_not_fitted_error_message()

metadata: Dict[str, Any] = {}

Expand All @@ -63,7 +63,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["_is_fitted"] = self._is_fitted
metadata["_is_compiled"] = self._is_compiled
metadata["input_quantizers"] = self.input_quantizers
metadata["_weight_quantizer"] = self._weight_quantizer
metadata["_q_X_fit_quantizer"] = self._q_X_fit_quantizer
metadata["_q_X_fit"] = self._q_X_fit
metadata["_y"] = self._y
Expand Down Expand Up @@ -99,7 +98,6 @@ def load_dict(cls, metadata: Dict):
obj._is_compiled = metadata["_is_compiled"]
obj.input_quantizers = metadata["input_quantizers"]
obj.output_quantizers = metadata["output_quantizers"]
obj._weight_quantizer = metadata["_weight_quantizer"]
obj._q_X_fit_quantizer = metadata["_q_X_fit_quantizer"]
obj._q_X_fit = metadata["_q_X_fit"]
obj._y = metadata["_y"]
Expand Down
21 changes: 3 additions & 18 deletions tests/common/test_pbs_error_probability_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@

import numpy
import pytest
from concrete.fhe.compilation import Configuration
from sklearn.exceptions import ConvergenceWarning
from torch import nn

from concrete import fhe
from concrete.ml.common.utils import get_model_name
from concrete.ml.pytest.torch_models import FCSmall
from concrete.ml.pytest.utils import sklearn_models_and_datasets
from concrete.ml.torch.compile import compile_torch_model
Expand All @@ -29,7 +26,7 @@
{"global_p_error": 0.038, "p_error": 0.39},
],
)
def test_config_sklearn(model_class, parameters, kwargs, load_data, default_configuration):
def test_config_sklearn(model_class, parameters, kwargs, load_data):
"""Testing with p_error and global_p_error configs with sklearn models."""

x, y = load_data(model_class, **parameters)
Expand All @@ -41,24 +38,12 @@ def test_config_sklearn(model_class, parameters, kwargs, load_data, default_conf
# Fit the model
model.fit(x, y)

if get_model_name(model_class) == "KNeighborsClassifier":

default_configuration = Configuration(
dump_artifacts_on_unexpected_failures=False,
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location="ConcreteNumpyKeyCache",
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MONO,
single_precision=True,
)

if kwargs.get("p_error", None) is not None and kwargs.get("global_p_error", None) is not None:
with pytest.raises(ValueError) as excinfo:
model.compile(x, default_configuration, verbose=True, **kwargs)
model.compile(x, verbose=True, **kwargs)
assert "Please only set one of (p_error, global_p_error) values" in str(excinfo.value)
else:

model.compile(x, default_configuration, verbose=True, **kwargs)
model.compile(x, verbose=True, **kwargs)

# We still need to check that we have the expected probabilities
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2206
Expand Down
15 changes: 1 addition & 14 deletions tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@

import numpy
import pytest
from concrete.fhe.compilation import Configuration
from sklearn.exceptions import ConvergenceWarning
from torch import nn

from concrete import fhe
from concrete.ml.common.utils import get_model_name
from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.pytest.torch_models import FCSmall
from concrete.ml.pytest.utils import instantiate_model_generic, sklearn_models_and_datasets
Expand Down Expand Up @@ -98,20 +95,10 @@ def test_client_server_sklearn(
# Compile
extra_params = {"global_p_error": 1 / 100_000}

if get_model_name(model_class) == "KNeighborsClassifier":

default_configuration = Configuration(
dump_artifacts_on_unexpected_failures=False,
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location="ConcreteNumpyKeyCache",
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MONO,
single_precision=True,
)

# Running the simulation using a model that is not compiled should not be possible
with pytest.raises(AttributeError, match=".* model is not compiled.*"):
client_server_simulation(x_train, x_test, model, default_configuration)

# With n_bits = 3, KNN is not compilable
fhe_circuit = model.compile(
x_train, default_configuration, **extra_params, show_mlir=(n_bits <= 8)
Expand Down
10 changes: 5 additions & 5 deletions tests/sklearn/test_dump_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
from sklearn.exceptions import ConvergenceWarning

from concrete import fhe
from concrete.ml.common.utils import is_model_class_in_a_list
from concrete.ml.pytest.utils import get_model_name, sklearn_models_and_datasets
from concrete.ml.sklearn import get_sklearn_tree_models
Expand Down Expand Up @@ -37,9 +36,9 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau
model.set_params(**model_params)

if get_model_name(model) == "KNeighborsClassifier":
model.n_bits = 4
default_configuration.parameter_selection_strategy = fhe.ParameterSelectionStrategy.MONO
default_configuration.single_precision = True
# KNN works only for small quantization bits
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979
model.n_bits = 2

with warnings.catch_warnings():
# Sometimes, we miss convergence, which is not a problem for our test
Expand All @@ -50,6 +49,7 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau
with warnings.catch_warnings():
# Use FHE simulation to not have issues with precision
model.compile(x, default_configuration)

# Get ONNX model
onnx_model = model.onnx_model

Expand Down Expand Up @@ -423,7 +423,7 @@ def test_dump(
return %variable
}""",
"KNeighborsClassifier": """graph torch_jit (
%input_0[DOUBLE, symx3]
%input_0[DOUBLE, symx2]
) {
%/_operators.0/Constant_output_0 = Constant[value = <Tensor>]()
%/_operators.0/Unsqueeze_output_0 = Unsqueeze(%input_0, %/_operators.0/Constant_output_0)
Expand Down
Loading

0 comments on commit 18ccae1

Please sign in to comment.