Skip to content

Commit

Permalink
chore: update v3
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed Jan 23, 2024
1 parent 14d9dc0 commit 3248216
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
16 changes: 8 additions & 8 deletions src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,25 @@ def assert_add_node_and_constant_in_xgboost_regressor_graph(onnx_model: onnx.Mod
)


def add_transpose_after_last_node(onnx_model: onnx.ModelProto, use_fhe_sum: bool):
def add_transpose_after_last_node(onnx_model: onnx.ModelProto, fhe_ensembling: bool):
"""Add transpose after last node.
Args:
onnx_model (onnx.ModelProto): The ONNX model.
use_fhe_sum (bool): Determines whether the sum of the trees' outputs is computed in FHE.
fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE.
Default to False.
"""
# Get the output node
output_node = onnx_model.graph.output[0]

# The state of the 'use_fhe_sum' variable affects the structure of the model's ONNX graph.
# The state of the 'fhe_ensembling' variable affects the structure of the model's ONNX graph.
# When the option is enabled, the graph is cut after the ReduceSum node.
# When it is disabled, the graph is cut at the ReduceSum node, which alters the output shape.
# Therefore, it is necessary to adjust this shape with the correct permutation.

# When using FHE sum for tree ensembles, create the node with perm attribute equal to (1, 0)
# Otherwise, create the node with perm attribute equal to (2, 1, 0)
perm = [1, 0] if use_fhe_sum else [2, 1, 0]
perm = [1, 0] if fhe_ensembling else [2, 1, 0]

transpose_node = onnx.helper.make_node(
"Transpose",
Expand Down Expand Up @@ -221,7 +221,7 @@ def tree_onnx_graph_preprocessing(
onnx_model: onnx.ModelProto,
framework: str,
expected_number_of_outputs: int,
use_fhe_sum: bool = False,
fhe_ensembling: bool = False,
):
"""Apply pre-processing onto the ONNX graph.
Expand All @@ -230,7 +230,7 @@ def tree_onnx_graph_preprocessing(
framework (str): The framework from which the ONNX model is generated.
(options: 'xgboost', 'sklearn')
expected_number_of_outputs (int): The expected number of outputs in the ONNX model.
use_fhe_sum (bool): Determines whether the sum of the trees' outputs is computed in FHE.
fhe_ensembling (bool): Determines whether the sum of the trees' outputs is computed in FHE.
Default to False.
"""
# Make sure the ONNX version returned by Hummingbird is OPSET_VERSION_FOR_ONNX_EXPORT
Expand Down Expand Up @@ -258,7 +258,7 @@ def tree_onnx_graph_preprocessing(

# Cut the graph after the ReduceSum node to remove
# argmax, sigmoid, softmax from the graph.
if use_fhe_sum:
if fhe_ensembling:
clean_graph_after_node_op_type(onnx_model, "ReduceSum")
else:
clean_graph_at_node_op_type(onnx_model, "ReduceSum")
Expand All @@ -274,7 +274,7 @@ def tree_onnx_graph_preprocessing(
# sklearn models apply the reduce sum before the transpose.
# To have equivalent output between xgboost in sklearn,
# apply the transpose before returning the output.
add_transpose_after_last_node(onnx_model, use_fhe_sum)
add_transpose_after_last_node(onnx_model, fhe_ensembling)

# Cast nodes are not necessary so remove them.
remove_node_types(onnx_model, op_types_to_remove=["Cast"])
Expand Down
4 changes: 2 additions & 2 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def get_n_bits_non_correctness(model_class):
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3979
if get_model_name(model_class) == "KNeighborsClassifier":
n_bits = 2

n_bits = min(N_BITS_REGULAR_BUILDS)
else:
n_bits = min(N_BITS_REGULAR_BUILDS)

return n_bits

Expand Down

0 comments on commit 3248216

Please sign in to comment.