-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/integrate fhe sum and quantized dict for input and leaves #449
Changes from 37 commits
e284bd2
8c1e99a
70f1775
6262f12
67e3572
8e2e056
3e2289a
a9c8385
a43f04c
ecf5c66
9007362
8fab929
ba26a5c
05256b2
03d498b
cf879d3
ff1c6b1
d4ca140
f96333b
cc4781f
a652001
9839ce9
7cb13e0
7d93575
70adfd5
783e7af
39b2972
07b2f2a
7fddece
14d9dc0
3248216
ab45587
0ad0f6d
9b58948
25bf839
5be7255
ee696ec
1c4ebc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,7 +49,13 @@ | |
# The sigmoid and softmax functions are already defined in the ONNX module and thus are imported | ||
# here in order to avoid duplicating them. | ||
from ..onnx.ops_impl import numpy_sigmoid, numpy_softmax | ||
from ..quantization import PostTrainingQATImporter, QuantizedArray, get_n_bits_dict | ||
from ..quantization import ( | ||
PostTrainingQATImporter, | ||
QuantizedArray, | ||
_get_n_bits_dict_trees, | ||
_inspect_tree_n_bits, | ||
get_n_bits_dict, | ||
) | ||
from ..quantization.quantized_module import QuantizedModule, _get_inputset_generator | ||
from ..quantization.quantizers import ( | ||
QuantizationOptions, | ||
|
@@ -96,7 +102,7 @@ | |
# Enable rounding feature for all tree-based models by default | ||
# Note: This setting is fixed and cannot be altered by users | ||
# However, for internal testing purposes, we retain the capability to disable this feature | ||
os.environ["TREES_USE_ROUNDING"] = "1" | ||
os.environ["TREES_USE_ROUNDING"] = os.environ.get("TREES_USE_ROUNDING", "1") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This env variable is for ROUNDING not for the sum of the outputs of the tree ensembles in fhe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still do we want to keep this env var? |
||
|
||
# pylint: disable=too-many-public-methods | ||
|
||
|
@@ -1281,17 +1287,32 @@ def __init_subclass__(cls): | |
_TREE_MODELS.add(cls) | ||
_ALL_SKLEARN_MODELS.add(cls) | ||
|
||
def __init__(self, n_bits: int): | ||
def __init__(self, n_bits: Union[int, Dict[str, int]]): | ||
"""Initialize the TreeBasedEstimatorMixin. | ||
|
||
Args: | ||
n_bits (int): The number of bits used for quantization. | ||
n_bits (int, Dict[str, int]): Number of bits to quantize the model. If an int is passed | ||
for n_bits, the value will be used for quantizing inputs and leaves. If a dict is | ||
passed, then it should contain "op_inputs" and "op_leaves" as keys with | ||
corresponding number of quantization bits so that: | ||
- op_inputs (mandatory): number of bits to quantize the input values | ||
- op_leaves (optional): number of bits to quantize the leaves | ||
Default to 6. | ||
""" | ||
self.n_bits: int = n_bits | ||
|
||
# Check if 'n_bits' is a valid value. | ||
_inspect_tree_n_bits(n_bits) | ||
|
||
self.n_bits: Union[int, Dict[str, int]] = n_bits | ||
RomanBredehoft marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#: The model's inference function. Is None if the model is not fitted. | ||
self._tree_inference: Optional[Callable] = None | ||
|
||
#: Wether to perform the sum of the output's tree ensembles in FHE or not. | ||
# By default, the decision of the tree ensembles is made in clear (not in FHE). | ||
# This attribute should not be modified by users. | ||
self._fhe_ensembling = False | ||
|
||
BaseEstimator.__init__(self) | ||
|
||
def fit(self, X: Data, y: Target, **fit_parameters): | ||
|
@@ -1304,9 +1325,14 @@ def fit(self, X: Data, y: Target, **fit_parameters): | |
|
||
q_X = numpy.zeros_like(X) | ||
|
||
# Convert the n_bits attribute into a proper dictionary | ||
self.n_bits = _get_n_bits_dict_trees(self.n_bits) | ||
|
||
# Quantization of each feature in X | ||
for i in range(X.shape[1]): | ||
input_quantizer = QuantizedArray(n_bits=self.n_bits, values=X[:, i]).quantizer | ||
input_quantizer = QuantizedArray( | ||
n_bits=self.n_bits["op_inputs"], values=X[:, i] | ||
).quantizer | ||
self.input_quantizers.append(input_quantizer) | ||
q_X[:, i] = input_quantizer.quant(X[:, i]) | ||
|
||
|
@@ -1319,7 +1345,7 @@ def fit(self, X: Data, y: Target, **fit_parameters): | |
# Check that the underlying sklearn model has been set and fit | ||
assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message() | ||
|
||
# Convert the tree inference with Numpy operators | ||
# Enable rounding feature | ||
enable_rounding = os.environ.get("TREES_USE_ROUNDING", "1") == "1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still env var based There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one, is for rounding, not fhe_sum. If we decide to remove it as well, it should be in another PR. |
||
|
||
if not enable_rounding: | ||
|
@@ -1332,12 +1358,14 @@ def fit(self, X: Data, y: Target, **fit_parameters): | |
stacklevel=2, | ||
) | ||
|
||
# Convert the tree inference with Numpy operators | ||
self._tree_inference, self.output_quantizers, self.onnx_model_ = tree_to_numpy( | ||
self.sklearn_model, | ||
q_X, | ||
use_rounding=enable_rounding, | ||
fhe_ensembling=self._fhe_ensembling, | ||
framework=self.framework, | ||
output_n_bits=self.n_bits, | ||
output_n_bits=self.n_bits["op_leaves"], | ||
) | ||
|
||
self._is_fitted = True | ||
|
@@ -1412,10 +1440,13 @@ def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray: | |
# Sum all tree outputs | ||
# Remove the sum once we handle multi-precision circuits | ||
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451 | ||
y_preds = numpy.sum(y_preds, axis=-1) | ||
if not self._fhe_ensembling: | ||
y_preds = numpy.sum(y_preds, axis=-1) | ||
|
||
assert_true(y_preds.ndim == 2, "y_preds should be a 2D array") | ||
return y_preds | ||
assert_true(y_preds.ndim == 2, "y_preds should be a 2D array") | ||
return y_preds | ||
|
||
return super().post_processing(y_preds) | ||
|
||
|
||
class BaseTreeRegressorMixin(BaseTreeEstimatorMixin, sklearn.base.RegressorMixin, ABC): | ||
|
@@ -1841,12 +1872,15 @@ def __init__(self, n_bits: int = 3): | |
quantizing inputs and X_fit. Default to 3. | ||
""" | ||
self.n_bits: int = n_bits | ||
|
||
# _q_fit_X: In distance metric algorithms, `_q_fit_X` stores the training set to compute | ||
# the similarity or distance measures. There is no `weights` attribute because there isn't | ||
# a training phase | ||
self._q_fit_X: numpy.ndarray | ||
|
||
# _y: Labels of `_q_fit_X` | ||
self._y: numpy.ndarray | ||
|
||
# _q_fit_X_quantizer: The quantizer to use for quantizing the model's training set | ||
self._q_fit_X_quantizer: Optional[UniformQuantizer] = None | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the namings here. What about
n_bits_input
n_bits_output
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NNs have
op_inputs
op_outputs
. It keeps that style, I think it's fine like this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jfrery the reason is that QNNs have
model_inputs
,model_outputs
,op_inputs
andop_weights
, so you need to differentiate both. As Andrei said, here, it's taking the same conventionalthough I agree for trees it's getting q bit misleading without much context but we don't have much choice as we should keep our namings coherent with one another
my main comment on this naming was about
op_leave
(#449 (comment)), this introduces a new name into our convention. Although I get why it's more relevant to tree models, I felt keepingop_weights
was fine imoThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well I find op_ not ideal for tree based models. Also if we want to keep consistency then it should be model_inputs instead of op_inputs.
I am really not a fan of having op_leave for many reasons. Most users won't know what leave refer to. Same for op.
Then why not
model_input
andmodel_output
? It makes sense for the input but would require a change formodel_output
. This will help us maintain the output at a given precision which is not easy with op_leave since this targets every tree instead of the final combination.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you I'm not a fan of
op_leave
. but about havingmodel_input
andmodel_output
only, how are you going to handle the n_bits to use for quantizing the leaves ? you'll either need an additional parameter similar to the suggestedop_leave
or a way to infer its value based onmodel_output
but not sure how this could easily/efficiently be done 🤔There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zama-ai/machine-learning, please let's agree on a proper naming.
I suggest: tree_input_bits/tree_leaf_bits or quant_input_bits/quant_leaf_bits
-> model_ouput why not, but doesn't ring like it's the leaves that we are quantizing
-> op_weight, really not a big fan, we are quantizing leaves not weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's best if you ask directly in the slack channel !