Skip to content

Commit

Permalink
chore: update comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Jan 24, 2024
1 parent ab45587 commit 0ad0f6d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _inspect_tree_n_bits(n_bits):
"""

detailed_message = (
"Invalid 'n_bits', either pass a non-null positive integer or a dictionary containing "
"Invalid 'n_bits', either pass a strictly positive integer or a dictionary containing "
"integer values for the following keys:\n"
"- 'op_inputs' (mandatory): number of bits to quantize the input values\n"
"- 'op_leaves' (optional): number of bits to quantize the leaves, must be less than or "
Expand All @@ -63,7 +63,7 @@ def _inspect_tree_n_bits(n_bits):

if isinstance(n_bits, int):
if n_bits <= 0:
error_message = "n_bits must be a non-null, positive integer"
error_message = "n_bits must be a strictly positive integer"
elif isinstance(n_bits, dict):
if "op_inputs" not in n_bits.keys():
error_message = "Invalid keys in `n_bits` dictionary. The key 'op_inputs' is mandatory"
Expand All @@ -73,7 +73,7 @@ def _inspect_tree_n_bits(n_bits):
"(optional) are allowed"
)
elif not all(isinstance(value, int) and value > 0 for value in n_bits.values()):
error_message = "All values in 'n_bits' dictionary must be non-null, positive integers"
error_message = "All values in 'n_bits' dictionary must be strictly positive integers"

elif n_bits.get("op_leaves", 0) > n_bits.get("op_inputs", 0):
error_message = "'op_leaves' must be less than or equal to 'op_inputs'"
Expand Down
1 change: 0 additions & 1 deletion src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,7 +1335,6 @@ def fhe_ensembling(self, value: bool) -> None:
assert isinstance(value, bool), "Value must be a boolean type"

if value is True:
print("LAA")
warnings.simplefilter("always")
warnings.warn(
"Enabling `fhe_ensembling` computes the sum of the ouputs of tree ensembles in "
Expand Down
10 changes: 5 additions & 5 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,8 +1955,8 @@ def test_fhe_sum_for_tree_based_models(
@pytest.mark.parametrize(
"n_bits, error_message",
[
(0, "n_bits must be a non-null, positive integer"),
(-1, "n_bits must be a non-null, positive integer"),
(0, "n_bits must be a strictly positive integer"),
(-1, "n_bits must be a strictly positive integer"),
({"op_leaves": 2}, "The key 'op_inputs' is mandatory"),
(
{"op_inputs": 4, "op_leaves": 2, "op_weights": 2},
Expand All @@ -1965,15 +1965,15 @@ def test_fhe_sum_for_tree_based_models(
),
(
{"op_inputs": -2, "op_leaves": -5},
"All values in 'n_bits' dictionary must be non-null, positive integers",
"All values in 'n_bits' dictionary must be strictly positive integers",
),
({"op_inputs": 2, "op_leaves": 5}, "'op_leaves' must be less than or equal to 'op_inputs'"),
(0.5, "n_bits must be either an integer or a dictionary"),
],
)
@pytest.mark.parametrize("model_class", _get_sklearn_tree_models())
def test_invalid_n_bits_setting(model_class, n_bits, error_message):
"""Check if the model instantiation raises an exception with invalid 'n_bits' settings."""
"""Check if the model instantiation raises an exception with invalid `n_bits` settings."""

with pytest.raises(ValueError, match=f"{error_message}. Got '{type(n_bits)}' and '{n_bits}'.*"):
instantiate_model_generic(model_class, n_bits=n_bits)
Expand All @@ -1991,7 +1991,7 @@ def test_valid_n_bits_setting(
is_weekly_option,
verbose=True,
):
"""Check valid `n_bits' settings."""
"""Check valid `n_bits` settings."""

if verbose:
print("Run test_valid_n_bits_setting")
Expand Down

0 comments on commit 0ad0f6d

Please sign in to comment.