From 0ad0f6d1dad2014151cb839e6f18bb02878b7c44 Mon Sep 17 00:00:00 2001 From: kcelia Date: Wed, 24 Jan 2024 10:25:51 +0100 Subject: [PATCH] chore: update comments --- src/concrete/ml/quantization/post_training.py | 6 +++--- src/concrete/ml/sklearn/base.py | 1 - tests/sklearn/test_sklearn_models.py | 10 +++++----- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/concrete/ml/quantization/post_training.py b/src/concrete/ml/quantization/post_training.py index 621697f14..46bed0214 100644 --- a/src/concrete/ml/quantization/post_training.py +++ b/src/concrete/ml/quantization/post_training.py @@ -50,7 +50,7 @@ def _inspect_tree_n_bits(n_bits): """ detailed_message = ( - "Invalid 'n_bits', either pass a non-null positive integer or a dictionary containing " + "Invalid 'n_bits', either pass a strictly positive integer or a dictionary containing " "integer values for the following keys:\n" "- 'op_inputs' (mandatory): number of bits to quantize the input values\n" "- 'op_leaves' (optional): number of bits to quantize the leaves, must be less than or " @@ -63,7 +63,7 @@ def _inspect_tree_n_bits(n_bits): if isinstance(n_bits, int): if n_bits <= 0: - error_message = "n_bits must be a non-null, positive integer" + error_message = "n_bits must be a strictly positive integer" elif isinstance(n_bits, dict): if "op_inputs" not in n_bits.keys(): error_message = "Invalid keys in `n_bits` dictionary. The key 'op_inputs' is mandatory" @@ -73,7 +73,7 @@ def _inspect_tree_n_bits(n_bits): "(optional) are allowed" ) elif not all(isinstance(value, int) and value > 0 for value in n_bits.values()): - error_message = "All values in 'n_bits' dictionary must be non-null, positive integers" + error_message = "All values in 'n_bits' dictionary must be strictly positive integers" elif n_bits.get("op_leaves", 0) > n_bits.get("op_inputs", 0): error_message = "'op_leaves' must be less than or equal to 'op_inputs'" diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index b1ce3d55d..8b947b163 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -1335,7 +1335,6 @@ def fhe_ensembling(self, value: bool) -> None: assert isinstance(value, bool), "Value must be a boolean type" if value is True: - print("LAA") warnings.simplefilter("always") warnings.warn( "Enabling `fhe_ensembling` computes the sum of the ouputs of tree ensembles in " diff --git a/tests/sklearn/test_sklearn_models.py b/tests/sklearn/test_sklearn_models.py index cd156f808..6e5bd2a9e 100644 --- a/tests/sklearn/test_sklearn_models.py +++ b/tests/sklearn/test_sklearn_models.py @@ -1955,8 +1955,8 @@ def test_fhe_sum_for_tree_based_models( @pytest.mark.parametrize( "n_bits, error_message", [ - (0, "n_bits must be a non-null, positive integer"), - (-1, "n_bits must be a non-null, positive integer"), + (0, "n_bits must be a strictly positive integer"), + (-1, "n_bits must be a strictly positive integer"), ({"op_leaves": 2}, "The key 'op_inputs' is mandatory"), ( {"op_inputs": 4, "op_leaves": 2, "op_weights": 2}, @@ -1965,7 +1965,7 @@ def test_fhe_sum_for_tree_based_models( ), ( {"op_inputs": -2, "op_leaves": -5}, - "All values in 'n_bits' dictionary must be non-null, positive integers", + "All values in 'n_bits' dictionary must be strictly positive integers", ), ({"op_inputs": 2, "op_leaves": 5}, "'op_leaves' must be less than or equal to 'op_inputs'"), (0.5, "n_bits must be either an integer or a dictionary"), @@ -1973,7 +1973,7 @@ def test_fhe_sum_for_tree_based_models( ) @pytest.mark.parametrize("model_class", _get_sklearn_tree_models()) def test_invalid_n_bits_setting(model_class, n_bits, error_message): - """Check if the model instantiation raises an exception with invalid 'n_bits' settings.""" + """Check if the model instantiation raises an exception with invalid `n_bits` settings.""" with pytest.raises(ValueError, match=f"{error_message}. Got '{type(n_bits)}' and '{n_bits}'.*"): instantiate_model_generic(model_class, n_bits=n_bits) @@ -1991,7 +1991,7 @@ def test_valid_n_bits_setting( is_weekly_option, verbose=True, ): - """Check valid `n_bits' settings.""" + """Check valid `n_bits` settings.""" if verbose: print("Run test_valid_n_bits_setting")