Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/integrate fhe sum and quantized dict for input and leaves #449

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
e284bd2
chore: run ensemble model aggregation in FHE
jfrery Dec 18, 2023
8c1e99a
chore: refresh notebooks
jfrery Dec 18, 2023
70f1775
chore: update celia
kcelia Dec 25, 2023
6262f12
chore: add op_input and op_leaves
kcelia Jan 11, 2024
67e3572
chore: restore non fhe computation
kcelia Jan 11, 2024
8e2e056
chore: update dump test
kcelia Jan 11, 2024
3e2289a
chore: update test dump
kcelia Jan 11, 2024
a9c8385
chore: fix pipeline test
kcelia Jan 12, 2024
a43f04c
chore: fix rounding test by decreasing the n_bits value because no cr…
kcelia Jan 12, 2024
ecf5c66
chore: reduce n_bits in simulation test to 4 bits otherwise OOM
kcelia Jan 12, 2024
9007362
chore: add a test for fhe sum
kcelia Jan 12, 2024
8fab929
chore: update
kcelia Jan 15, 2024
ba26a5c
chore: update
kcelia Jan 15, 2024
05256b2
chore: update
kcelia Jan 16, 2024
03d498b
chore: remove useless prints
kcelia Jan 16, 2024
cf879d3
chore: update get_n_bits_dict_trees
kcelia Jan 17, 2024
ff1c6b1
chore: update
kcelia Jan 17, 2024
d4ca140
chore: update comment
kcelia Jan 17, 2024
f96333b
chore: update simulated p_error test
kcelia Jan 17, 2024
cc4781f
chore: update coverage
kcelia Jan 17, 2024
a652001
chore: update tests
kcelia Jan 18, 2024
9839ce9
chore: update assert
kcelia Jan 18, 2024
7cb13e0
chore: update comment
kcelia Jan 22, 2024
7d93575
chore: update comment
kcelia Jan 22, 2024
70adfd5
chore: test dump in both cases (sum_fhe enabled and disabled)
kcelia Jan 22, 2024
783e7af
chore: remove env var
kcelia Jan 22, 2024
39b2972
chore: restore knn notebook
kcelia Jan 23, 2024
07b2f2a
chore: restore exp notebotebook
kcelia Jan 23, 2024
7fddece
chore: update v1
kcelia Jan 23, 2024
14d9dc0
chore: update v2
kcelia Jan 23, 2024
3248216
chore: update v3
kcelia Jan 23, 2024
ab45587
chore: update
kcelia Jan 23, 2024
0ad0f6d
chore: update comments
kcelia Jan 24, 2024
9b58948
chore: update
kcelia Jan 24, 2024
25bf839
chore: fix test dump
kcelia Jan 24, 2024
5be7255
chore: update comments
kcelia Jan 24, 2024
ee696ec
chore: remove comment
kcelia Jan 25, 2024
1c4ebc0
chore: update
kcelia Jan 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/concrete/ml/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Modules for quantization."""
from .base_quantized_op import QuantizedOp
from .post_training import PostTrainingAffineQuantization, PostTrainingQATImporter, get_n_bits_dict
from .post_training import (
PostTrainingAffineQuantization,
PostTrainingQATImporter,
_get_n_bits_dict_trees,
_inspect_tree_n_bits,
get_n_bits_dict,
)
from .quantized_module import QuantizedModule
from .quantized_ops import (
QuantizedAbs,
Expand Down
95 changes: 95 additions & 0 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,101 @@
from .quantized_ops import QuantizedBrevitasQuant
from .quantizers import QuantizationOptions, QuantizedArray, UniformQuantizer

# pylint: disable=too-many-lines


def _inspect_tree_n_bits(n_bits):
"""Validate the 'n_bits' parameter for tree-based models.

This function checks whether 'n_bits' is a valid integer or dictionary.
- If 'n_bits' is an integer, it must be a non-null positive, its value is assigned to
'op_inputs' and 'op_leaves' bits
- If it is a dictionary, it should contain integer values for keys 'op_leaves' and 'op_inputs',
where 'op_leaves' should not exceed 'op_inputs'.
Comment on lines +33 to +36
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the namings here. What about n_bits_input n_bits_output?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NNs have op_inputs op_outputs. It keeps that style, I think it's fine like this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jfrery the reason is that QNNs have model_inputs, model_outputs , op_inputs and op_weights, so you need to differentiate both. As Andrei said, here, it's taking the same convention

although I agree for trees it's getting q bit misleading without much context but we don't have much choice as we should keep our namings coherent with one another

my main comment on this naming was about op_leave (#449 (comment)), this introduces a new name into our convention. Although I get why it's more relevant to tree models, I felt keeping op_weights was fine imo

Copy link
Collaborator

@jfrery jfrery Jan 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I find op_ not ideal for tree based models. Also if we want to keep consistency then it should be model_inputs instead of op_inputs.

I am really not a fan of having op_leave for many reasons. Most users won't know what leave refer to. Same for op.

Then why not model_input and model_output? It makes sense for the input but would require a change for model_output. This will help us maintain the output at a given precision which is not easy with op_leave since this targets every tree instead of the final combination.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you I'm not a fan of op_leave. but about having model_input and model_output only, how are you going to handle the n_bits to use for quantizing the leaves ? you'll either need an additional parameter similar to the suggested op_leave or a way to infer its value based on model_output but not sure how this could easily/efficiently be done 🤔

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zama-ai/machine-learning, please let's agree on a proper naming.
I suggest: tree_input_bits/tree_leaf_bits or quant_input_bits/quant_leaf_bits

-> model_ouput why not, but doesn't ring like it's the leaves that we are quantizing
-> op_weight, really not a big fan, we are quantizing leaves not weights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best if you ask directly in the slack channel !


The function raises a ValueError with a descriptive message if 'n_bits' does not meet
these criteria.

Args:
n_bits (int, Dict[str, int]): number of bits for quantization, can be a single value or
a dictionary with the following keys :
- "op_inputs" (mandatory): number of bits to quantize the input values
- "op_leaves" (optional): number of bits to quantize the leaves, must be less than or
equal to 'op_inputs. defaults to the value of 'op_inputs if not specified.

Raises:
ValueError: If 'n_bits' does not conform to the required format or value constraints.
"""

detailed_message = (
"Invalid 'n_bits', either pass a strictly positive integer or a dictionary containing "
"integer values for the following keys:\n"
"- 'op_inputs' (mandatory): number of bits to quantize the input values\n"
"- 'op_leaves' (optional): number of bits to quantize the leaves, must be less than or "
"equal to 'op_inputs'. Defaults to the value of 'op_inputs' if not specified."
"When using a single integer for n_bits, its value is assigned to 'op_inputs' and "
"'op_leaves' bits.\n"
)

error_message = ""

if isinstance(n_bits, int):
if n_bits <= 0:
error_message = "n_bits must be a strictly positive integer"
elif isinstance(n_bits, dict):
if "op_inputs" not in n_bits.keys():
error_message = "Invalid keys in `n_bits` dictionary. The key 'op_inputs' is mandatory"
elif set(n_bits.keys()) - {"op_leaves", "op_inputs"}:
error_message = (
"Invalid keys in 'n_bits' dictionary. Only 'op_inputs' (mandatory) and 'op_leaves' "
"(optional) are allowed"
)
elif not all(isinstance(value, int) and value > 0 for value in n_bits.values()):
error_message = "All values in 'n_bits' dictionary must be strictly positive integers"

elif n_bits.get("op_leaves", 0) > n_bits.get("op_inputs", 0):
error_message = "'op_leaves' must be less than or equal to 'op_inputs'"
else:
error_message = "n_bits must be either an integer or a dictionary"

if len(error_message) > 0:
raise ValueError(
f"{error_message}. Got '{type(n_bits)}' and '{n_bits}' value.\n{detailed_message}"
)


# Find a better naming to describe leaf quantization in tree-based models
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4258
def _get_n_bits_dict_trees(n_bits: Union[int, Dict[str, int]]) -> Dict[str, int]:
"""Convert the n_bits parameter into a proper dictionary for tree based-models.

Args:
n_bits (int, Dict[str, int]): number of bits for quantization, can be a single value or
a dictionary with the following keys :
- "op_inputs" (mandatory): number of bits to quantize the input values
- "op_leaves" (optional): number of bits to quantize the leaves, must be less than or
equal to 'op_inputs'. defaults to the value of "op_inputs" if not specified.

When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_leaves" bits.

Returns:
n_bits_dict (Dict[str, int]): A dictionary properly representing the number of bits to use
for quantization.
"""

_inspect_tree_n_bits(n_bits)

# If a single integer is passed, we use a default value for the model's input and leaves
if isinstance(n_bits, int):
return {"op_inputs": n_bits, "op_leaves": n_bits}

# Default 'op_leaves' to 'op_inputs' if not specified
if "op_leaves" not in n_bits:
n_bits["op_leaves"] = n_bits["op_inputs"]

return n_bits


def get_n_bits_dict(n_bits: Union[int, Dict[str, int]]) -> Dict[str, int]:
"""Convert the n_bits parameter into a proper dictionary.
Expand Down
56 changes: 45 additions & 11 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@
# The sigmoid and softmax functions are already defined in the ONNX module and thus are imported
# here in order to avoid duplicating them.
from ..onnx.ops_impl import numpy_sigmoid, numpy_softmax
from ..quantization import PostTrainingQATImporter, QuantizedArray, get_n_bits_dict
from ..quantization import (
PostTrainingQATImporter,
QuantizedArray,
_get_n_bits_dict_trees,
_inspect_tree_n_bits,
get_n_bits_dict,
)
from ..quantization.quantized_module import QuantizedModule, _get_inputset_generator
from ..quantization.quantizers import (
QuantizationOptions,
Expand Down Expand Up @@ -96,7 +102,7 @@
# Enable rounding feature for all tree-based models by default
# Note: This setting is fixed and cannot be altered by users
# However, for internal testing purposes, we retain the capability to disable this feature
os.environ["TREES_USE_ROUNDING"] = "1"
os.environ["TREES_USE_ROUNDING"] = os.environ.get("TREES_USE_ROUNDING", "1")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This env variable is for ROUNDING not for the sum of the outputs of the tree ensembles in fhe

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still do we want to keep this env var?


# pylint: disable=too-many-public-methods

Expand Down Expand Up @@ -1281,17 +1287,32 @@ def __init_subclass__(cls):
_TREE_MODELS.add(cls)
_ALL_SKLEARN_MODELS.add(cls)

def __init__(self, n_bits: int):
def __init__(self, n_bits: Union[int, Dict[str, int]]):
"""Initialize the TreeBasedEstimatorMixin.

Args:
n_bits (int): The number of bits used for quantization.
n_bits (int, Dict[str, int]): Number of bits to quantize the model. If an int is passed
for n_bits, the value will be used for quantizing inputs and leaves. If a dict is
passed, then it should contain "op_inputs" and "op_leaves" as keys with
corresponding number of quantization bits so that:
- op_inputs (mandatory): number of bits to quantize the input values
- op_leaves (optional): number of bits to quantize the leaves
Default to 6.
"""
self.n_bits: int = n_bits

# Check if 'n_bits' is a valid value.
_inspect_tree_n_bits(n_bits)

self.n_bits: Union[int, Dict[str, int]] = n_bits
RomanBredehoft marked this conversation as resolved.
Show resolved Hide resolved

#: The model's inference function. Is None if the model is not fitted.
self._tree_inference: Optional[Callable] = None

#: Wether to perform the sum of the output's tree ensembles in FHE or not.
# By default, the decision of the tree ensembles is made in clear (not in FHE).
# This attribute should not be modified by users.
self._fhe_ensembling = False

BaseEstimator.__init__(self)

def fit(self, X: Data, y: Target, **fit_parameters):
Expand All @@ -1304,9 +1325,14 @@ def fit(self, X: Data, y: Target, **fit_parameters):

q_X = numpy.zeros_like(X)

# Convert the n_bits attribute into a proper dictionary
self.n_bits = _get_n_bits_dict_trees(self.n_bits)

# Quantization of each feature in X
for i in range(X.shape[1]):
input_quantizer = QuantizedArray(n_bits=self.n_bits, values=X[:, i]).quantizer
input_quantizer = QuantizedArray(
n_bits=self.n_bits["op_inputs"], values=X[:, i]
).quantizer
self.input_quantizers.append(input_quantizer)
q_X[:, i] = input_quantizer.quant(X[:, i])

Expand All @@ -1319,7 +1345,7 @@ def fit(self, X: Data, y: Target, **fit_parameters):
# Check that the underlying sklearn model has been set and fit
assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message()

# Convert the tree inference with Numpy operators
# Enable rounding feature
enable_rounding = os.environ.get("TREES_USE_ROUNDING", "1") == "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still env var based

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one, is for rounding, not fhe_sum.

If we decide to remove it as well, it should be in another PR.


if not enable_rounding:
Expand All @@ -1332,12 +1358,14 @@ def fit(self, X: Data, y: Target, **fit_parameters):
stacklevel=2,
)

# Convert the tree inference with Numpy operators
self._tree_inference, self.output_quantizers, self.onnx_model_ = tree_to_numpy(
self.sklearn_model,
q_X,
use_rounding=enable_rounding,
fhe_ensembling=self._fhe_ensembling,
framework=self.framework,
output_n_bits=self.n_bits,
output_n_bits=self.n_bits["op_leaves"],
)

self._is_fitted = True
Expand Down Expand Up @@ -1412,10 +1440,13 @@ def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
# Sum all tree outputs
# Remove the sum once we handle multi-precision circuits
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451
y_preds = numpy.sum(y_preds, axis=-1)
if not self._fhe_ensembling:
y_preds = numpy.sum(y_preds, axis=-1)

assert_true(y_preds.ndim == 2, "y_preds should be a 2D array")
return y_preds
assert_true(y_preds.ndim == 2, "y_preds should be a 2D array")
return y_preds

return super().post_processing(y_preds)


class BaseTreeRegressorMixin(BaseTreeEstimatorMixin, sklearn.base.RegressorMixin, ABC):
Expand Down Expand Up @@ -1841,12 +1872,15 @@ def __init__(self, n_bits: int = 3):
quantizing inputs and X_fit. Default to 3.
"""
self.n_bits: int = n_bits

# _q_fit_X: In distance metric algorithms, `_q_fit_X` stores the training set to compute
# the similarity or distance measures. There is no `weights` attribute because there isn't
# a training phase
self._q_fit_X: numpy.ndarray

# _y: Labels of `_q_fit_X`
self._y: numpy.ndarray

# _q_fit_X_quantizer: The quantizer to use for quantizing the model's training set
self._q_fit_X_quantizer: Optional[UniformQuantizer] = None

Expand Down
16 changes: 11 additions & 5 deletions src/concrete/ml/sklearn/rf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Implement RandomForest models."""
from typing import Any, Dict
from typing import Any, Dict, Union

import numpy
import sklearn.ensemble
Expand All @@ -19,7 +19,7 @@ class RandomForestClassifier(BaseTreeClassifierMixin):
# pylint: disable-next=too-many-arguments
def __init__(
self,
n_bits: int = 6,
n_bits: Union[int, Dict[str, int]] = 6,
n_estimators=20,
criterion="gini",
max_depth=4,
Expand Down Expand Up @@ -84,6 +84,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["onnx_model_"] = self.onnx_model_
metadata["framework"] = self.framework
metadata["post_processing_params"] = self.post_processing_params
metadata["_fhe_ensembling"] = self._fhe_ensembling
kcelia marked this conversation as resolved.
Show resolved Hide resolved

# Scikit-Learn
metadata["n_estimators"] = self.n_estimators
Expand Down Expand Up @@ -120,11 +121,13 @@ def load_dict(cls, metadata: Dict):
obj.framework = metadata["framework"]
obj.onnx_model_ = metadata["onnx_model_"]
obj.output_quantizers = metadata["output_quantizers"]
obj._fhe_ensembling = metadata["_fhe_ensembling"]
obj._tree_inference = tree_to_numpy(
obj.sklearn_model,
numpy.zeros((len(obj.input_quantizers),))[None, ...],
framework=obj.framework,
output_n_bits=obj.n_bits,
output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits,
RomanBredehoft marked this conversation as resolved.
Show resolved Hide resolved
RomanBredehoft marked this conversation as resolved.
Show resolved Hide resolved
fhe_ensembling=obj._fhe_ensembling,
)[0]
obj.post_processing_params = metadata["post_processing_params"]

Expand Down Expand Up @@ -162,7 +165,7 @@ class RandomForestRegressor(BaseTreeRegressorMixin):
# pylint: disable-next=too-many-arguments
def __init__(
self,
n_bits: int = 6,
n_bits: Union[int, Dict[str, int]] = 6,
n_estimators=20,
criterion="squared_error",
max_depth=4,
Expand Down Expand Up @@ -219,6 +222,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["onnx_model_"] = self.onnx_model_
metadata["framework"] = self.framework
metadata["post_processing_params"] = self.post_processing_params
metadata["_fhe_ensembling"] = self._fhe_ensembling

# Scikit-Learn
metadata["n_estimators"] = self.n_estimators
Expand Down Expand Up @@ -255,11 +259,13 @@ def load_dict(cls, metadata: Dict):
obj.framework = metadata["framework"]
obj.onnx_model_ = metadata["onnx_model_"]
obj.output_quantizers = metadata["output_quantizers"]
obj._fhe_ensembling = metadata["_fhe_ensembling"]
obj._tree_inference = tree_to_numpy(
obj.sklearn_model,
numpy.zeros((len(obj.input_quantizers),))[None, ...],
framework=obj.framework,
output_n_bits=obj.n_bits,
output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits,
fhe_ensembling=obj._fhe_ensembling,
)[0]
obj.post_processing_params = metadata["post_processing_params"]

Expand Down
17 changes: 12 additions & 5 deletions src/concrete/ml/sklearn/tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Implement DecisionTree models."""
from typing import Any, Dict
from typing import Any, Dict, Union

import numpy
import sklearn.tree
Expand Down Expand Up @@ -31,7 +31,7 @@ def __init__(
min_impurity_decrease=0.0,
class_weight=None,
ccp_alpha: float = 0.0,
n_bits: int = 6,
n_bits: Union[int, Dict[str, int]] = 6,
):
"""Initialize the DecisionTreeClassifier.

Expand Down Expand Up @@ -84,6 +84,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["onnx_model_"] = self.onnx_model_
metadata["framework"] = self.framework
metadata["post_processing_params"] = self.post_processing_params
metadata["_fhe_ensembling"] = self._fhe_ensembling

# Scikit-Learn
metadata["criterion"] = self.criterion
Expand Down Expand Up @@ -115,11 +116,13 @@ def load_dict(cls, metadata: Dict):
obj.framework = metadata["framework"]
obj.onnx_model_ = metadata["onnx_model_"]
obj.output_quantizers = metadata["output_quantizers"]
obj._fhe_ensembling = metadata["_fhe_ensembling"]
obj._tree_inference = tree_to_numpy(
obj.sklearn_model,
numpy.zeros((len(obj.input_quantizers),))[None, ...],
framework=obj.framework,
output_n_bits=obj.n_bits,
output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits,
fhe_ensembling=obj._fhe_ensembling,
)[0]
obj.post_processing_params = metadata["post_processing_params"]

Expand Down Expand Up @@ -162,7 +165,7 @@ def __init__(
max_leaf_nodes=None,
min_impurity_decrease=0.0,
ccp_alpha=0.0,
n_bits: int = 6,
n_bits: Union[int, Dict[str, int]] = 6,
):
"""Initialize the DecisionTreeRegressor.

Expand Down Expand Up @@ -208,6 +211,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["onnx_model_"] = self.onnx_model_
metadata["framework"] = self.framework
metadata["post_processing_params"] = self.post_processing_params
metadata["_fhe_ensembling"] = self._fhe_ensembling

# Scikit-Learn
metadata["criterion"] = self.criterion
Expand All @@ -233,16 +237,19 @@ def load_dict(cls, metadata: Dict):
# Concrete-ML
obj.sklearn_model = metadata["sklearn_model"]
obj._is_fitted = metadata["_is_fitted"]
obj._fhe_ensembling = metadata["_fhe_ensembling"]
obj._is_compiled = metadata["_is_compiled"]
obj.input_quantizers = metadata["input_quantizers"]
obj.framework = metadata["framework"]
obj.onnx_model_ = metadata["onnx_model_"]
obj.output_quantizers = metadata["output_quantizers"]
obj._fhe_ensembling = metadata["_fhe_ensembling"]
obj._tree_inference = tree_to_numpy(
obj.sklearn_model,
numpy.zeros((len(obj.input_quantizers),))[None, ...],
framework=obj.framework,
output_n_bits=obj.n_bits,
output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits,
fhe_ensembling=obj._fhe_ensembling,
)[0]
obj.post_processing_params = metadata["post_processing_params"]

Expand Down
Loading
Loading