From 65026a7b2fcbbd9c0d42f50a51fba410840e226f Mon Sep 17 00:00:00 2001 From: kcelia Date: Mon, 15 Jan 2024 15:21:54 +0100 Subject: [PATCH] chore: update --- src/concrete/ml/sklearn/base.py | 17 ++--- src/concrete/ml/sklearn/tree_to_numpy.py | 18 +++-- tests/sklearn/test_dump_onnx.py | 15 ++--- tests/sklearn/test_sklearn_models.py | 85 ++++++++++++++---------- 4 files changed, 77 insertions(+), 58 deletions(-) diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index 278a8231b4..b1b5eae199 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -1318,7 +1318,6 @@ def fit(self, X: Data, y: Target, **fit_parameters): # Convert the n_bits attribute into a proper dictionary self.n_bits = get_n_bits_dict_trees(self.n_bits) - print(f"{self.n_bits=}") # Quantization of each feature in X for i in range(X.shape[1]): @@ -1430,16 +1429,18 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy. y_pred = self.post_processing(y_pred) return y_pred - # def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray: + def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray: + # Sum all tree outputs + # Remove the sum once we handle multi-precision circuits + # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451 + if os.getenv("TREES_USE_FHE_SUM") == "0": + y_preds = numpy.sum(y_preds, axis=-1) + assert_true(y_preds.ndim == 2, "y_preds should be a 2D array") + return y_preds - # # Sum all tree outputs - # # Remove the sum once we handle multi-precision circuits - # # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451 - # y_preds = numpy.sum(y_preds, axis=-1) + return super().post_processing(y_preds) - # assert_true(y_preds.ndim == 2, "y_preds should be a 2D array") - # return y_preds class BaseTreeRegressorMixin(BaseTreeEstimatorMixin, sklearn.base.RegressorMixin, ABC): """Mixin class for tree-based regressors. diff --git a/src/concrete/ml/sklearn/tree_to_numpy.py b/src/concrete/ml/sklearn/tree_to_numpy.py index 49b86705ef..65a115759c 100644 --- a/src/concrete/ml/sklearn/tree_to_numpy.py +++ b/src/concrete/ml/sklearn/tree_to_numpy.py @@ -18,7 +18,11 @@ OPSET_VERSION_FOR_ONNX_EXPORT, get_equivalent_numpy_forward_from_onnx_tree, ) -from ..onnx.onnx_model_manipulations import clean_graph_after_node_op_type, remove_node_types +from ..onnx.onnx_model_manipulations import ( + clean_graph_after_node_op_type, + clean_graph_at_node_op_type, + remove_node_types, +) from ..onnx.onnx_utils import get_op_type from ..quantization import QuantizedArray from ..quantization.quantizers import UniformQuantizer @@ -142,13 +146,14 @@ def add_transpose_after_last_node(onnx_model: onnx.ModelProto): # Get the output node output_node = onnx_model.graph.output[0] + # When using FHE sum for tree ensembles, create the node with perm attribute equal to (1, 0) if os.getenv("TREES_USE_FHE_SUM") == "1": - # Create the node with perm attribute equal to (1, 0) perm = [1, 0] + + # Otherwise, create the node with perm attribute equal to (2, 1, 0) else: - # Create the node with perm attribute equal to (2, 1, 0) perm = [2, 1, 0] - + transpose_node = onnx.helper.make_node( "Transpose", inputs=[output_node.name], @@ -246,7 +251,10 @@ def tree_onnx_graph_preprocessing( # Cut the graph after the ReduceSum node to remove # argmax, sigmoid, softmax from the graph. - clean_graph_after_node_op_type(onnx_model, "ReduceSum") + if os.getenv("TREES_USE_FHE_SUM") == "1": + clean_graph_after_node_op_type(onnx_model, "ReduceSum") + else: + clean_graph_at_node_op_type(onnx_model, "ReduceSum") if framework == "xgboost": # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/2778 diff --git a/tests/sklearn/test_dump_onnx.py b/tests/sklearn/test_dump_onnx.py index a5fe6497a1..484e2c0d75 100644 --- a/tests/sklearn/test_dump_onnx.py +++ b/tests/sklearn/test_dump_onnx.py @@ -69,7 +69,6 @@ def check_onnx_file_dump(model_class, parameters, load_data, str_expected, defau str_model = onnx.helper.printable_graph(onnx_model.graph) print(f"{model_name}:") print(str_model) - # Test equality when it does not depend on seeds # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3266 if not is_model_class_in_a_list(model_class, _get_sklearn_tree_models(select="RandomForest")): @@ -228,7 +227,7 @@ def test_dump( %transposed_output = Transpose[perm = [1, 0]](%/_operators.0/ReduceSum_output_0) """ if os.getenv("TREES_USE_FHE_SUM") == "1" - else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)" + else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)\n " ) + """return %transposed_output }""", @@ -307,7 +306,7 @@ def test_dump( %transposed_output = Transpose[perm = [1, 0]](%/_operators.0/ReduceSum_output_0) """ if os.getenv("TREES_USE_FHE_SUM") == "1" - else "" + else "%transposed_output = Transpose[perm = [2, 1, 0]](%/_operators.0/Reshape_3_output_0)\n " ) + """return %transposed_output }""", @@ -359,8 +358,7 @@ def test_dump( return %/_operators.0/ReduceSum_output_0 }""" if os.getenv("TREES_USE_FHE_SUM") == "1" - else """return %/_operators.0/Reshape_4_output_0 - }""" + else "return %/_operators.0/Reshape_4_output_0\n}" ), "RandomForestRegressor": """graph torch_jit ( %input_0[DOUBLE, symx10] @@ -401,9 +399,8 @@ def test_dump( %/_operators.0/Constant_1_output_0[INT64, 2] %/_operators.0/Constant_2_output_0[INT64, 3] %/_operators.0/Constant_3_output_0[INT64, 3] - %/_operators.0/Constant_4_output_0[INT64, 3] - """ - + ("%onnx::ReduceSum_27[INT64, 1]" if os.getenv("TREES_USE_FHE_SUM") == "1" else "") + %/_operators.0/Constant_4_output_0[INT64, 3]""" + + ("\n %onnx::ReduceSum_27[INT64, 1]" if os.getenv("TREES_USE_FHE_SUM") == "1" else "") + """ ) { %/_operators.0/Gemm_output_0 = Gemm[alpha = 1, beta = 0, transB = 1](%_operators.0.weight_1, %input_0) @@ -424,7 +421,7 @@ def test_dump( return %/_operators.0/ReduceSum_output_0 }""" if os.getenv("TREES_USE_FHE_SUM") == "1" - else "return %/_operators.0/Reshape_4_output_0" + else """return %/_operators.0/Reshape_4_output_0\n}""" ), "LinearRegression": """graph torch_jit ( %input_0[DOUBLE, symx10] diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index 1910d15b41..33dd945205 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -46,6 +46,7 @@ from concrete.ml.common.serialization.loaders import load, loads from concrete.ml.common.utils import ( USE_OLD_VL, + array_allclose_and_same_shape, get_model_class, get_model_name, is_classifier_or_partial_classifier, @@ -1206,47 +1207,73 @@ def check_rounding_consistency( def check_fhe_sum_consistency( + model_class, x, - predict_method, - metric, + y, + n_bits, is_weekly_option, ): - """Test that Concrete ML without and with rounding are 'equivalent'.""" + """Test that Concrete ML without and with FHE sum are 'equivalent'.""" # Run the test with more samples during weekly CIs if is_weekly_option: fhe_test = get_random_samples(x, n_sample=5) - # By default, FHE_SUM is disabled - fhe_sum_disabled = os.getenv("TREES_USE_FHE_SUM") == "1" + # By default, the summation of tree ensemble outputs is done in clear + fhe_sum_disabled = os.getenv("TREES_USE_FHE_SUM") == "0" assert fhe_sum_disabled + model_ref = instantiate_model_generic(model_class, n_bits=n_bits) + fit_and_compile(model_ref, x, y) + + # Check `predict_proba` for classifiers and `predict` for regressors + predict_method = ( + model_ref.predict_proba + if is_classifier_or_partial_classifier(model_class) + else model_ref.predict + ) + non_fhe_sum_predict_quantized = predict_method(x, fhe="disable") non_fhe_sum_predict_simulate = predict_method(x, fhe="simulate") + # Sanity check + array_allclose_and_same_shape(non_fhe_sum_predict_quantized, non_fhe_sum_predict_simulate) + # Compute the FHE predictions only during weekly CIs if is_weekly_option: - rounded_predict_fhe = predict_method(fhe_test, fhe="execute") + non_fhe_sum_predict_fhe = predict_method(fhe_test, fhe="execute") with pytest.MonkeyPatch.context() as mp_context: - # Enable FHE sum - mp_context.setenv("TREES_USE_FHE_SUM", "0") + # Enable the FHE summation of tree ensemble outputs + mp_context.setenv("TREES_USE_FHE_SUM", "1") - # Check that rounding is disabled - fhe_sum_enbled = os.environ.get("TREES_USE_FHE_SUM") == "0" - assert fhe_sum_enbled + # Check that the summation of tree ensemble outputs is enabled + fhe_sum_enabled = os.environ.get("TREES_USE_FHE_SUM") == "1" + assert fhe_sum_enabled + + model = model_class(**model_ref.get_params()) + fit_and_compile(model, x, y) + + # Check `predict_proba` for classifiers and `predict` for regressors + predict_method = ( + model.predict_proba + if is_classifier_or_partial_classifier(model_class) + else model.predict + ) fhe_sum_predict_quantized = predict_method(x, fhe="disable") fhe_sum_predict_simulate = predict_method(x, fhe="simulate") - metric(non_fhe_sum_predict_quantized, fhe_sum_predict_quantized) - metric(non_fhe_sum_predict_simulate, fhe_sum_predict_simulate) + # Sanity check + array_allclose_and_same_shape(fhe_sum_predict_quantized, fhe_sum_predict_simulate) - # Compute the FHE predictions only during weekly CIs - if is_weekly_option: - not_rounded_predict_fhe = predict_method(fhe_test, fhe="execute") - metric(rounded_predict_fhe, not_rounded_predict_fhe) + # Check that we have the exact same predictions + array_allclose_and_same_shape(fhe_sum_predict_quantized, non_fhe_sum_predict_quantized) + array_allclose_and_same_shape(fhe_sum_predict_simulate, non_fhe_sum_predict_simulate) + if is_weekly_option: + fhe_sum_predict_fhe = predict_method(fhe_test, fhe="execute") + array_allclose_and_same_shape(fhe_sum_predict_fhe, non_fhe_sum_predict_fhe) # Neural network models are skipped for this test @@ -1937,34 +1964,20 @@ def test_fhe_sum_for_tree_based_models( parameters, n_bits, load_data, - check_r2_score, - check_accuracy, is_weekly_option, - default_configuration, verbose=True, ): """Test that Concrete ML without and with rounding are 'equivalent'.""" if verbose: - print("Run check_rounding_consistency") + print("Run check_fhe_sum_consistency") - model, x = preamble(model_class, parameters, n_bits, load_data, is_weekly_option) - - # Compile the model to make sure we consider all possible attributes during the serialization - model.compile(x, default_configuration) - - # Check `predict_proba` for classifiers - if is_classifier_or_partial_classifier(model): - predict_method = model.predict_proba - metric = check_r2_score - else: - # Check `predict` for regressors - predict_method = model.predict - metric = check_accuracy + x, y = get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option) check_fhe_sum_consistency( + model_class, x, - predict_method, - metric, + y, + n_bits, is_weekly_option, )