diff --git a/src/concrete/ml/onnx/convert.py b/src/concrete/ml/onnx/convert.py index 54f7aad92..86ee5f0de 100644 --- a/src/concrete/ml/onnx/convert.py +++ b/src/concrete/ml/onnx/convert.py @@ -4,7 +4,7 @@ import tempfile import warnings from pathlib import Path -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Tuple, Union import numpy import onnx @@ -255,7 +255,7 @@ def get_equivalent_numpy_forward_from_onnx( def get_equivalent_numpy_forward_from_onnx_tree( onnx_model: onnx.ModelProto, check_model: bool = True, - auto_truncate = None, + auto_truncate=None, ) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]: """Get the numpy equivalent forward of the provided ONNX model for tree-based models only. @@ -264,7 +264,7 @@ def get_equivalent_numpy_forward_from_onnx_tree( forward. check_model (bool): set to True to run the onnx checker on the model. Defaults to True. - auto_truncate: This parameter is exclusively used for + auto_truncate (TODO): This parameter is exclusively used for optimizing tree-based models. It contains the values of the least significant bits to remove during the tree traversal, where the first value refers to the first comparison (either "less" or "less_or_equal"), while the second value refers to the "Equal" diff --git a/src/concrete/ml/onnx/onnx_impl_utils.py b/src/concrete/ml/onnx/onnx_impl_utils.py index 9a68052f0..156d845e8 100644 --- a/src/concrete/ml/onnx/onnx_impl_utils.py +++ b/src/concrete/ml/onnx/onnx_impl_utils.py @@ -3,10 +3,9 @@ from typing import Callable, Tuple, Union import numpy - -from concrete.fhe import truncate_bit_pattern, round_bit_pattern from concrete.fhe import conv as fhe_conv from concrete.fhe import ones as fhe_ones +from concrete.fhe import truncate_bit_pattern from concrete.fhe.tracing import Tracer from ..common.debugging import assert_true @@ -241,12 +240,12 @@ def onnx_avgpool_compute_norm_const( def rounded_comparison( x: numpy.ndarray, y: numpy.ndarray, auto_truncate, operation: ComparisonOperationType ) -> Tuple[bool]: - """Comparison operation using `round_bit_pattern` function. + """Comparison operation using `truncate_bit_pattern` function. - `round_bit_pattern` rounds the bit pattern of an integer to the closer + `truncate_bit_pattern` rounds the bit pattern of an integer to the closer It also checks for any potential overflow. If so, it readjusts the LSBs accordingly. - The parameter `lsbs_to_remove` in `round_bit_pattern` can either be an integer specifying the + The parameter `lsbs_to_remove` in `truncate_bit_pattern` can either be an integer specifying the number of LSBS to remove, or an `AutoRounder` object that determines the required number of LSBs based on the specified number of MSBs to retain. But in our case, we choose to compute the LSBs manually. @@ -254,8 +253,8 @@ def rounded_comparison( Args: x (numpy.ndarray): Input tensor y (numpy.ndarray): Input tensor - lsbs_to_remove (int): Number of the least significant bits to remove - operation (ComparisonOperationType): Comparison operation, which can `<`, `<=` and `==` + auto_truncate: TODO + operation: TODO Returns: Tuple[bool]: If x and y satisfy the comparison operator. @@ -265,17 +264,3 @@ def rounded_comparison( rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=auto_truncate) return (operation(rounded_subtraction),) - - # # To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y' - # rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=auto_truncate) - # assert isinstance(lsbs_to_remove, int) - - # # Workaround: in this context, `round_bit_pattern` is used as a truncate operation. - # # Consequently, we subtract a term, called `half` that will subsequently be re-added during the - # # `round_bit_pattern` process. - # half = 1 << (lsbs_to_remove - 1) - - # # To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y' - # rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=lsbs_to_remove) - - # return (operation(rounded_subtraction),) diff --git a/src/concrete/ml/onnx/onnx_utils.py b/src/concrete/ml/onnx/onnx_utils.py index da2ad80d3..1077e3f4b 100644 --- a/src/concrete/ml/onnx/onnx_utils.py +++ b/src/concrete/ml/onnx/onnx_utils.py @@ -213,7 +213,7 @@ # Original file: # https://github.com/google/jax/blob/f6d329b2d9b5f83c6a59e5739aa1ca8d4d1ffa1c/examples/onnx2xla.py -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Tuple import numpy import onnx @@ -413,7 +413,7 @@ } # All numpy operators used for tree-based models that support auto rounding ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL = { - "Less": rounded_numpy_less_for_trees, + "Less": rounded_numpy_less_for_trees, # type: ignore[dict-item] "Equal": rounded_numpy_equal_for_trees, "LessOrEqual": rounded_numpy_less_or_equal_for_trees, } diff --git a/src/concrete/ml/onnx/ops_impl.py b/src/concrete/ml/onnx/ops_impl.py index 661e13c98..7f2730dee 100644 --- a/src/concrete/ml/onnx/ops_impl.py +++ b/src/concrete/ml/onnx/ops_impl.py @@ -892,7 +892,7 @@ def rounded_numpy_equal_for_trees( x: numpy.ndarray, y: numpy.ndarray, *, - auto_truncate = None, + auto_truncate=None, ) -> Tuple[numpy.ndarray]: """Compute rounded equal in numpy according to ONNX spec for tree-based models only. @@ -925,7 +925,6 @@ def rounded_numpy_equal_for_trees( return (numpy.equal(x, y),) - def numpy_equal_float( x: numpy.ndarray, y: numpy.ndarray, @@ -1096,7 +1095,6 @@ def rounded_numpy_less_for_trees( # numpy.less(x, y) is equivalent to : # x - y <= 0 => round_bit_pattern(x - y - half) < 0 if auto_truncate is not None: - #print("Use truncate for <") return rounded_comparison(x, y, auto_truncate, operation=lambda x: x < 0) # Else, default numpy_less operator @@ -1145,7 +1143,7 @@ def rounded_numpy_less_or_equal_for_trees( x: numpy.ndarray, y: numpy.ndarray, *, - auto_truncate = None, + auto_truncate=None, ) -> Tuple[numpy.ndarray]: """Compute rounded less or equal in numpy according to ONNX spec for tree-based models only. @@ -1161,21 +1159,15 @@ def rounded_numpy_less_or_equal_for_trees( Tuple[numpy.ndarray]: Output tensor """ - # numpy.less_equal(x, y) <= 0 is equivalent to : - # np.less_equal(x, y), truncate_bit_pattern((y - x), lsbs_to_remove=r) >= 0 - # option 1: x - y <= 0 => round_bit_pattern(x - y) <= 0 - # gives bad results for : 0 < x - y <= 2**lsbs_to_remove because truncate_bit_pattern(x - y, lsb) = 0 - # option 2: y - x >= 0 => round_bit_pattern(y - x) >= 0 - - if auto_truncate is not None: - #print("Use truncate for <=") - return rounded_comparison(y, x, auto_truncate, operation=lambda x: x >= 0) # numpy.less_equal(x, y) <= y is equivalent to : - # option 1: x - y <= 0 => round_bit_pattern(x - y + half) <= 0 or - # option 2: y - x >= 0 => round_bit_pattern(y - x - half) >= 0 + # option 1: x - y <= 0 => truncate_bit_pattern(x - y + half) <= 0 or + # option 2: y - x >= 0 => truncate_bit_pattern(y - x - half) >= 0 # Option 2 is selected because it adheres to the established pattern in `rounded_comparison` - # which does: (a - b) - half. + # which does: (a - b). + + if auto_truncate is not None: + return rounded_comparison(y, x, auto_truncate, operation=lambda x: x >= 0) # Else, default numpy_less_or_equal operator return numpy_less_or_equal(x, y) diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index 46dcb6623..3b1275ac3 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -102,10 +102,10 @@ # Most significant bits to retain when applying rounding to the tree MSB_TO_KEEP_FOR_TREES = 1 -# Enable rounding feature for all tree-based models by default +# Enable truncate feature for all tree-based models by default # Note: This setting is fixed and cannot be altered by users # However, for internal testing purposes, we retain the capability to disable this feature -os.environ["TREES_USE_ROUNDING"] = os.environ.get("TREES_USE_ROUNDING", "1") +os.environ["TREES_USE_TRUNCATE"] = os.environ.get("TREES_USE_TRUNCATE", "1") # pylint: disable=too-many-public-methods @@ -1326,8 +1326,8 @@ def fit(self, X: Data, y: Target, **fit_parameters): self.output_quantizers = [] #: Determines the LSB to remove given a `target_msbs` - self.auto_truncate = cp.AutoTruncator(target_msbs=MSB_TO_KEEP_FOR_TREES) - + self._auto_truncate = cp.AutoTruncator(target_msbs=MSB_TO_KEEP_FOR_TREES) + X, y = check_X_y_and_assert_multi_output(X, y) q_X = numpy.zeros_like(X) @@ -1353,25 +1353,23 @@ def fit(self, X: Data, y: Target, **fit_parameters): assert self.sklearn_model is not None, self._sklearn_model_is_not_fitted_error_message() # Enable optimized computation - enable_truncate = os.environ.get("TREES_USE_ROUNDING", "1") == "1" + enable_truncate = os.environ.get("TREES_USE_TRUNCATE", "1") == "1" if not enable_truncate: warnings.simplefilter("always") warnings.warn( - "Using Concrete tree-based models without the `rounding feature` is deprecated. " - "Consider setting 'use_rounding' to `True` for making the FHE inference faster " + "Using Concrete tree-based models without the `truncate feature` is deprecated. " + "Consider setting 'use_truncate' to `True` for making the FHE inference faster " "and key generation.", category=DeprecationWarning, stacklevel=2, ) - self.auto_truncate = None + self._auto_truncate = None - print(f"{self.auto_truncate=}") - self._tree_inference, self.output_quantizers, self.onnx_model_ = tree_to_numpy( self.sklearn_model, q_X, - auto_truncate=self.auto_truncate, + auto_truncate=self._auto_truncate, fhe_ensembling=self._fhe_ensembling, framework=self.framework, output_n_bits=self.n_bits["op_leaves"], @@ -1380,7 +1378,7 @@ def fit(self, X: Data, y: Target, **fit_parameters): # Adjust the truncate if enable_truncate: inputset = numpy.array(list(_get_inputset_generator(q_X))).astype(int) - self.auto_truncate.adjust(self._tree_inference, inputset) + self._auto_truncate.adjust(self._tree_inference, inputset) self._tree_inference(q_X.astype("int")) self._is_fitted = True diff --git a/src/concrete/ml/sklearn/tree_to_numpy.py b/src/concrete/ml/sklearn/tree_to_numpy.py index 3e088d321..b3b72f88b 100644 --- a/src/concrete/ml/sklearn/tree_to_numpy.py +++ b/src/concrete/ml/sklearn/tree_to_numpy.py @@ -1,7 +1,7 @@ """Implements the conversion of a tree model to a numpy function.""" import math import warnings -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Tuple import numpy import onnx @@ -330,7 +330,7 @@ def tree_to_numpy( model: Callable, x: numpy.ndarray, framework: str, - auto_truncate: bool = None, + auto_truncate=None, fhe_ensembling: bool = False, output_n_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE, ) -> Tuple[Callable, List[UniformQuantizer], onnx.ModelProto]: @@ -339,7 +339,7 @@ def tree_to_numpy( Args: model (Callable): The tree model to convert. x (numpy.ndarray): The input data. - auto_truncate (bool): Determines whether the rounding feature is enabled or disabled. + auto_truncate (TODO): Determines whether the rounding feature is enabled or disabled. Default to True. fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE. Default to False. @@ -365,15 +365,6 @@ def tree_to_numpy( # Execute with 1 example for efficiency in large data scenarios to prevent slowdown onnx_model = get_onnx_model(model, x[:1], framework) - # # Compute for tree-based models the LSB to remove in stage 1 and stage 2 - # if use_rounding: - # # First LSB refers to Less or LessOrEqual comparisons - # # Second LSB refers to Equal comparison - # lsbs_to_remove_for_trees = _compute_lsb_to_remove_for_trees(onnx_model, x) - - # # mypy - # assert len(lsbs_to_remove_for_trees) == 2 - # Get the expected number of ONNX outputs in the sklearn model. expected_number_of_outputs = 1 if is_regressor_or_partial_regressor(model) else 2 diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index 58423df78..d77a67cf9 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -1138,7 +1138,7 @@ def check_load_fitted_sklearn_linear_models(model_class, n_bits, x, y, check_flo ) -def check_rounding_consistency( +def check_truncate_consistency( model, x, y, @@ -1146,53 +1146,53 @@ def check_rounding_consistency( metric, is_weekly_option, ): - """Test that Concrete ML without and with rounding are 'equivalent'.""" + """Test that Concrete ML without and with truncate are 'equivalent'.""" # Run the test with more samples during weekly CIs if is_weekly_option: fhe_test = get_random_samples(x, n_sample=5) - # Check that rounding is enabled - assert os.environ.get("TREES_USE_ROUNDING") == "1", "'TREES_USE_ROUNDING' is not enabled" + # Check that truncate is enabled + assert os.environ.get("TREES_USE_TRUNCATE") == "1", "'TREES_USE_TRUNCATE' is not enabled" - # Fit and compile with rounding enabled + # Fit and compile with truncate enabled fit_and_compile(model, x, y) - rounded_predict_quantized = predict_method(x, fhe="disable") - rounded_predict_simulate = predict_method(x, fhe="simulate") + truncate_predict_quantized = predict_method(x, fhe="disable") + truncate_predict_simulate = predict_method(x, fhe="simulate") # Compute the FHE predictions only during weekly CIs if is_weekly_option: - rounded_predict_fhe = predict_method(fhe_test, fhe="execute") + truncate_predict_fhe = predict_method(fhe_test, fhe="execute") with pytest.MonkeyPatch.context() as mp_context: - # Disable rounding - mp_context.setenv("TREES_USE_ROUNDING", "0") + # Disable truncate + mp_context.setenv("TREES_USE_TRUNCATE", "0") - # Check that rounding is disabled - assert os.environ.get("TREES_USE_ROUNDING") == "0", "'TREES_USE_ROUNDING' is not disabled" + # Check that truncate is disabled + assert os.environ.get("TREES_USE_TRUNCATE") == "0", "'TREES_USE_TRUNCATE' is not disabled" with pytest.warns( DeprecationWarning, match=( - "Using Concrete tree-based models without the `rounding feature` is " "deprecated.*" + "Using Concrete tree-based models without the `truncate feature` is " "deprecated.*" ), ): - # Fit and compile without rounding + # Fit and compile without truncate fit_and_compile(model, x, y) - not_rounded_predict_quantized = predict_method(x, fhe="disable") - not_rounded_predict_simulate = predict_method(x, fhe="simulate") + not_truncate_predict_quantized = predict_method(x, fhe="disable") + not_truncate_predict_simulate = predict_method(x, fhe="simulate") - metric(rounded_predict_quantized, not_rounded_predict_quantized) - metric(rounded_predict_simulate, not_rounded_predict_simulate) + metric(truncate_predict_quantized, not_truncate_predict_quantized) + metric(truncate_predict_simulate, not_truncate_predict_simulate) # Compute the FHE predictions only during weekly CIs if is_weekly_option: - not_rounded_predict_fhe = predict_method(fhe_test, fhe="execute") - metric(rounded_predict_fhe, not_rounded_predict_fhe) + not_truncate_predict_fhe = predict_method(fhe_test, fhe="execute") + metric(truncate_predict_fhe, not_truncate_predict_fhe) # Check that the maximum bit-width of the circuit with rounding is at most: # maximum bit-width (of the circuit without rounding) + 2 @@ -1881,7 +1881,7 @@ def test_linear_models_have_no_tlu( # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4179 @pytest.mark.parametrize("model_class, parameters", get_sklearn_tree_models_and_datasets()) @pytest.mark.parametrize("n_bits", [2, 5, 10]) -def test_rounding_consistency_for_regular_models( +def test_truncate_consistency_for_regular_models( model_class, parameters, n_bits, @@ -1894,7 +1894,7 @@ def test_rounding_consistency_for_regular_models( """Test that Concrete ML without and with rounding are 'equivalent'.""" if verbose: - print("Run check_rounding_consistency") + print("Run check_truncate_consistency") model = instantiate_model_generic(model_class, n_bits=n_bits) @@ -1909,7 +1909,7 @@ def test_rounding_consistency_for_regular_models( predict_method = model.predict metric = check_accuracy - check_rounding_consistency( + check_truncate_consistency( model, x, y,