Skip to content

Commit

Permalink
chore: update v2
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Jan 23, 2024
1 parent 7fddece commit 14d9dc0
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 31 deletions.
25 changes: 13 additions & 12 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,22 +1311,22 @@ def __init__(self, n_bits: Union[int, Dict[str, int]]):

#: Wether to perform the sum of the output's tree ensembles in FHE or not.
# By default, the decision of the tree ensembles is made in clear.
self._use_fhe_sum = False
self._fhe_ensembling = False

BaseEstimator.__init__(self)

@property
def use_fhe_sum(self) -> bool:
"""Property getter for `use_fhe_sum`.
def fhe_ensembling(self) -> bool:
"""Property getter for `_fhe_ensembling`.
Returns:
bool: The current setting of the `_use_fhe_sum` attribute.
bool: The current setting of the `fhe_ensembling` attribute.
"""
return self._use_fhe_sum
return self._fhe_ensembling

@use_fhe_sum.setter
def use_fhe_sum(self, value: bool) -> None:
"""Property setter for `use_fhe_sum`.
@fhe_ensembling.setter
def fhe_ensembling(self, value: bool) -> None:
"""Property setter for `fhe_ensembling`.
Args:
value (bool): Whether to enable or disable the feature.
Expand All @@ -1335,17 +1335,18 @@ def use_fhe_sum(self, value: bool) -> None:
assert isinstance(value, bool), "Value must be a boolean type"

if value is True:
print("LAA")
warnings.simplefilter("always")
warnings.warn(
"Enabling `use_fhe_sum` computes the sum of the ouputs of tree ensembles in FHE.\n"
"Enabling `fhe_ensembling` computes the sum of the ouputs of tree ensembles in FHE.\n"
"This may slow down the computation and increase the maximum bitwidth.\n"
"To optimize performance, consider reducing the quantization leaf precision.\n"
"Additionally, the model must be refitted for these changes to take effect.",
category=UserWarning,
stacklevel=2,
)

self._use_fhe_sum = value
self._fhe_ensembling = value

def fit(self, X: Data, y: Target, **fit_parameters):
# Reset for double fit
Expand Down Expand Up @@ -1395,7 +1396,7 @@ def fit(self, X: Data, y: Target, **fit_parameters):
self.sklearn_model,
q_X,
use_rounding=enable_rounding,
use_fhe_sum=self._use_fhe_sum,
fhe_ensembling=self.fhe_ensembling,
framework=self.framework,
output_n_bits=self.n_bits["op_leaves"],
)
Expand Down Expand Up @@ -1472,7 +1473,7 @@ def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
# Sum all tree outputs
# Remove the sum once we handle multi-precision circuits
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451
if not self._use_fhe_sum:
if not self._fhe_ensembling:
y_preds = numpy.sum(y_preds, axis=-1)

assert_true(y_preds.ndim == 2, "y_preds should be a 2D array")
Expand Down
6 changes: 3 additions & 3 deletions src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def tree_to_numpy(
x: numpy.ndarray,
framework: str,
use_rounding: bool = True,
use_fhe_sum: bool = False,
fhe_ensembling: bool = False,
output_n_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
) -> Tuple[Callable, List[UniformQuantizer], onnx.ModelProto]:
"""Convert the tree inference to a numpy functions using Hummingbird.
Expand All @@ -341,7 +341,7 @@ def tree_to_numpy(
x (numpy.ndarray): The input data.
use_rounding (bool): Determines whether the rounding feature is enabled or disabled.
Default to True.
use_fhe_sum (bool): Determines whether the sum of the trees' outputs is computed in FHE.
fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE.
Default to False.
framework (str): The framework from which the ONNX model is generated.
(options: 'xgboost', 'sklearn')
Expand Down Expand Up @@ -379,7 +379,7 @@ def tree_to_numpy(

# ONNX graph pre-processing to make the model FHE friendly
# i.e., delete irrelevant nodes and cut the graph before the final ensemble sum)
tree_onnx_graph_preprocessing(onnx_model, framework, expected_number_of_outputs, use_fhe_sum)
tree_onnx_graph_preprocessing(onnx_model, framework, expected_number_of_outputs, fhe_ensembling)

# Tree values pre-processing
# i.e., mainly predictions quantization
Expand Down
17 changes: 10 additions & 7 deletions tests/sklearn/test_dump_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
# pylint: disable=line-too-long


def check_onnx_file_dump(model_class, parameters, load_data, default_configuration, use_fhe_sum):
def check_onnx_file_dump(
model_class, parameters, load_data, default_configuration, use_fhe_sum=False
):
"""Fit the model and dump the corresponding ONNX."""

model_name = get_model_name(model_class)
Expand Down Expand Up @@ -498,9 +500,10 @@ def test_dump(
callbacks="disable",
)

check_onnx_file_dump(
model_class, parameters, load_data, default_configuration, use_fhe_sum=False
)
check_onnx_file_dump(
model_class, parameters, load_data, default_configuration, use_fhe_sum=True
)
check_onnx_file_dump(model_class, parameters, load_data, default_configuration)

# Additional tests exclusively dedicated for tree ensemble models.
if model_class in _get_sklearn_tree_models()[2:]:
check_onnx_file_dump(
model_class, parameters, load_data, default_configuration, use_fhe_sum=True
)
17 changes: 8 additions & 9 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,7 @@ def get_n_bits_non_correctness(model_class):
if get_model_name(model_class) == "KNeighborsClassifier":
n_bits = 2

# Adjust the quantization precision for tree-based model based on `TREES_USE_FHE_SUM` setting.
# When enabled, the circuit's bitwidth increases, potentially leading to Out-of-Memory issues.
# Therefore, the maximum quantization precision is 4 bits in this case.
elif model_class in _get_sklearn_tree_models() and os.environ.get("TREES_USE_FHE_SUM") == "1":
n_bits = min(min(N_BITS_REGULAR_BUILDS), 4)
else:
n_bits = min(N_BITS_REGULAR_BUILDS)
n_bits = min(N_BITS_REGULAR_BUILDS)

return n_bits

Expand Down Expand Up @@ -1218,7 +1212,7 @@ def check_fhe_sum_for_tree_based_models(
if is_weekly_option:
fhe_test = get_random_samples(x, n_sample=5)

assert not model.use_fhe_sum, "`use_fhe_sum` is disabled by default."
assert not model.fhe_ensembling, "`fhe_ensembling` is disabled by default."
fit_and_compile(model, x, y)

non_fhe_sum_predict_quantized = predict_method(x, fhe="disable")
Expand All @@ -1231,7 +1225,8 @@ def check_fhe_sum_for_tree_based_models(
if is_weekly_option:
non_fhe_sum_predict_fhe = predict_method(fhe_test, fhe="execute")

model.use_fhe_sum = True
with pytest.warns(UserWarning, match="Enabling `fhe_ensembling` .*"):
model.fhe_ensembling = True

fit_and_compile(model, x, y)

Expand Down Expand Up @@ -1955,6 +1950,8 @@ def test_fhe_sum_for_tree_based_models(
)


# This test should be extended to all built-in models.
# FIXME: https://github.com/zama-ai/concrete-ml-internal#4234
@pytest.mark.parametrize(
"n_bits, error_message",
[
Expand Down Expand Up @@ -1982,6 +1979,8 @@ def test_invalid_n_bits_setting(model_class, n_bits, error_message):
instantiate_model_generic(model_class, n_bits=n_bits)


# This test should be extended to all built-in models.
# FIXME: https://github.com/zama-ai/concrete-ml-internal#4234
@pytest.mark.parametrize("n_bits", [5, {"op_inputs": 5}, {"op_inputs": 2, "op_leaves": 1}])
@pytest.mark.parametrize("model_class, parameters", get_sklearn_tree_models_and_datasets())
def test_valid_n_bits_setting(
Expand Down

0 comments on commit 14d9dc0

Please sign in to comment.