diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 1aacbcd65..3f76b509f 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -4,7 +4,7 @@ import tempfile import warnings from pathlib import Path -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Tuple, Union import numpy import onnx @@ -265,7 +265,7 @@ def get_equivalent_numpy_forward_from_onnx( def get_equivalent_numpy_forward_from_onnx_tree( onnx_model: onnx.ModelProto, check_model: bool = True, - lsbs_to_remove_for_trees: Optional[Tuple[int, int]] = None, + auto_truncate=None, ) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]: """Get the numpy equivalent forward of the provided ONNX model for tree-based models only. @@ -274,7 +274,7 @@ def get_equivalent_numpy_forward_from_onnx_tree( forward. check_model (bool): set to True to run the onnx checker on the model. Defaults to True. - lsbs_to_remove_for_trees (Optional[Tuple[int, int]]): This parameter is exclusively used for + auto_truncate (TODO): This parameter is exclusively used for optimizing tree-based models. It contains the values of the least significant bits to remove during the tree traversal, where the first value refers to the first comparison (either "less" or "less_or_equal"), while the second value refers to the "Equal" @@ -290,6 +290,6 @@ def get_equivalent_numpy_forward_from_onnx_tree( # Return lambda of numpy equivalent of onnx execution return ( lambda *args: execute_onnx_with_numpy_trees( - equivalent_onnx_model.graph, lsbs_to_remove_for_trees, *args + equivalent_onnx_model.graph, auto_truncate, *args ) ), equivalent_onnx_model diff --git a/src/concrete/ml/onnx/onnx_impl_utils.py b/src/concrete/ml/onnx/onnx_impl_utils.py index 158f513ae..156d845e8 100644 --- a/src/concrete/ml/onnx/onnx_impl_utils.py +++ b/src/concrete/ml/onnx/onnx_impl_utils.py @@ -5,7 +5,7 @@ import numpy from concrete.fhe import conv as fhe_conv from concrete.fhe import ones as fhe_ones -from concrete.fhe import round_bit_pattern +from concrete.fhe import truncate_bit_pattern from concrete.fhe.tracing import Tracer from ..common.debugging import assert_true @@ -238,14 +238,14 @@ def onnx_avgpool_compute_norm_const( # - Adjust the typing # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4143 def rounded_comparison( - x: numpy.ndarray, y: numpy.ndarray, lsbs_to_remove: int, operation: ComparisonOperationType + x: numpy.ndarray, y: numpy.ndarray, auto_truncate, operation: ComparisonOperationType ) -> Tuple[bool]: - """Comparison operation using `round_bit_pattern` function. + """Comparison operation using `truncate_bit_pattern` function. - `round_bit_pattern` rounds the bit pattern of an integer to the closer + `truncate_bit_pattern` rounds the bit pattern of an integer to the closer It also checks for any potential overflow. If so, it readjusts the LSBs accordingly. - The parameter `lsbs_to_remove` in `round_bit_pattern` can either be an integer specifying the + The parameter `lsbs_to_remove` in `truncate_bit_pattern` can either be an integer specifying the number of LSBS to remove, or an `AutoRounder` object that determines the required number of LSBs based on the specified number of MSBs to retain. But in our case, we choose to compute the LSBs manually. @@ -253,21 +253,14 @@ def rounded_comparison( Args: x (numpy.ndarray): Input tensor y (numpy.ndarray): Input tensor - lsbs_to_remove (int): Number of the least significant bits to remove - operation (ComparisonOperationType): Comparison operation, which can `<`, `<=` and `==` + auto_truncate: TODO + operation: TODO Returns: Tuple[bool]: If x and y satisfy the comparison operator. """ - assert isinstance(lsbs_to_remove, int) - - # Workaround: in this context, `round_bit_pattern` is used as a truncate operation. - # Consequently, we subtract a term, called `half` that will subsequently be re-added during the - # `round_bit_pattern` process. - half = 1 << (lsbs_to_remove - 1) - # To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y' - rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=lsbs_to_remove) + rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=auto_truncate) return (operation(rounded_subtraction),) diff --git a/src/concrete/ml/onnx/onnx_utils.py b/src/concrete/ml/onnx/onnx_utils.py index 60a8cb02c..ff1f1ef67 100644 --- a/src/concrete/ml/onnx/onnx_utils.py +++ b/src/concrete/ml/onnx/onnx_utils.py @@ -213,7 +213,7 @@ # Original file: # https://github.com/google/jax/blob/f6d329b2d9b5f83c6a59e5739aa1ca8d4d1ffa1c/examples/onnx2xla.py -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Tuple import numpy import onnx @@ -415,7 +415,7 @@ } # All numpy operators used for tree-based models that support auto rounding ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL = { - "Less": rounded_numpy_less_for_trees, + "Less": rounded_numpy_less_for_trees, # type: ignore[dict-item] "Equal": rounded_numpy_equal_for_trees, "LessOrEqual": rounded_numpy_less_or_equal_for_trees, } @@ -485,14 +485,14 @@ def execute_onnx_with_numpy( def execute_onnx_with_numpy_trees( graph: onnx.GraphProto, - lsbs_to_remove_for_trees: Optional[Tuple[int, int]], + auto_truncate, *inputs: numpy.ndarray, ) -> Tuple[numpy.ndarray, ...]: """Execute the provided ONNX graph on the given inputs for tree-based models only. Args: graph (onnx.GraphProto): The ONNX graph to execute. - lsbs_to_remove_for_trees (Optional[Tuple[int, int]]): This parameter is exclusively used for + auto_truncate: This parameter is exclusively used for optimizing tree-based models. It contains the values of the least significant bits to remove during the tree traversal, where the first value refers to the first comparison (either "less" or "less_or_equal"), while the second value refers to the "Equal" @@ -507,7 +507,7 @@ def execute_onnx_with_numpy_trees( op_type: Callable[..., Tuple[numpy.ndarray[Any, Any], ...]] # If no tree-based optimization is specified, return standard execution - if lsbs_to_remove_for_trees is None: + if auto_truncate is None: return execute_onnx_with_numpy(graph, *inputs) node_results: Dict[str, numpy.ndarray] = dict( @@ -523,11 +523,12 @@ def execute_onnx_with_numpy_trees( attributes = {attribute.name: get_attribute(attribute) for attribute in node.attribute} if node.op_type in ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL: + attributes["auto_truncate"] = auto_truncate - # The first LSB refers to `Less` or `LessOrEqual` comparisons - # The second LSB refers to `Equal` comparison - stage = 0 if node.op_type != "Equal" else 1 - attributes["lsbs_to_remove_for_trees"] = lsbs_to_remove_for_trees[stage] + # # The first LSB refers to `Less` or `LessOrEqual` comparisons + # # The second LSB refers to `Equal` comparison + # stage = 0 if node.op_type != "Equal" else 1 + # attributes["lsbs_to_remove_for_trees"] = lsbs_to_remove_for_trees[stage] # Use rounded numpy operation to relevant comparison nodes op_type = ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL[node.op_type] diff --git a/src/concrete/ml/onnx/ops_impl.py b/src/concrete/ml/onnx/ops_impl.py index 6ebf96900..9bbe412fc 100644 --- a/src/concrete/ml/onnx/ops_impl.py +++ b/src/concrete/ml/onnx/ops_impl.py @@ -892,7 +892,7 @@ def rounded_numpy_equal_for_trees( x: numpy.ndarray, y: numpy.ndarray, *, - lsbs_to_remove_for_trees: Optional[int] = None, + auto_truncate=None, ) -> Tuple[numpy.ndarray]: """Compute rounded equal in numpy according to ONNX spec for tree-based models only. @@ -901,7 +901,7 @@ def rounded_numpy_equal_for_trees( Args: x (numpy.ndarray): Input tensor y (numpy.ndarray): Input tensor - lsbs_to_remove_for_trees (Optional[int]): Number of the least significant bits to remove + auto_truncate: Number of the least significant bits to remove for tree-based models only. Returns: @@ -916,9 +916,9 @@ def rounded_numpy_equal_for_trees( # Option 2 is selected because it adheres to the established pattern in `rounded_comparison` # which does: (a - b) - half. - if lsbs_to_remove_for_trees is not None and lsbs_to_remove_for_trees > 0: + if auto_truncate is not None: return rounded_comparison( - y, x, lsbs_to_remove_for_trees, operation=lambda x: x >= 0 + y, x, auto_truncate, operation=lambda x: x >= 0 ) # pragma: no cover # Else, default numpy_equal operator @@ -1076,7 +1076,7 @@ def rounded_numpy_less_for_trees( x: numpy.ndarray, y: numpy.ndarray, *, - lsbs_to_remove_for_trees: Optional[int] = None, + auto_truncate, ) -> Tuple[numpy.ndarray]: """Compute rounded less in numpy according to ONNX spec for tree-based models only. @@ -1085,7 +1085,7 @@ def rounded_numpy_less_for_trees( Args: x (numpy.ndarray): Input tensor y (numpy.ndarray): Input tensor - lsbs_to_remove_for_trees (Optional[int]): Number of the least significant bits to remove + auto_truncate: Number of the least significant bits to remove for tree-based models only. Returns: @@ -1094,8 +1094,8 @@ def rounded_numpy_less_for_trees( # numpy.less(x, y) is equivalent to : # x - y <= 0 => round_bit_pattern(x - y - half) < 0 - if lsbs_to_remove_for_trees is not None and lsbs_to_remove_for_trees > 0: - return rounded_comparison(x, y, lsbs_to_remove_for_trees, operation=lambda x: x < 0) + if auto_truncate is not None: + return rounded_comparison(x, y, auto_truncate, operation=lambda x: x < 0) # Else, default numpy_less operator return numpy_less(x, y) @@ -1143,7 +1143,7 @@ def rounded_numpy_less_or_equal_for_trees( x: numpy.ndarray, y: numpy.ndarray, *, - lsbs_to_remove_for_trees: Optional[int] = None, + auto_truncate=None, ) -> Tuple[numpy.ndarray]: """Compute rounded less or equal in numpy according to ONNX spec for tree-based models only. @@ -1152,7 +1152,7 @@ def rounded_numpy_less_or_equal_for_trees( Args: x (numpy.ndarray): Input tensor y (numpy.ndarray): Input tensor - lsbs_to_remove_for_trees (Optional[int]): Number of the least significant bits to remove + auto_truncate: Number of the least significant bits to remove for tree-based models only. Returns: @@ -1160,13 +1160,14 @@ def rounded_numpy_less_or_equal_for_trees( """ # numpy.less_equal(x, y) <= y is equivalent to : - # option 1: x - y <= 0 => round_bit_pattern(x - y + half) <= 0 or - # option 2: y - x >= 0 => round_bit_pattern(y - x - half) >= 0 + # option 1: x - y <= 0 => truncate_bit_pattern(x - y + half) <= 0 or + # option 2: y - x >= 0 => truncate_bit_pattern(y - x - half) >= 0 # Option 2 is selected because it adheres to the established pattern in `rounded_comparison` - # which does: (a - b) - half. - if lsbs_to_remove_for_trees is not None and lsbs_to_remove_for_trees > 0: - return rounded_comparison(y, x, lsbs_to_remove_for_trees, operation=lambda x: x >= 0) + # which does: (a - b). + + if auto_truncate is not None: + return rounded_comparison(y, x, auto_truncate, operation=lambda x: x >= 0) # Else, default numpy_less_or_equal operator return numpy_less_or_equal(x, y) diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index 461d54462..9b899380b 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -99,10 +99,13 @@ # Define QNN's attribute that will be auto-generated when fitting QNN_AUTO_KWARGS = ["module__n_outputs", "module__input_dim"] -# Enable rounding feature for all tree-based models by default +# Most significant bits to retain when applying rounding to the tree +MSB_TO_KEEP_FOR_TREES = 1 + +# Enable truncate feature for all tree-based models by default # Note: This setting is fixed and cannot be altered by users # However, for internal testing purposes, we retain the capability to disable this feature -os.environ["TREES_USE_ROUNDING"] = os.environ.get("TREES_USE_ROUNDING", "1") +os.environ["TREES_USE_TRUNCATE"] = os.environ.get("TREES_USE_TRUNCATE", "1") # pylint: disable=too-many-public-methods @@ -529,6 +532,7 @@ def compile( Returns: Circuit: The compiled Circuit. """ + print("compilation stage 2") # Reset for double compile self._is_compiled = False @@ -1321,6 +1325,9 @@ def fit(self, X: Data, y: Target, **fit_parameters): self.input_quantizers = [] self.output_quantizers = [] + #: Determines the LSB to remove given a `target_msbs` + self._auto_truncate = cp.AutoTruncator(target_msbs=MSB_TO_KEEP_FOR_TREES) + X, y = check_X_y_and_assert_multi_output(X, y) q_X = numpy.zeros_like(X) @@ -1345,29 +1352,36 @@ def fit(self, X: Data, y: Target, **fit_parameters): # Check that the underlying sklearn model has been set and fit assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message() - # Enable rounding feature - enable_rounding = os.environ.get("TREES_USE_ROUNDING", "1") == "1" + # Enable optimized computation + enable_truncate = os.environ.get("TREES_USE_TRUNCATE", "1") == "1" - if not enable_rounding: + if not enable_truncate: warnings.simplefilter("always") warnings.warn( - "Using Concrete tree-based models without the `rounding feature` is deprecated. " - "Consider setting 'use_rounding' to `True` for making the FHE inference faster " + "Using Concrete tree-based models without the `truncate feature` is deprecated. " + "Consider setting 'use_truncate' to `True` for making the FHE inference faster " "and key generation.", category=DeprecationWarning, stacklevel=2, ) + self._auto_truncate = None # Convert the tree inference with Numpy operators self._tree_inference, self.output_quantizers, self.onnx_model_ = tree_to_numpy( self.sklearn_model, q_X, - use_rounding=enable_rounding, + auto_truncate=self._auto_truncate, fhe_ensembling=self._fhe_ensembling, framework=self.framework, output_n_bits=self.n_bits["op_leaves"], ) + # Adjust the truncate + if enable_truncate: + inputset = numpy.array(list(_get_inputset_generator(q_X))).astype(int) + self._auto_truncate.adjust(self._tree_inference, inputset) + self._tree_inference(q_X.astype("int")) + self._is_fitted = True return self @@ -1407,6 +1421,7 @@ def _get_module_to_compile(self) -> Union[Compiler, QuantizedModule]: return compiler def compile(self, *args, **kwargs) -> Circuit: + print("Compilation base.py") BaseEstimator.compile(self, *args, **kwargs) # Check that the graph only has a single output diff --git a/src/concrete/ml/sklearn/tree_to_numpy.py b/src/concrete/ml/sklearn/tree_to_numpy.py index b50944319..b3b72f88b 100644 --- a/src/concrete/ml/sklearn/tree_to_numpy.py +++ b/src/concrete/ml/sklearn/tree_to_numpy.py @@ -1,7 +1,7 @@ """Implements the conversion of a tree model to a numpy function.""" import math import warnings -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Tuple import numpy import onnx @@ -330,7 +330,7 @@ def tree_to_numpy( model: Callable, x: numpy.ndarray, framework: str, - use_rounding: bool = True, + auto_truncate=None, fhe_ensembling: bool = False, output_n_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE, ) -> Tuple[Callable, List[UniformQuantizer], onnx.ModelProto]: @@ -339,7 +339,7 @@ def tree_to_numpy( Args: model (Callable): The tree model to convert. x (numpy.ndarray): The input data. - use_rounding (bool): Determines whether the rounding feature is enabled or disabled. + auto_truncate (TODO): Determines whether the rounding feature is enabled or disabled. Default to True. fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE. Default to False. @@ -355,7 +355,7 @@ def tree_to_numpy( # mypy assert output_n_bits is not None - lsbs_to_remove_for_trees: Optional[Tuple[int, int]] = None + # lsbs_to_remove_for_trees: Optional[Tuple[int, int]] = None assert_true( framework in ["xgboost", "sklearn"], @@ -365,15 +365,6 @@ def tree_to_numpy( # Execute with 1 example for efficiency in large data scenarios to prevent slowdown onnx_model = get_onnx_model(model, x[:1], framework) - # Compute for tree-based models the LSB to remove in stage 1 and stage 2 - if use_rounding: - # First LSB refers to Less or LessOrEqual comparisons - # Second LSB refers to Equal comparison - lsbs_to_remove_for_trees = _compute_lsb_to_remove_for_trees(onnx_model, x) - - # mypy - assert len(lsbs_to_remove_for_trees) == 2 - # Get the expected number of ONNX outputs in the sklearn model. expected_number_of_outputs = 1 if is_regressor_or_partial_regressor(model) else 2 @@ -387,7 +378,7 @@ def tree_to_numpy( q_y = tree_values_preprocessing(onnx_model, framework, output_n_bits) _tree_inference, onnx_model = get_equivalent_numpy_forward_from_onnx_tree( - onnx_model, lsbs_to_remove_for_trees=lsbs_to_remove_for_trees + onnx_model, auto_truncate=auto_truncate ) return (_tree_inference, [q_y.quantizer], onnx_model) diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index a75d3c87b..4a8af6d47 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -1138,7 +1138,7 @@ def check_load_fitted_sklearn_linear_models(model_class, n_bits, x, y, check_flo ) -def check_rounding_consistency( +def check_truncate_consistency( model, x, y, @@ -1146,53 +1146,53 @@ def check_rounding_consistency( metric, is_weekly_option, ): - """Test that Concrete ML without and with rounding are 'equivalent'.""" + """Test that Concrete ML without and with truncate are 'equivalent'.""" # Run the test with more samples during weekly CIs if is_weekly_option: fhe_test = get_random_samples(x, n_sample=5) - # Check that rounding is enabled - assert os.environ.get("TREES_USE_ROUNDING") == "1", "'TREES_USE_ROUNDING' is not enabled" + # Check that truncate is enabled + assert os.environ.get("TREES_USE_TRUNCATE") == "1", "'TREES_USE_TRUNCATE' is not enabled" - # Fit and compile with rounding enabled + # Fit and compile with truncate enabled fit_and_compile(model, x, y) - rounded_predict_quantized = predict_method(x, fhe="disable") - rounded_predict_simulate = predict_method(x, fhe="simulate") + truncate_predict_quantized = predict_method(x, fhe="disable") + truncate_predict_simulate = predict_method(x, fhe="simulate") # Compute the FHE predictions only during weekly CIs if is_weekly_option: - rounded_predict_fhe = predict_method(fhe_test, fhe="execute") + truncate_predict_fhe = predict_method(fhe_test, fhe="execute") with pytest.MonkeyPatch.context() as mp_context: - # Disable rounding - mp_context.setenv("TREES_USE_ROUNDING", "0") + # Disable truncate + mp_context.setenv("TREES_USE_TRUNCATE", "0") - # Check that rounding is disabled - assert os.environ.get("TREES_USE_ROUNDING") == "0", "'TREES_USE_ROUNDING' is not disabled" + # Check that truncate is disabled + assert os.environ.get("TREES_USE_TRUNCATE") == "0", "'TREES_USE_TRUNCATE' is not disabled" with pytest.warns( DeprecationWarning, match=( - "Using Concrete tree-based models without the `rounding feature` is " "deprecated.*" + "Using Concrete tree-based models without the `truncate feature` is " "deprecated.*" ), ): - # Fit and compile without rounding + # Fit and compile without truncate fit_and_compile(model, x, y) - not_rounded_predict_quantized = predict_method(x, fhe="disable") - not_rounded_predict_simulate = predict_method(x, fhe="simulate") + not_truncate_predict_quantized = predict_method(x, fhe="disable") + not_truncate_predict_simulate = predict_method(x, fhe="simulate") - metric(rounded_predict_quantized, not_rounded_predict_quantized) - metric(rounded_predict_simulate, not_rounded_predict_simulate) + metric(truncate_predict_quantized, not_truncate_predict_quantized) + metric(truncate_predict_simulate, not_truncate_predict_simulate) # Compute the FHE predictions only during weekly CIs if is_weekly_option: - not_rounded_predict_fhe = predict_method(fhe_test, fhe="execute") - metric(rounded_predict_fhe, not_rounded_predict_fhe) + not_truncate_predict_fhe = predict_method(fhe_test, fhe="execute") + metric(truncate_predict_fhe, not_truncate_predict_fhe) # Check that the maximum bit-width of the circuit with rounding is at most: # maximum bit-width (of the circuit without rounding) + 2 @@ -1238,10 +1238,7 @@ def check_fhe_sum_for_tree_based_models( # Sanity check array_allclose_and_same_shape(fhe_sum_predict_quantized, fhe_sum_predict_simulate) - - # Check that we have the exact same predictions array_allclose_and_same_shape(fhe_sum_predict_quantized, non_fhe_sum_predict_quantized) - array_allclose_and_same_shape(fhe_sum_predict_simulate, non_fhe_sum_predict_simulate) array_allclose_and_same_shape(fhe_sum_predict_fhe, non_fhe_sum_predict_fhe) @@ -1879,7 +1876,7 @@ def test_linear_models_have_no_tlu( # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4179 @pytest.mark.parametrize("model_class, parameters", get_sklearn_tree_models_and_datasets()) @pytest.mark.parametrize("n_bits", [2, 5, 10]) -def test_rounding_consistency_for_regular_models( +def test_truncate_consistency_for_regular_models( model_class, parameters, n_bits, @@ -1892,7 +1889,7 @@ def test_rounding_consistency_for_regular_models( """Test that Concrete ML without and with rounding are 'equivalent'.""" if verbose: - print("Run check_rounding_consistency") + print("Run check_truncate_consistency") model = instantiate_model_generic(model_class, n_bits=n_bits) @@ -1907,7 +1904,7 @@ def test_rounding_consistency_for_regular_models( predict_method = model.predict metric = check_accuracy - check_rounding_consistency( + check_truncate_consistency( model, x, y,