Skip to content

Commit

Permalink
chore: add post_processing
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Sep 20, 2023
1 parent d5b6e46 commit 69f11f8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 37 deletions.
77 changes: 45 additions & 32 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,14 +1719,14 @@ def __init__(self, n_bits: int = 3):
quantizing inputs and X_fit. Default to 3.
"""
self.n_bits: int = n_bits
# _q_X_fit: In distance metric algorithms, `_q_X_fit` stores the training set to compute
# _q_fit_X: In distance metric algorithms, `_q_fit_X` stores the training set to compute
# the similarity or distance measures. There is no `weights` attribute because there isn't
# a training phase
self._q_X_fit: numpy.ndarray
# _y: Labels of `_q_X_fit`
self._q_fit_X: numpy.ndarray
# _y: Labels of `_q_fit_X`
self._y: numpy.ndarray
# _q_X_fit_quantizer: The quantizer to use for quantizing the model's training set
self._q_X_fit_quantizer: Optional[UniformQuantizer] = None
# _q_fit_X_quantizer: The quantizer to use for quantizing the model's training set
self._q_fit_X_quantizer: Optional[UniformQuantizer] = None

BaseEstimator.__init__(self)

Expand Down Expand Up @@ -1768,8 +1768,6 @@ def fit(self, X: Data, y: Target, **fit_parameters):
# KNeighbors handles multi-labels data
X, y = check_X_y_and_assert_multi_output(X, y)

self._y = numpy.array(y)

# Fit the scikit-learn model
self._fit_sklearn_model(X, y, **fit_parameters)

Expand All @@ -1785,28 +1783,30 @@ def fit(self, X: Data, y: Target, **fit_parameters):
input_quantizer = q_inputs.quantizer
self.input_quantizers.append(input_quantizer)

# Quantize the _X_fit and store the associated quantizer
# Quantize the _fit_X and store the associated quantizer
# pylint: disable-next=protected-access
_X_fit = self.sklearn_model._fit_X
# We assume that the inputs have the same distribution as the _X_fit
q_X_fit = QuantizedArray(
_fit_X = self.sklearn_model._fit_X
# We assume that the inputs have the same distribution as the _fit_X
q_fit_X = QuantizedArray(
n_bits=self.n_bits,
values=numpy.expand_dims(_X_fit, axis=1) if len(_X_fit.shape) == 1 else _X_fit,
values=numpy.expand_dims(_fit_X, axis=1) if len(_fit_X.shape) == 1 else _fit_X,
options=input_options,
)
self._q_X_fit = q_X_fit.qvalues
self._q_X_fit_quantizer = q_X_fit.quantizer
self._q_fit_X = q_fit_X.qvalues
self._q_fit_X_quantizer = q_fit_X.quantizer

# mypy
assert self._q_X_fit_quantizer.scale is not None
assert self._q_fit_X_quantizer.scale is not None

self._y = numpy.array(y)

# We assume that the query has the same distribution as the data in _X_fit.
# therefore, they use the same scaling and zero point.
# https://arxiv.org/abs/1712.05877

self.output_quant_params = UniformQuantizationParameters(
scale=self._q_X_fit_quantizer.scale,
zero_point=self._q_X_fit_quantizer.zero_point,
scale=self._q_fit_X_quantizer.scale,
zero_point=self._q_fit_X_quantizer.zero_point,
offset=0,
)

Expand Down Expand Up @@ -1879,15 +1879,15 @@ def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray:
Returns:
numpy.ndarray: The quantized predicted values.
"""
assert self._q_X_fit_quantizer is not None, self._is_not_fitted_error_message()
assert self._q_fit_X_quantizer is not None, self._is_not_fitted_error_message()

def pairwise_euclidean_distance(q_X):
# 1. Pairwise euclidean distance
# dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y))
return (
numpy.sum(q_X**2, axis=1, keepdims=True)
- 2 * q_X @ self._q_X_fit.T
+ numpy.expand_dims(numpy.sum(self._q_X_fit**2, axis=1), 0)
- 2 * q_X @ self._q_fit_X.T
+ numpy.expand_dims(numpy.sum(self._q_fit_X**2, axis=1), 0)
)

def topk_sorting(x, labels):
Expand All @@ -1896,7 +1896,8 @@ def topk_sorting(x, labels):
Time complexity: O(nlog²(k))
Args:
x (numpy.ndarray): The quantized input values.
x (numpy.ndarray): The quantized input values
labels (numpy.ndarray): The labels of the training data-set
Returns:
numpy.ndarray: The argsort.
Expand Down Expand Up @@ -2002,18 +2003,13 @@ def scatter1d(x, v, indices):
return fhe_array(topk_labels)

# 1. Pairwise_euclidiean distance
# from concrete import fhe
# with fhe.tag(f"distance_matrix"):
distance_matrix = pairwise_euclidean_distance(q_X)

# The square root in the Euclidean distance calculation is not applied to speed up FHE
# computations.
# Being a monotonic function, it does not affect the logic of the calculation, notably for
# the argsort.

# 2. Sorting args
# with fhe.tag(f"sorted_args"):

# pylint: disable-next=protected-access
topk_labels = topk_sorting(distance_matrix.flatten(), self._y)

Expand All @@ -2038,17 +2034,34 @@ def compile(self, *args, **kwargs) -> Circuit:

return BaseEstimator.compile(self, *args, **kwargs)

def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
"""Perform the majority.
For KNN, the dequantization step is not required. Because _inference returns the label of
the k-nearest neighbors.
Args:
y_preds (numpy.ndarray): The topk nearest labels
Returns:
numpy.ndarray: The majority vote.
"""
y_preds_processed = []
for y in y_preds:
vote = self.majority_vote(y.flatten())
y_preds_processed.append(vote)

return numpy.array(y_preds_processed)

def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:

X = check_array_and_assert(X)

y_preds = []
topk_labels = []
for query in X:
# Argsort
topk_labels = super().predict(query[None], fhe)
# Majority vote
y_pred = self.majority_vote(topk_labels.flatten())
y_preds.append(y_pred)
topk_labels.append(super().predict(query[None], fhe))

y_preds = self.post_processing(numpy.array(topk_labels))

return numpy.array(y_preds)

Expand Down
10 changes: 5 additions & 5 deletions src/concrete/ml/sklearn/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.weights = weights

def dump_dict(self) -> Dict[str, Any]:
assert self._q_X_fit_quantizer is not None, self._is_not_fitted_error_message()
assert self._q_fit_X_quantizer is not None, self._is_not_fitted_error_message()

metadata: Dict[str, Any] = {}

Expand All @@ -71,8 +71,8 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["_is_fitted"] = self._is_fitted
metadata["_is_compiled"] = self._is_compiled
metadata["input_quantizers"] = self.input_quantizers
metadata["_q_X_fit_quantizer"] = self._q_X_fit_quantizer
metadata["_q_X_fit"] = self._q_X_fit
metadata["_q_fit_X_quantizer"] = self._q_fit_X_quantizer
metadata["_q_fit_X"] = self._q_fit_X
metadata["_y"] = self._y

metadata["output_quantizers"] = self.output_quantizers
Expand Down Expand Up @@ -106,8 +106,8 @@ def load_dict(cls, metadata: Dict):
obj._is_compiled = metadata["_is_compiled"]
obj.input_quantizers = metadata["input_quantizers"]
obj.output_quantizers = metadata["output_quantizers"]
obj._q_X_fit_quantizer = metadata["_q_X_fit_quantizer"]
obj._q_X_fit = metadata["_q_X_fit"]
obj._q_fit_X_quantizer = metadata["_q_fit_X_quantizer"]
obj._q_fit_X = metadata["_q_fit_X"]
obj._y = metadata["_y"]

obj.onnx_model_ = metadata["onnx_model_"]
Expand Down

0 comments on commit 69f11f8

Please sign in to comment.