Skip to content

Commit

Permalink
chore: fix drift seen in xgboost regressor
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Oct 13, 2023
1 parent a678807 commit 6c4b523
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 29 deletions.
50 changes: 25 additions & 25 deletions docs/advanced_examples/RegressorComparison.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs/advanced_examples/XGBRegressor.ipynb

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions src/concrete/ml/sklearn/tree_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,42 @@ def workaround_squeeze_node_xgboost(onnx_model: onnx.ModelProto):
onnx_model.graph.node[target_node_id_list[0]].input.insert(1, axes_input_name)


def assert_add_node_and_constant_in_xgboost_regressor_graph(onnx_model: onnx.ModelProto):
"""Assert if an Add node with a specific constant exists in the ONNX graph.
Args:
onnx_model (onnx.ModelProto): The ONNX model.
"""

constant_add_name = "_operators.0.base_prediction"
is_expected_add_node_present = False
initializer_value_correct = False

# Find the initializer with the specified name
initializer = next(
(init for init in onnx_model.graph.initializer if init.name == constant_add_name), None
)

# Check if the initializer exists and its value is 0.5
if initializer:
values = onnx.numpy_helper.to_array(initializer)
if values.size == 1 and values[0] == 0.5:
initializer_value_correct = True

# Iterate over all nodes in the model's graph
for node in onnx_model.graph.node:
# Check if the node is an "Add" node and has the
# specified initializer as one of its inputs
if node.op_type == "Add" and constant_add_name in node.input:
is_expected_add_node_present = True
break

assert_true(
is_expected_add_node_present and initializer_value_correct,
"XGBoostRegressor is not supported.",
)


def add_transpose_after_last_node(onnx_model: onnx.ModelProto):
"""Add transpose after last node.
Expand Down Expand Up @@ -183,6 +219,13 @@ def tree_onnx_graph_preprocessing(
on_error_msg=f"{len(onnx_model.graph.output)} != 2",
)

# Check that a XGBoostRegressor onnx graph has the + 0.5 add node.
if framework == "xgboost":
# Make sure it is a regression model
# (by checking it has a single output, as mentioned above)
if len(onnx_model.graph.output) == 1:
assert_add_node_and_constant_in_xgboost_regressor_graph(onnx_model)

# Cut the graph at the ReduceSum node as large sum are not yet supported.
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/451
clean_graph_at_node_op_type(onnx_model, "ReduceSum")
Expand Down
8 changes: 8 additions & 0 deletions src/concrete/ml/sklearn/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,14 @@ def fit(self, X, y, *args, **kwargs) -> Any:
super().fit(X, y, *args, **kwargs)
return self

def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
y_preds = super().post_processing(y_preds)

# Hummingbird Gemm for XGBoostRegressor adds a + 0.5 at the end of the graph.
# We need to add it back here since the graph is cut before this add node.
y_preds += 0.5
return y_preds

def dump_dict(self) -> Dict[str, Any]:
metadata: Dict[str, Any] = {}

Expand Down
2 changes: 1 addition & 1 deletion tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def check_correctness_with_sklearn(
"Lasso": 0.9,
"Ridge": 0.9,
"ElasticNet": 0.9,
"XGBRegressor": -0.2,
"XGBRegressor": 0.9,
"NeuralNetRegressor": -10,
}

Expand Down

0 comments on commit 6c4b523

Please sign in to comment.