diff --git a/src/concrete/ml/quantization/__init__.py b/src/concrete/ml/quantization/__init__.py index 845b5dc11..f9c94793e 100644 --- a/src/concrete/ml/quantization/__init__.py +++ b/src/concrete/ml/quantization/__init__.py @@ -1,6 +1,12 @@ """Modules for quantization.""" from .base_quantized_op import QuantizedOp -from .post_training import PostTrainingAffineQuantization, PostTrainingQATImporter, get_n_bits_dict +from .post_training import ( + PostTrainingAffineQuantization, + PostTrainingQATImporter, + _get_n_bits_dict_trees, + _inspect_tree_n_bits, + get_n_bits_dict, +) from .quantized_module import QuantizedModule from .quantized_ops import ( QuantizedAbs, diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index 9389ab05f..022508507 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -23,6 +23,101 @@ from .quantized_ops import QuantizedBrevitasQuant from .quantizers import QuantizationOptions, QuantizedArray, UniformQuantizer +# pylint: disable=too-many-lines + + +def _inspect_tree_n_bits(n_bits): + """Validate the 'n_bits' parameter for tree-based models. + + This function checks whether 'n_bits' is a valid integer or dictionary. + - If 'n_bits' is an integer, it must be a non-null positive, its value is assigned to + 'op_inputs' and 'op_leaves' bits + - If it is a dictionary, it should contain integer values for keys 'op_leaves' and 'op_inputs', + where 'op_leaves' should not exceed 'op_inputs'. + + The function raises a ValueError with a descriptive message if 'n_bits' does not meet + these criteria. + + Args: + n_bits (int, Dict[str, int]): number of bits for quantization, can be a single value or + a dictionary with the following keys : + - "op_inputs" (mandatory): number of bits to quantize the input values + - "op_leaves" (optional): number of bits to quantize the leaves, must be less than or + equal to 'op_inputs. defaults to the value of 'op_inputs if not specified. + + Raises: + ValueError: If 'n_bits' does not conform to the required format or value constraints. + """ + + detailed_message = ( + "Invalid 'n_bits', either pass a strictly positive integer or a dictionary containing " + "integer values for the following keys:\n" + "- 'op_inputs' (mandatory): number of bits to quantize the input values\n" + "- 'op_leaves' (optional): number of bits to quantize the leaves, must be less than or " + "equal to 'op_inputs'. Defaults to the value of 'op_inputs' if not specified." + "When using a single integer for n_bits, its value is assigned to 'op_inputs' and " + "'op_leaves' bits.\n" + ) + + error_message = "" + + if isinstance(n_bits, int): + if n_bits <= 0: + error_message = "n_bits must be a strictly positive integer" + elif isinstance(n_bits, dict): + if "op_inputs" not in n_bits.keys(): + error_message = "Invalid keys in `n_bits` dictionary. The key 'op_inputs' is mandatory" + elif set(n_bits.keys()) - {"op_leaves", "op_inputs"}: + error_message = ( + "Invalid keys in 'n_bits' dictionary. Only 'op_inputs' (mandatory) and 'op_leaves' " + "(optional) are allowed" + ) + elif not all(isinstance(value, int) and value > 0 for value in n_bits.values()): + error_message = "All values in 'n_bits' dictionary must be strictly positive integers" + + elif n_bits.get("op_leaves", 0) > n_bits.get("op_inputs", 0): + error_message = "'op_leaves' must be less than or equal to 'op_inputs'" + else: + error_message = "n_bits must be either an integer or a dictionary" + + if len(error_message) > 0: + raise ValueError( + f"{error_message}. Got '{type(n_bits)}' and '{n_bits}' value.\n{detailed_message}" + ) + + +# Find a better naming to describe leaf quantization in tree-based models +# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4258 +def _get_n_bits_dict_trees(n_bits: Union[int, Dict[str, int]]) -> Dict[str, int]: + """Convert the n_bits parameter into a proper dictionary for tree based-models. + + Args: + n_bits (int, Dict[str, int]): number of bits for quantization, can be a single value or + a dictionary with the following keys : + - "op_inputs" (mandatory): number of bits to quantize the input values + - "op_leaves" (optional): number of bits to quantize the leaves, must be less than or + equal to 'op_inputs'. defaults to the value of "op_inputs" if not specified. + + When using a single integer for n_bits, its value is assigned to "op_inputs" and + "op_leaves" bits. + + Returns: + n_bits_dict (Dict[str, int]): A dictionary properly representing the number of bits to use + for quantization. + """ + + _inspect_tree_n_bits(n_bits) + + # If a single integer is passed, we use a default value for the model's input and leaves + if isinstance(n_bits, int): + return {"op_inputs": n_bits, "op_leaves": n_bits} + + # Default 'op_leaves' to 'op_inputs' if not specified + if "op_leaves" not in n_bits: + n_bits["op_leaves"] = n_bits["op_inputs"] + + return n_bits + def get_n_bits_dict(n_bits: Union[int, Dict[str, int]]) -> Dict[str, int]: """Convert the n_bits parameter into a proper dictionary. diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index d1275c130..6ffdb8e2a 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -49,7 +49,13 @@ # The sigmoid and softmax functions are already defined in the ONNX module and thus are imported # here in order to avoid duplicating them. from ..onnx.ops_impl import numpy_sigmoid, numpy_softmax -from ..quantization import PostTrainingQATImporter, QuantizedArray, get_n_bits_dict +from ..quantization import ( + PostTrainingQATImporter, + QuantizedArray, + _get_n_bits_dict_trees, + _inspect_tree_n_bits, + get_n_bits_dict, +) from ..quantization.quantized_module import QuantizedModule, _get_inputset_generator from ..quantization.quantizers import ( QuantizationOptions, @@ -96,7 +102,7 @@ # Enable rounding feature for all tree-based models by default # Note: This setting is fixed and cannot be altered by users # However, for internal testing purposes, we retain the capability to disable this feature -os.environ["TREES_USE_ROUNDING"] = "1" +os.environ["TREES_USE_ROUNDING"] = os.environ.get("TREES_USE_ROUNDING", "1") # pylint: disable=too-many-public-methods @@ -1281,17 +1287,32 @@ def __init_subclass__(cls): _TREE_MODELS.add(cls) _ALL_SKLEARN_MODELS.add(cls) - def __init__(self, n_bits: int): + def __init__(self, n_bits: Union[int, Dict[str, int]]): """Initialize the TreeBasedEstimatorMixin. Args: - n_bits (int): The number of bits used for quantization. + n_bits (int, Dict[str, int]): Number of bits to quantize the model. If an int is passed + for n_bits, the value will be used for quantizing inputs and leaves. If a dict is + passed, then it should contain "op_inputs" and "op_leaves" as keys with + corresponding number of quantization bits so that: + - op_inputs (mandatory): number of bits to quantize the input values + - op_leaves (optional): number of bits to quantize the leaves + Default to 6. """ - self.n_bits: int = n_bits + + # Check if 'n_bits' is a valid value. + _inspect_tree_n_bits(n_bits) + + self.n_bits: Union[int, Dict[str, int]] = n_bits #: The model's inference function. Is None if the model is not fitted. self._tree_inference: Optional[Callable] = None + #: Wether to perform the sum of the output's tree ensembles in FHE or not. + # By default, the decision of the tree ensembles is made in clear (not in FHE). + # This attribute should not be modified by users. + self._fhe_ensembling = False + BaseEstimator.__init__(self) def fit(self, X: Data, y: Target, **fit_parameters): @@ -1304,9 +1325,14 @@ def fit(self, X: Data, y: Target, **fit_parameters): q_X = numpy.zeros_like(X) + # Convert the n_bits attribute into a proper dictionary + self.n_bits = _get_n_bits_dict_trees(self.n_bits) + # Quantization of each feature in X for i in range(X.shape[1]): - input_quantizer = QuantizedArray(n_bits=self.n_bits, values=X[:, i]).quantizer + input_quantizer = QuantizedArray( + n_bits=self.n_bits["op_inputs"], values=X[:, i] + ).quantizer self.input_quantizers.append(input_quantizer) q_X[:, i] = input_quantizer.quant(X[:, i]) @@ -1319,7 +1345,7 @@ def fit(self, X: Data, y: Target, **fit_parameters): # Check that the underlying sklearn model has been set and fit assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message() - # Convert the tree inference with Numpy operators + # Enable rounding feature enable_rounding = os.environ.get("TREES_USE_ROUNDING", "1") == "1" if not enable_rounding: @@ -1332,12 +1358,14 @@ def fit(self, X: Data, y: Target, **fit_parameters): stacklevel=2, ) + # Convert the tree inference with Numpy operators self._tree_inference, self.output_quantizers, self.onnx_model_ = tree_to_numpy( self.sklearn_model, q_X, use_rounding=enable_rounding, + fhe_ensembling=self._fhe_ensembling, framework=self.framework, - output_n_bits=self.n_bits, + output_n_bits=self.n_bits["op_leaves"], ) self._is_fitted = True @@ -1412,10 +1440,13 @@ def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray: # Sum all tree outputs # Remove the sum once we handle multi-precision circuits # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451 - y_preds = numpy.sum(y_preds, axis=-1) + if not self._fhe_ensembling: + y_preds = numpy.sum(y_preds, axis=-1) - assert_true(y_preds.ndim == 2, "y_preds should be a 2D array") - return y_preds + assert_true(y_preds.ndim == 2, "y_preds should be a 2D array") + return y_preds + + return super().post_processing(y_preds) class BaseTreeRegressorMixin(BaseTreeEstimatorMixin, sklearn.base.RegressorMixin, ABC): @@ -1841,12 +1872,15 @@ def __init__(self, n_bits: int = 3): quantizing inputs and X_fit. Default to 3. """ self.n_bits: int = n_bits + # _q_fit_X: In distance metric algorithms, `_q_fit_X` stores the training set to compute # the similarity or distance measures. There is no `weights` attribute because there isn't # a training phase self._q_fit_X: numpy.ndarray + # _y: Labels of `_q_fit_X` self._y: numpy.ndarray + # _q_fit_X_quantizer: The quantizer to use for quantizing the model's training set self._q_fit_X_quantizer: Optional[UniformQuantizer] = None diff --git a/src/concrete/ml/sklearn/rf.py b/src/concrete/ml/sklearn/rf.py index e5f756664..f4521bf06 100644 --- a/src/concrete/ml/sklearn/rf.py +++ b/src/concrete/ml/sklearn/rf.py @@ -1,5 +1,5 @@ """Implement RandomForest models.""" -from typing import Any, Dict +from typing import Any, Dict, Union import numpy import sklearn.ensemble @@ -19,7 +19,7 @@ class RandomForestClassifier(BaseTreeClassifierMixin): # pylint: disable-next=too-many-arguments def __init__( self, - n_bits: int = 6, + n_bits: Union[int, Dict[str, int]] = 6, n_estimators=20, criterion="gini", max_depth=4, @@ -84,6 +84,7 @@ def dump_dict(self) -> Dict[str, Any]: metadata["onnx_model_"] = self.onnx_model_ metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params + metadata["_fhe_ensembling"] = self._fhe_ensembling # Scikit-Learn metadata["n_estimators"] = self.n_estimators @@ -120,11 +121,13 @@ def load_dict(cls, metadata: Dict): obj.framework = metadata["framework"] obj.onnx_model_ = metadata["onnx_model_"] obj.output_quantizers = metadata["output_quantizers"] + obj._fhe_ensembling = metadata["_fhe_ensembling"] obj._tree_inference = tree_to_numpy( obj.sklearn_model, numpy.zeros((len(obj.input_quantizers),))[None, ...], framework=obj.framework, - output_n_bits=obj.n_bits, + output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits, + fhe_ensembling=obj._fhe_ensembling, )[0] obj.post_processing_params = metadata["post_processing_params"] @@ -162,7 +165,7 @@ class RandomForestRegressor(BaseTreeRegressorMixin): # pylint: disable-next=too-many-arguments def __init__( self, - n_bits: int = 6, + n_bits: Union[int, Dict[str, int]] = 6, n_estimators=20, criterion="squared_error", max_depth=4, @@ -219,6 +222,7 @@ def dump_dict(self) -> Dict[str, Any]: metadata["onnx_model_"] = self.onnx_model_ metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params + metadata["_fhe_ensembling"] = self._fhe_ensembling # Scikit-Learn metadata["n_estimators"] = self.n_estimators @@ -255,11 +259,13 @@ def load_dict(cls, metadata: Dict): obj.framework = metadata["framework"] obj.onnx_model_ = metadata["onnx_model_"] obj.output_quantizers = metadata["output_quantizers"] + obj._fhe_ensembling = metadata["_fhe_ensembling"] obj._tree_inference = tree_to_numpy( obj.sklearn_model, numpy.zeros((len(obj.input_quantizers),))[None, ...], framework=obj.framework, - output_n_bits=obj.n_bits, + output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits, + fhe_ensembling=obj._fhe_ensembling, )[0] obj.post_processing_params = metadata["post_processing_params"] diff --git a/src/concrete/ml/sklearn/tree.py b/src/concrete/ml/sklearn/tree.py index 1ea972cfd..b496d4e47 100644 --- a/src/concrete/ml/sklearn/tree.py +++ b/src/concrete/ml/sklearn/tree.py @@ -1,5 +1,5 @@ """Implement DecisionTree models.""" -from typing import Any, Dict +from typing import Any, Dict, Union import numpy import sklearn.tree @@ -31,7 +31,7 @@ def __init__( min_impurity_decrease=0.0, class_weight=None, ccp_alpha: float = 0.0, - n_bits: int = 6, + n_bits: Union[int, Dict[str, int]] = 6, ): """Initialize the DecisionTreeClassifier. @@ -84,6 +84,7 @@ def dump_dict(self) -> Dict[str, Any]: metadata["onnx_model_"] = self.onnx_model_ metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params + metadata["_fhe_ensembling"] = self._fhe_ensembling # Scikit-Learn metadata["criterion"] = self.criterion @@ -115,11 +116,13 @@ def load_dict(cls, metadata: Dict): obj.framework = metadata["framework"] obj.onnx_model_ = metadata["onnx_model_"] obj.output_quantizers = metadata["output_quantizers"] + obj._fhe_ensembling = metadata["_fhe_ensembling"] obj._tree_inference = tree_to_numpy( obj.sklearn_model, numpy.zeros((len(obj.input_quantizers),))[None, ...], framework=obj.framework, - output_n_bits=obj.n_bits, + output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits, + fhe_ensembling=obj._fhe_ensembling, )[0] obj.post_processing_params = metadata["post_processing_params"] @@ -162,7 +165,7 @@ def __init__( max_leaf_nodes=None, min_impurity_decrease=0.0, ccp_alpha=0.0, - n_bits: int = 6, + n_bits: Union[int, Dict[str, int]] = 6, ): """Initialize the DecisionTreeRegressor. @@ -208,6 +211,7 @@ def dump_dict(self) -> Dict[str, Any]: metadata["onnx_model_"] = self.onnx_model_ metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params + metadata["_fhe_ensembling"] = self._fhe_ensembling # Scikit-Learn metadata["criterion"] = self.criterion @@ -233,16 +237,19 @@ def load_dict(cls, metadata: Dict): # Concrete-ML obj.sklearn_model = metadata["sklearn_model"] obj._is_fitted = metadata["_is_fitted"] + obj._fhe_ensembling = metadata["_fhe_ensembling"] obj._is_compiled = metadata["_is_compiled"] obj.input_quantizers = metadata["input_quantizers"] obj.framework = metadata["framework"] obj.onnx_model_ = metadata["onnx_model_"] obj.output_quantizers = metadata["output_quantizers"] + obj._fhe_ensembling = metadata["_fhe_ensembling"] obj._tree_inference = tree_to_numpy( obj.sklearn_model, numpy.zeros((len(obj.input_quantizers),))[None, ...], framework=obj.framework, - output_n_bits=obj.n_bits, + output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits, + fhe_ensembling=obj._fhe_ensembling, )[0] obj.post_processing_params = metadata["post_processing_params"] diff --git a/src/concrete/ml/sklearn/tree_to_numpy.py b/src/concrete/ml/sklearn/tree_to_numpy.py index 15940e0ce..b50944319 100644 --- a/src/concrete/ml/sklearn/tree_to_numpy.py +++ b/src/concrete/ml/sklearn/tree_to_numpy.py @@ -17,7 +17,11 @@ OPSET_VERSION_FOR_ONNX_EXPORT, get_equivalent_numpy_forward_from_onnx_tree, ) -from ..onnx.onnx_model_manipulations import clean_graph_at_node_op_type, remove_node_types +from ..onnx.onnx_model_manipulations import ( + clean_graph_after_node_op_type, + clean_graph_at_node_op_type, + remove_node_types, +) from ..onnx.onnx_utils import get_op_type from ..quantization import QuantizedArray from ..quantization.quantizers import UniformQuantizer @@ -132,21 +136,31 @@ def assert_add_node_and_constant_in_xgboost_regressor_graph(onnx_model: onnx.Mod ) -def add_transpose_after_last_node(onnx_model: onnx.ModelProto): +def add_transpose_after_last_node(onnx_model: onnx.ModelProto, fhe_ensembling: bool = False): """Add transpose after last node. Args: onnx_model (onnx.ModelProto): The ONNX model. + fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE. + Default to False. """ # Get the output node output_node = onnx_model.graph.output[0] - # Create the node with perm attribute equal to (2, 1, 0) + # The state of the 'fhe_ensembling' variable affects the structure of the model's ONNX graph. + # When the option is enabled, the graph is cut after the ReduceSum node. + # When it is disabled, the graph is cut at the ReduceSum node, which alters the output shape. + # Therefore, it is necessary to adjust this shape with the correct permutation. + + # When using FHE sum for tree ensembles, create the node with perm attribute equal to (1, 0) + # Otherwise, create the node with perm attribute equal to (2, 1, 0) + perm = [1, 0] if fhe_ensembling else [2, 1, 0] + transpose_node = onnx.helper.make_node( "Transpose", inputs=[output_node.name], outputs=["transposed_output"], - perm=[2, 1, 0], + perm=perm, ) onnx_model.graph.node.append(transpose_node) @@ -204,7 +218,10 @@ def preprocess_tree_predictions( def tree_onnx_graph_preprocessing( - onnx_model: onnx.ModelProto, framework: str, expected_number_of_outputs: int + onnx_model: onnx.ModelProto, + framework: str, + expected_number_of_outputs: int, + fhe_ensembling: bool = False, ): """Apply pre-processing onto the ONNX graph. @@ -213,6 +230,8 @@ def tree_onnx_graph_preprocessing( framework (str): The framework from which the ONNX model is generated. (options: 'xgboost', 'sklearn') expected_number_of_outputs (int): The expected number of outputs in the ONNX model. + fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE. + Default to False. """ # Make sure the ONNX version returned by Hummingbird is OPSET_VERSION_FOR_ONNX_EXPORT onnx_version = get_onnx_opset_version(onnx_model) @@ -237,9 +256,12 @@ def tree_onnx_graph_preprocessing( if len(onnx_model.graph.output) == 1: assert_add_node_and_constant_in_xgboost_regressor_graph(onnx_model) - # Cut the graph at the ReduceSum node as large sum are not yet supported. - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451 - clean_graph_at_node_op_type(onnx_model, "ReduceSum") + # Cut the graph after the ReduceSum node to remove + # argmax, sigmoid, softmax from the graph. + if fhe_ensembling: + clean_graph_after_node_op_type(onnx_model, "ReduceSum") + else: + clean_graph_at_node_op_type(onnx_model, "ReduceSum") if framework == "xgboost": # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2778 @@ -252,7 +274,7 @@ def tree_onnx_graph_preprocessing( # sklearn models apply the reduce sum before the transpose. # To have equivalent output between xgboost in sklearn, # apply the transpose before returning the output. - add_transpose_after_last_node(onnx_model) + add_transpose_after_last_node(onnx_model, fhe_ensembling) # Cast nodes are not necessary so remove them. remove_node_types(onnx_model, op_types_to_remove=["Cast"]) @@ -277,6 +299,7 @@ def tree_values_preprocessing( # Modify ONNX graph to fit in FHE for i, initializer in enumerate(onnx_model.graph.initializer): + # All constants in our tree should be integers. # Tree thresholds can be rounded up or down (depending on the tree implementation) # while the final probabilities/regression values must be quantized. @@ -289,6 +312,7 @@ def tree_values_preprocessing( # Get the preprocessed tree predictions to replace the current (non-quantized) # values in the onnx_model. init_tensor = q_y.qvalues + elif "bias_1" in initializer.name: if framework == "xgboost": # xgboost uses "<" (Less) operator thus we must round up. @@ -307,6 +331,7 @@ def tree_to_numpy( x: numpy.ndarray, framework: str, use_rounding: bool = True, + fhe_ensembling: bool = False, output_n_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE, ) -> Tuple[Callable, List[UniformQuantizer], onnx.ModelProto]: """Convert the tree inference to a numpy functions using Hummingbird. @@ -314,8 +339,10 @@ def tree_to_numpy( Args: model (Callable): The tree model to convert. x (numpy.ndarray): The input data. - use_rounding (bool): This parameter is exclusively used to tree-based models. - It determines whether the rounding feature is enabled or disabled. + use_rounding (bool): Determines whether the rounding feature is enabled or disabled. + Default to True. + fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE. + Default to False. framework (str): The framework from which the ONNX model is generated. (options: 'xgboost', 'sklearn') output_n_bits (int): The number of bits of the output. Default to 8. @@ -352,7 +379,7 @@ def tree_to_numpy( # ONNX graph pre-processing to make the model FHE friendly # i.e., delete irrelevant nodes and cut the graph before the final ensemble sum) - tree_onnx_graph_preprocessing(onnx_model, framework, expected_number_of_outputs) + tree_onnx_graph_preprocessing(onnx_model, framework, expected_number_of_outputs, fhe_ensembling) # Tree values pre-processing # i.e., mainly predictions quantization diff --git a/src/concrete/ml/sklearn/xgb.py b/src/concrete/ml/sklearn/xgb.py index a10b1400b..8f3925fd7 100644 --- a/src/concrete/ml/sklearn/xgb.py +++ b/src/concrete/ml/sklearn/xgb.py @@ -27,7 +27,7 @@ class XGBClassifier(BaseTreeClassifierMixin): # pylint: disable=too-many-arguments,too-many-locals def __init__( self, - n_bits: int = 6, + n_bits: Union[int, Dict[str, int]] = 6, max_depth: Optional[int] = 3, learning_rate: Optional[float] = None, n_estimators: Optional[int] = 20, @@ -125,6 +125,7 @@ def dump_dict(self) -> Dict[str, Any]: metadata["onnx_model_"] = self.onnx_model_ metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params + metadata["_fhe_ensembling"] = self._fhe_ensembling # XGBoost metadata["max_depth"] = self.max_depth @@ -174,11 +175,13 @@ def load_dict(cls, metadata: Dict): obj.framework = metadata["framework"] obj.onnx_model_ = metadata["onnx_model_"] obj.output_quantizers = metadata["output_quantizers"] + obj._fhe_ensembling = metadata["_fhe_ensembling"] obj._tree_inference = tree_to_numpy( obj.sklearn_model, numpy.zeros((len(obj.input_quantizers),))[None, ...], framework=obj.framework, - output_n_bits=obj.n_bits, + output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits, + fhe_ensembling=obj._fhe_ensembling, )[0] obj.post_processing_params = metadata["post_processing_params"] @@ -233,7 +236,7 @@ class XGBRegressor(BaseTreeRegressorMixin): # pylint: disable=too-many-arguments,too-many-locals def __init__( self, - n_bits: int = 6, + n_bits: Union[int, Dict[str, int]] = 6, max_depth: Optional[int] = 3, learning_rate: Optional[float] = None, n_estimators: Optional[int] = 20, @@ -354,6 +357,7 @@ def dump_dict(self) -> Dict[str, Any]: metadata["onnx_model_"] = self.onnx_model_ metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params + metadata["_fhe_ensembling"] = self._fhe_ensembling # XGBoost metadata["max_depth"] = self.max_depth @@ -403,11 +407,13 @@ def load_dict(cls, metadata: Dict): obj.framework = metadata["framework"] obj.onnx_model_ = metadata["onnx_model_"] obj.output_quantizers = metadata["output_quantizers"] + obj._fhe_ensembling = metadata["_fhe_ensembling"] obj._tree_inference = tree_to_numpy( obj.sklearn_model, numpy.zeros((len(obj.input_quantizers),))[None, ...], framework=obj.framework, - output_n_bits=obj.n_bits, + output_n_bits=obj.n_bits["op_leaves"] if isinstance(obj.n_bits, Dict) else obj.n_bits, + fhe_ensembling=obj._fhe_ensembling, )[0] obj.post_processing_params = metadata["post_processing_params"] diff --git a/tests/sklearn/test_dump_onnx.py b/tests/sklearn/test_dump_onnx.py index ecfafc879..34e14d242 100644 --- a/tests/sklearn/test_dump_onnx.py +++ b/tests/sklearn/test_dump_onnx.py @@ -1,6 +1,5 @@ """Tests for the sklearn decision trees.""" - import warnings from functools import partial @@ -20,108 +19,22 @@ # pylint: disable=line-too-long -def check_onnx_file_dump(model_class, parameters, load_data, str_expected, default_configuration): +def check_onnx_file_dump( + model_class, parameters, load_data, default_configuration, use_fhe_sum=False +): """Fit the model and dump the corresponding ONNX.""" - # Get the data-set. The data generation is seeded in load_data. - x, y = load_data(model_class, **parameters) + model_name = get_model_name(model_class) + n_classes = parameters.get("n_classes", 2) # Set the model model = model_class() - model_params = model.get_params() - if "random_state" in model_params: - model_params["random_state"] = numpy.random.randint(0, 2**15) - - model.set_params(**model_params) + # Set `_fhe_ensembling` for tree based models only + if model_class in _get_sklearn_tree_models(): - if get_model_name(model) == "KNeighborsClassifier": - # KNN can only be compiled with small quantization bit numbers for now - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 - model.n_bits = 2 - - with warnings.catch_warnings(): - # Sometimes, we miss convergence, which is not a problem for our test - warnings.simplefilter("ignore", category=ConvergenceWarning) - - model.fit(x, y) - - with warnings.catch_warnings(): - # Use FHE simulation to not have issues with precision - model.compile(x, default_configuration) - - # Get ONNX model - onnx_model = model.onnx_model - - # Remove initializers, since they change from one seed to the other - model_name = get_model_name(model_class) - if model_name in [ - "DecisionTreeRegressor", - "DecisionTreeClassifier", - "RandomForestClassifier", - "RandomForestRegressor", - "XGBClassifier", - "KNeighborsClassifier", - ]: - while len(onnx_model.graph.initializer) > 0: - del onnx_model.graph.initializer[0] - - str_model = onnx.helper.printable_graph(onnx_model.graph) - print(f"{model_name}:") - print(str_model) - - # Test equality when it does not depend on seeds - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3266 - if not is_model_class_in_a_list(model_class, _get_sklearn_tree_models(select="RandomForest")): - # The expected graph is usually a string and we therefore directly test if it is equal to - # the retrieved graph's string. However, in some cases such as for TweedieRegressor models, - # this graph can slightly changed depending on some input's values. We then expected the - # string to match as least one of them expected strings (as a list) - if isinstance(str_expected, str): - assert str_model == str_expected - else: - assert str_model in str_expected - - -@pytest.mark.parametrize("model_class, parameters", UNIQUE_MODELS_AND_DATASETS) -def test_dump( - model_class, - parameters, - load_data, - default_configuration, -): - """Tests dump.""" - - model_name = get_model_name(model_class) - - # Some models have been done with different n_classes which create different ONNX - if parameters.get("n_classes", 2) != 2 and model_name in ["LinearSVC", "LogisticRegression"]: - return - - if model_name == "NeuralNetClassifier": - model_class = partial( - NeuralNetClassifier, - module__n_layers=3, - module__power_of_two_scaling=False, - max_epochs=1, - verbose=0, - callbacks="disable", - ) - elif model_name == "NeuralNetRegressor": - model_class = partial( - NeuralNetRegressor, - module__n_layers=3, - module__n_w_bits=2, - module__n_a_bits=2, - module__n_accum_bits=7, # Stay with 7 bits for test exec time - module__n_hidden_neurons_multiplier=1, - module__power_of_two_scaling=False, - max_epochs=1, - verbose=0, - callbacks="disable", - ) - - n_classes = parameters.get("n_classes", 2) + # pylint: disable=protected-access + model._fhe_ensembling = use_fhe_sum # Ignore long lines here # ruff: noqa: E501 @@ -222,8 +135,15 @@ def test_dump( %/_operators.0/Reshape_2_output_0 = Reshape[allowzero = 0](%/_operators.0/Equal_output_0, %/_operators.0/Constant_2_output_0) %/_operators.0/MatMul_1_output_0 = MatMul(%_operators.0.weight_3, %/_operators.0/Reshape_2_output_0) %/_operators.0/Reshape_3_output_0 = Reshape[allowzero = 0](%/_operators.0/MatMul_1_output_0, %/_operators.0/Constant_3_output_0) - %transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0) - return %transposed_output + """ + + ( + """%/_operators.0/ReduceSum_output_0 = ReduceSum[keepdims = 0](%/_operators.0/Reshape_3_output_0, %onnx::ReduceSum_22) + %transposed_output = Transpose[perm = [1, 0]](%/_operators.0/ReduceSum_output_0) + """ + if use_fhe_sum + else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)\n " + ) + + """return %transposed_output }""", "RandomForestClassifier": """graph torch_jit ( %input_0[DOUBLE, symx10] @@ -294,8 +214,15 @@ def test_dump( %/_operators.0/Reshape_2_output_0 = Reshape[allowzero = 0](%/_operators.0/Equal_output_0, %/_operators.0/Constant_2_output_0) %/_operators.0/MatMul_1_output_0 = MatMul(%_operators.0.weight_3, %/_operators.0/Reshape_2_output_0) %/_operators.0/Reshape_3_output_0 = Reshape[allowzero = 0](%/_operators.0/MatMul_1_output_0, %/_operators.0/Constant_3_output_0) - %transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0) - return %transposed_output + """ + + ( + """%/_operators.0/ReduceSum_output_0 = ReduceSum[keepdims = 0](%/_operators.0/Reshape_3_output_0, %onnx::ReduceSum_22) + %transposed_output = Transpose[perm = [1, 0]](%/_operators.0/ReduceSum_output_0) + """ + if use_fhe_sum is True + else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)\n " + ) + + """return %transposed_output }""", "GammaRegressor": """graph torch_jit ( %input_0[DOUBLE, symx10] @@ -339,8 +266,14 @@ def test_dump( %/_operators.0/Squeeze_output_0 = Squeeze(%/_operators.0/Reshape_3_output_0, %axes_squeeze) %/_operators.0/Transpose_output_0 = Transpose[perm = [1, 0]](%/_operators.0/Squeeze_output_0) %/_operators.0/Reshape_4_output_0 = Reshape[allowzero = 0](%/_operators.0/Transpose_output_0, %/_operators.0/Constant_4_output_0) - return %/_operators.0/Reshape_4_output_0 -}""", + """ + + ( + """%/_operators.0/ReduceSum_output_0 = ReduceSum[keepdims = 0](%/_operators.0/Reshape_4_output_0, %onnx::ReduceSum_26) + return %/_operators.0/ReduceSum_output_0 +}""" + if use_fhe_sum is True + else "return %/_operators.0/Reshape_4_output_0\n}" + ), "RandomForestRegressor": """graph torch_jit ( %input_0[DOUBLE, symx10] ) { @@ -357,8 +290,15 @@ def test_dump( %/_operators.0/MatMul_1_output_0 = MatMul(%_operators.0.weight_3, %/_operators.0/Reshape_2_output_0) %/_operators.0/Constant_3_output_0 = Constant[value = ]() %/_operators.0/Reshape_3_output_0 = Reshape[allowzero = 0](%/_operators.0/MatMul_1_output_0, %/_operators.0/Constant_3_output_0) - %transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0) - return %transposed_output + """ + + ( + """%/_operators.0/ReduceSum_output_0 = ReduceSum[keepdims = 0](%/_operators.0/Reshape_3_output_0, %onnx::ReduceSum_22) + %transposed_output = Transpose[perm = [1, 0]](%/_operators.0/ReduceSum_output_0) + """ + if use_fhe_sum is True + else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)" + ) + + """return %transposed_output }""", "XGBRegressor": """graph torch_jit ( %input_0[DOUBLE, symx10] @@ -373,7 +313,9 @@ def test_dump( %/_operators.0/Constant_1_output_0[INT64, 2] %/_operators.0/Constant_2_output_0[INT64, 3] %/_operators.0/Constant_3_output_0[INT64, 3] - %/_operators.0/Constant_4_output_0[INT64, 3] + %/_operators.0/Constant_4_output_0[INT64, 3]""" + + ("\n %onnx::ReduceSum_27[INT64, 1]" if use_fhe_sum is True else "") + + """ ) { %/_operators.0/Gemm_output_0 = Gemm[alpha = 1, beta = 0, transB = 1](%_operators.0.weight_1, %input_0) %/_operators.0/Less_output_0 = Less(%/_operators.0/Gemm_output_0, %_operators.0.bias_1) @@ -387,8 +329,14 @@ def test_dump( %/_operators.0/Squeeze_output_0 = Squeeze(%/_operators.0/Reshape_3_output_0, %axes_squeeze) %/_operators.0/Transpose_output_0 = Transpose[perm = [1, 0]](%/_operators.0/Squeeze_output_0) %/_operators.0/Reshape_4_output_0 = Reshape[allowzero = 0](%/_operators.0/Transpose_output_0, %/_operators.0/Constant_4_output_0) - return %/_operators.0/Reshape_4_output_0 -}""", + """ + + ( + """%/_operators.0/ReduceSum_output_0 = ReduceSum[keepdims = 0](%/_operators.0/Reshape_4_output_0, %onnx::ReduceSum_27) + return %/_operators.0/ReduceSum_output_0 +}""" + if use_fhe_sum is True + else """return %/_operators.0/Reshape_4_output_0\n}""" + ), "LinearRegression": """graph torch_jit ( %input_0[DOUBLE, symx10] ) initializers ( @@ -457,4 +405,106 @@ def test_dump( } str_expected = expected_strings.get(model_name, "") - check_onnx_file_dump(model_class, parameters, load_data, str_expected, default_configuration) + + # Get the data-set. The data generation is seeded in load_data. + x, y = load_data(model_class, **parameters) + + model_params = model.get_params() + if "random_state" in model_params: + model_params["random_state"] = numpy.random.randint(0, 2**15) + + model.set_params(**model_params) + + if model_name == "KNeighborsClassifier": + # KNN can only be compiled with small quantization bit numbers for now + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 + model.n_bits = 2 + + with warnings.catch_warnings(): + # Sometimes, we miss convergence, which is not a problem for our test + warnings.simplefilter("ignore", category=ConvergenceWarning) + + model.fit(x, y) + + with warnings.catch_warnings(): + # Use FHE simulation to not have issues with precision + model.compile(x, default_configuration) + + # Get ONNX model + onnx_model = model.onnx_model + + # Remove initializers, since they change from one seed to the other + model_name = get_model_name(model_class) + if model_name in [ + "DecisionTreeRegressor", + "DecisionTreeClassifier", + "RandomForestClassifier", + "RandomForestRegressor", + "XGBClassifier", + "KNeighborsClassifier", + ]: + while len(onnx_model.graph.initializer) > 0: + del onnx_model.graph.initializer[0] + + str_model = onnx.helper.printable_graph(onnx_model.graph) + print(f"\nCurrent {model_name=}:\n{str_model}") + print(f"\nExpected {model_name=}:\n{str_expected}") + + # Test equality when it does not depend on seeds + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3266 + if not is_model_class_in_a_list(model_class, _get_sklearn_tree_models(select="RandomForest")): + # The expected graph is usually a string and we therefore directly test if it is equal to + # the retrieved graph's string. However, in some cases such as for TweedieRegressor models, + # this graph can slightly changed depending on some input's values. We then expected the + # string to match as least one of them expected strings (as a list) + if isinstance(str_expected, str): + assert str_model == str_expected + else: + assert str_model in str_expected + + +@pytest.mark.parametrize("model_class, parameters", UNIQUE_MODELS_AND_DATASETS) +def test_dump( + model_class, + parameters, + load_data, + default_configuration, +): + """Tests dump.""" + + model_name = get_model_name(model_class) + + # Some models have been done with different n_classes which create different ONNX + if parameters.get("n_classes", 2) != 2 and model_name in ["LinearSVC", "LogisticRegression"]: + return + + if model_name == "NeuralNetClassifier": + model_class = partial( + NeuralNetClassifier, + module__n_layers=3, + module__power_of_two_scaling=False, + max_epochs=1, + verbose=0, + callbacks="disable", + ) + elif model_name == "NeuralNetRegressor": + model_class = partial( + NeuralNetRegressor, + module__n_layers=3, + module__n_w_bits=2, + module__n_a_bits=2, + module__n_accum_bits=7, # Stay with 7 bits for test exec time + module__n_hidden_neurons_multiplier=1, + module__power_of_two_scaling=False, + max_epochs=1, + verbose=0, + callbacks="disable", + ) + + check_onnx_file_dump(model_class, parameters, load_data, default_configuration) + + # Additional tests exclusively dedicated for tree ensemble models. + if model_class in _get_sklearn_tree_models(): + check_onnx_file_dump( + model_class, parameters, load_data, default_configuration, use_fhe_sum=True + ) diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index d3e68f731..6666087e1 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -46,6 +46,7 @@ from concrete.ml.common.serialization.loaders import load, loads from concrete.ml.common.utils import ( USE_OLD_VL, + array_allclose_and_same_shape, get_model_class, get_model_name, is_classifier_or_partial_classifier, @@ -143,9 +144,9 @@ def preamble(model_class, parameters, n_bits, load_data, is_weekly_option): def get_n_bits_non_correctness(model_class): """Get the number of bits to use for non correctness related tests.""" + # KNN can only be compiled with small quantization bit numbers for now + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 if get_model_name(model_class) == "KNeighborsClassifier": - # KNN can only be compiled with small quantization bit numbers for now - # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979 n_bits = 2 else: n_bits = min(N_BITS_REGULAR_BUILDS) @@ -1152,8 +1153,7 @@ def check_rounding_consistency( fhe_test = get_random_samples(x, n_sample=5) # Check that rounding is enabled - rounding_enabled = os.getenv("TREES_USE_ROUNDING") == "1" - assert rounding_enabled + assert os.environ.get("TREES_USE_ROUNDING") == "1", "'TREES_USE_ROUNDING' is not enabled" # Fit and compile with rounding enabled fit_and_compile(model, x, y) @@ -1171,8 +1171,7 @@ def check_rounding_consistency( mp_context.setenv("TREES_USE_ROUNDING", "0") # Check that rounding is disabled - rounding_disabled = os.environ.get("TREES_USE_ROUNDING") == "0" - assert rounding_disabled + assert os.environ.get("TREES_USE_ROUNDING") == "0", "'TREES_USE_ROUNDING' is not disabled" with pytest.warns( DeprecationWarning, @@ -1200,6 +1199,52 @@ def check_rounding_consistency( # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4178 +def check_fhe_sum_for_tree_based_models( + model, + x, + y, + predict_method, + is_weekly_option, +): + """Test that Concrete ML without and with FHE sum are 'equivalent'.""" + + # Run the test with more samples during weekly CIs + if is_weekly_option: + fhe_samples = 5 + else: + fhe_samples = 1 + + fhe_test = get_random_samples(x, n_sample=fhe_samples) + + # pylint: disable=protected-access + assert not model._fhe_ensembling, "`_fhe_ensembling` is disabled by default." + fit_and_compile(model, x, y) + + non_fhe_sum_predict_quantized = predict_method(x, fhe="disable") + non_fhe_sum_predict_simulate = predict_method(x, fhe="simulate") + non_fhe_sum_predict_fhe = predict_method(fhe_test, fhe="execute") + + # Sanity check + array_allclose_and_same_shape(non_fhe_sum_predict_quantized, non_fhe_sum_predict_simulate) + + # pylint: disable=protected-access + model._fhe_ensembling = True + + fit_and_compile(model, x, y) + + fhe_sum_predict_quantized = predict_method(x, fhe="disable") + fhe_sum_predict_simulate = predict_method(x, fhe="simulate") + fhe_sum_predict_fhe = predict_method(fhe_test, fhe="execute") + + # Sanity check + array_allclose_and_same_shape(fhe_sum_predict_quantized, fhe_sum_predict_simulate) + + # Check that we have the exact same predictions + array_allclose_and_same_shape(fhe_sum_predict_quantized, non_fhe_sum_predict_quantized) + array_allclose_and_same_shape(fhe_sum_predict_simulate, non_fhe_sum_predict_simulate) + array_allclose_and_same_shape(fhe_sum_predict_fhe, non_fhe_sum_predict_fhe) + + # Neural network models are skipped for this test # The `fit_benchmark` function of QNNs returns a QAT model and a FP32 model that is similar # in structure but trained from scratch. Furthermore, the `n_bits` setting @@ -1658,6 +1703,7 @@ def test_p_error_simulation( The test checks that models compiled with a large p_error value predicts very different results with simulation or in FHE compared to the expected clear quantized ones. """ + n_bits = get_n_bits_non_correctness(model_class) # Get data-set, initialize and fit the model @@ -1834,7 +1880,7 @@ def test_linear_models_have_no_tlu( # Additional tests for this purpose should be added in future updates # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4179 @pytest.mark.parametrize("model_class, parameters", get_sklearn_tree_models_and_datasets()) -@pytest.mark.parametrize("n_bits", [2, 5, 11]) +@pytest.mark.parametrize("n_bits", [2, 5, 10]) def test_rounding_consistency_for_regular_models( model_class, parameters, @@ -1871,3 +1917,91 @@ def test_rounding_consistency_for_regular_models( metric, is_weekly_option, ) + + +@pytest.mark.parametrize("model_class, parameters", get_sklearn_tree_models_and_datasets()) +@pytest.mark.parametrize("n_bits", [2, 5, 10]) +def test_fhe_sum_for_tree_based_models( + model_class, + parameters, + n_bits, + load_data, + is_weekly_option, + verbose=True, +): + """Test that the tree ensembles' output are the same with and without the sum in FHE.""" + + if verbose: + print("Run check_fhe_sum_for_tree_based_models") + + model = instantiate_model_generic(model_class, n_bits=n_bits) + + x, y = get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option) + + predict_method = ( + model.predict_proba if is_classifier_or_partial_classifier(model) else model.predict + ) + + check_fhe_sum_for_tree_based_models( + model, + x, + y, + predict_method, + is_weekly_option, + ) + + +# This test should be extended to all built-in models. +# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4234 +@pytest.mark.parametrize( + "n_bits, error_message", + [ + (0, "n_bits must be a strictly positive integer"), + (-1, "n_bits must be a strictly positive integer"), + ({"op_leaves": 2}, "The key 'op_inputs' is mandatory"), + ( + {"op_inputs": 4, "op_leaves": 2, "op_weights": 2}, + "Invalid keys in 'n_bits' dictionary. Only 'op_inputs' \\(mandatory\\) and 'op_leaves' " + "\\(optional\\) are allowed", + ), + ( + {"op_inputs": -2, "op_leaves": -5}, + "All values in 'n_bits' dictionary must be strictly positive integers", + ), + ({"op_inputs": 2, "op_leaves": 5}, "'op_leaves' must be less than or equal to 'op_inputs'"), + (0.5, "n_bits must be either an integer or a dictionary"), + ], +) +@pytest.mark.parametrize("model_class", _get_sklearn_tree_models()) +def test_invalid_n_bits_setting(model_class, n_bits, error_message): + """Check if the model instantiation raises an exception with invalid `n_bits` settings.""" + + with pytest.raises(ValueError, match=f"{error_message}. Got '{type(n_bits)}' and '{n_bits}'.*"): + instantiate_model_generic(model_class, n_bits=n_bits) + + +# This test should be extended to all built-in models. +# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4234 +@pytest.mark.parametrize("n_bits", [5, {"op_inputs": 5}, {"op_inputs": 2, "op_leaves": 1}]) +@pytest.mark.parametrize("model_class, parameters", get_sklearn_tree_models_and_datasets()) +def test_valid_n_bits_setting( + model_class, + n_bits, + parameters, + load_data, + is_weekly_option, + verbose=True, +): + """Check valid `n_bits` settings.""" + + if verbose: + print("Run test_valid_n_bits_setting") + + x, y = get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option) + + model = instantiate_model_generic(model_class, n_bits=n_bits) + + with warnings.catch_warnings(): + # Sometimes, we miss convergence, which is not a problem for our test + warnings.simplefilter("ignore", category=ConvergenceWarning) + model.fit(x, y)