Skip to content

Commit

Permalink
chore: add post_processing
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Sep 21, 2023
1 parent d5b6e46 commit fd2c1c7
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 45 deletions.
82 changes: 47 additions & 35 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,14 +1719,14 @@ def __init__(self, n_bits: int = 3):
quantizing inputs and X_fit. Default to 3.
"""
self.n_bits: int = n_bits
# _q_X_fit: In distance metric algorithms, `_q_X_fit` stores the training set to compute
# _q_fit_X: In distance metric algorithms, `_q_fit_X` stores the training set to compute
# the similarity or distance measures. There is no `weights` attribute because there isn't
# a training phase
self._q_X_fit: numpy.ndarray
# _y: Labels of `_q_X_fit`
self._q_fit_X: numpy.ndarray
# _y: Labels of `_q_fit_X`
self._y: numpy.ndarray
# _q_X_fit_quantizer: The quantizer to use for quantizing the model's training set
self._q_X_fit_quantizer: Optional[UniformQuantizer] = None
# _q_fit_X_quantizer: The quantizer to use for quantizing the model's training set
self._q_fit_X_quantizer: Optional[UniformQuantizer] = None

BaseEstimator.__init__(self)

Expand Down Expand Up @@ -1768,8 +1768,6 @@ def fit(self, X: Data, y: Target, **fit_parameters):
# KNeighbors handles multi-labels data
X, y = check_X_y_and_assert_multi_output(X, y)

self._y = numpy.array(y)

# Fit the scikit-learn model
self._fit_sklearn_model(X, y, **fit_parameters)

Expand All @@ -1785,28 +1783,30 @@ def fit(self, X: Data, y: Target, **fit_parameters):
input_quantizer = q_inputs.quantizer
self.input_quantizers.append(input_quantizer)

# Quantize the _X_fit and store the associated quantizer
# Quantize the _fit_X and store the associated quantizer
# pylint: disable-next=protected-access
_X_fit = self.sklearn_model._fit_X
# We assume that the inputs have the same distribution as the _X_fit
q_X_fit = QuantizedArray(
_fit_X = self.sklearn_model._fit_X
# We assume that the inputs have the same distribution as the _fit_X
q_fit_X = QuantizedArray(
n_bits=self.n_bits,
values=numpy.expand_dims(_X_fit, axis=1) if len(_X_fit.shape) == 1 else _X_fit,
values=numpy.expand_dims(_fit_X, axis=1) if len(_fit_X.shape) == 1 else _fit_X,
options=input_options,
)
self._q_X_fit = q_X_fit.qvalues
self._q_X_fit_quantizer = q_X_fit.quantizer
self._q_fit_X = q_fit_X.qvalues
self._q_fit_X_quantizer = q_fit_X.quantizer

# mypy
assert self._q_X_fit_quantizer.scale is not None
assert self._q_fit_X_quantizer.scale is not None

self._y = numpy.array(y)

# We assume that the query has the same distribution as the data in _X_fit.
# therefore, they use the same scaling and zero point.
# https://arxiv.org/abs/1712.05877

self.output_quant_params = UniformQuantizationParameters(
scale=self._q_X_fit_quantizer.scale,
zero_point=self._q_X_fit_quantizer.zero_point,
scale=self._q_fit_X_quantizer.scale,
zero_point=self._q_fit_X_quantizer.zero_point,
offset=0,
)

Expand Down Expand Up @@ -1879,15 +1879,15 @@ def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray:
Returns:
numpy.ndarray: The quantized predicted values.
"""
assert self._q_X_fit_quantizer is not None, self._is_not_fitted_error_message()
assert self._q_fit_X_quantizer is not None, self._is_not_fitted_error_message()

def pairwise_euclidean_distance(q_X):
# 1. Pairwise euclidean distance
# dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))
return (
numpy.sum(q_X**2, axis=1, keepdims=True)
- 2 * q_X @ self._q_X_fit.T
+ numpy.expand_dims(numpy.sum(self._q_X_fit**2, axis=1), 0)
- 2 * q_X @ self._q_fit_X.T
+ numpy.expand_dims(numpy.sum(self._q_fit_X**2, axis=1), 0)
)

def topk_sorting(x, labels):
Expand All @@ -1896,7 +1896,8 @@ def topk_sorting(x, labels):
Time complexity: O(nlog²(k))
Args:
x (numpy.ndarray): The quantized input values.
x (numpy.ndarray): The quantized input values
labels (numpy.ndarray): The labels of the training data-set
Returns:
numpy.ndarray: The argsort.
Expand Down Expand Up @@ -1982,10 +1983,10 @@ def scatter1d(x, v, indices):
x = scatter1d(x, max_x, range_i + d)

# Max index selection
sign = diff <= 0
is_a_greater_than_b = diff <= 0

# Update labels array according to the max items
max_labels = labels_a + (labels_b - labels_a) * sign
max_labels = labels_a + (labels_b - labels_a) * is_a_greater_than_b
labels = scatter1d(labels, labels_a + labels_b - max_labels, range_i)
labels = scatter1d(labels, max_labels, range_i + d)

Expand All @@ -2002,19 +2003,13 @@ def scatter1d(x, v, indices):
return fhe_array(topk_labels)

# 1. Pairwise_euclidiean distance
# from concrete import fhe
# with fhe.tag(f"distance_matrix"):
distance_matrix = pairwise_euclidean_distance(q_X)

# The square root in the Euclidean distance calculation is not applied to speed up FHE
# computations.
# Being a monotonic function, it does not affect the logic of the calculation, notably for
# the argsort.

# 2. Sorting args
# with fhe.tag(f"sorted_args"):

# pylint: disable-next=protected-access
topk_labels = topk_sorting(distance_matrix.flatten(), self._y)

return numpy.expand_dims(topk_labels, axis=0)
Expand All @@ -2038,17 +2033,34 @@ def compile(self, *args, **kwargs) -> Circuit:

return BaseEstimator.compile(self, *args, **kwargs)

def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
"""Perform the majority.
For KNN, the de-quantization step is not required. Because _inference returns the label of
the k-nearest neighbors.
Args:
y_preds (numpy.ndarray): The topk nearest labels
Returns:
numpy.ndarray: The majority vote.
"""
y_preds_processed = []
for y in y_preds:
vote = self.majority_vote(y.flatten())
y_preds_processed.append(vote)

return numpy.array(y_preds_processed)

def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:

X = check_array_and_assert(X)

y_preds = []
topk_labels = []
for query in X:
# Argsort
topk_labels = super().predict(query[None], fhe)
# Majority vote
y_pred = self.majority_vote(topk_labels.flatten())
y_preds.append(y_pred)
topk_labels.append(super().predict(query[None], fhe))

y_preds = self.post_processing(numpy.array(topk_labels))

return numpy.array(y_preds)

Expand Down
10 changes: 5 additions & 5 deletions src/concrete/ml/sklearn/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.weights = weights

def dump_dict(self) -> Dict[str, Any]:
assert self._q_X_fit_quantizer is not None, self._is_not_fitted_error_message()
assert self._q_fit_X_quantizer is not None, self._is_not_fitted_error_message()

metadata: Dict[str, Any] = {}

Expand All @@ -71,8 +71,8 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["_is_fitted"] = self._is_fitted
metadata["_is_compiled"] = self._is_compiled
metadata["input_quantizers"] = self.input_quantizers
metadata["_q_X_fit_quantizer"] = self._q_X_fit_quantizer
metadata["_q_X_fit"] = self._q_X_fit
metadata["_q_fit_X_quantizer"] = self._q_fit_X_quantizer
metadata["_q_fit_X"] = self._q_fit_X
metadata["_y"] = self._y

metadata["output_quantizers"] = self.output_quantizers
Expand Down Expand Up @@ -106,8 +106,8 @@ def load_dict(cls, metadata: Dict):
obj._is_compiled = metadata["_is_compiled"]
obj.input_quantizers = metadata["input_quantizers"]
obj.output_quantizers = metadata["output_quantizers"]
obj._q_X_fit_quantizer = metadata["_q_X_fit_quantizer"]
obj._q_X_fit = metadata["_q_X_fit"]
obj._q_fit_X_quantizer = metadata["_q_fit_X_quantizer"]
obj._q_fit_X = metadata["_q_fit_X"]
obj._y = metadata["_y"]

obj.onnx_model_ = metadata["onnx_model_"]
Expand Down
24 changes: 20 additions & 4 deletions tests/deployment/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@

from concrete.ml.deployment.fhe_client_server import FHEModelClient, FHEModelDev, FHEModelServer
from concrete.ml.pytest.torch_models import FCSmall
from concrete.ml.pytest.utils import instantiate_model_generic, sklearn_models_and_datasets
from concrete.ml.pytest.utils import (
get_model_name,
instantiate_model_generic,
sklearn_models_and_datasets,
)
from concrete.ml.quantization.quantized_module import QuantizedModule
from concrete.ml.torch.compile import compile_torch_model

# pylint: disable=too-many-statements
# pylint: disable=too-many-statements,too-many-locals


class OnDiskNetwork:
Expand Down Expand Up @@ -67,7 +71,7 @@ def cleanup(self):


@pytest.mark.parametrize("model_class, parameters", sklearn_models_and_datasets)
@pytest.mark.parametrize("n_bits", [2])
@pytest.mark.parametrize("n_bits", [3])
def test_client_server_sklearn(
default_configuration,
model_class,
Expand Down Expand Up @@ -99,10 +103,17 @@ def test_client_server_sklearn(
with pytest.raises(AttributeError, match=".* model is not compiled.*"):
client_server_simulation(x_train, x_test, model, default_configuration)

# With n_bits = 3, KNN is not compilable
fhe_circuit = model.compile(
x_train, default_configuration, **extra_params, show_mlir=(n_bits <= 8)
)

if get_model_name(model) == "KNeighborsClassifier":
# Fit the model
with warnings.catch_warnings():
# Sometimes, we miss convergence, which is not a problem for our test
warnings.simplefilter("ignore", category=ConvergenceWarning)
model.fit(x, y)

max_bit_width = fhe_circuit.graph.maximum_integer_bit_width()
print(f"Max width {max_bit_width}")

Expand Down Expand Up @@ -259,5 +270,10 @@ def client_server_simulation(x_train, x_test, model, default_configuration):
y_pred_on_client_dequantized, y_pred_model_server_ds_dequantized
)

# Make sure the clear predictions are the same for the server
if get_model_name(model) == "KNeighborsClassifier":
y_pred_model_clear = model.predict(x_test, fhe="disable")
numpy.testing.assert_array_equal(y_pred_model_clear, y_pred_model_server_ds_dequantized)

# Clean up
network.cleanup()
3 changes: 2 additions & 1 deletion tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,8 @@ def check_for_divergent_predictions(x, model, fhe, max_iterations=N_ALLOWED_FHE_
predict_function = (
model.predict_proba
if is_classifier_or_partial_classifier(model)
# predict_prob not implemented yet for KNeighborsClassifier
# `predict_prob` not implemented yet for KNeighborsClassifier
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3962
and get_model_name(model) != "KNeighborsClassifier"
else model.predict
)
Expand Down

0 comments on commit fd2c1c7

Please sign in to comment.