Skip to content

Commit

Permalink
chore: add raise on not implemented serialization + test + fix defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed May 1, 2024
1 parent f4bb1fc commit 32d0037
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 21 deletions.
52 changes: 34 additions & 18 deletions src/concrete/ml/sklearn/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import platform
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import numpy
import xgboost.sklearn
from numpy.random import RandomState
from xgboost.callback import TrainingCallback

from ..common.debugging.custom_assert import assert_true
Expand Down Expand Up @@ -33,7 +34,7 @@ def __init__(
max_depth: Optional[int] = 3,
learning_rate: Optional[float] = None,
n_estimators: Optional[int] = 20,
objective: Optional[str] = None,
objective: Optional[str] = "binary:logistic",
booster: Optional[str] = None,
tree_method: Optional[str] = None,
n_jobs: Optional[int] = None,
Expand All @@ -51,7 +52,7 @@ def __init__(
missing: float = numpy.nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None,
interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
Expand Down Expand Up @@ -180,14 +181,24 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["random_state"] = self.random_state
metadata["verbosity"] = self.verbosity
metadata["max_bin"] = self.max_bin
metadata["callbacks"] = self.callbacks
metadata["early_stopping_rounds"] = self.early_stopping_rounds
metadata["max_leaves"] = self.max_leaves
metadata["eval_metric"] = self.eval_metric
metadata["max_cat_to_onehot"] = self.max_cat_to_onehot
metadata["grow_policy"] = self.grow_policy
metadata["sampling_method"] = self.sampling_method
metadata["kwargs"] = self.kwargs

if callable(self.eval_metric):
raise NotImplementedError("Callable eval_metric is not supported for serialization")

if self.kwargs:
raise NotImplementedError("kwargs are not supported for serialization")

if self.callbacks:
raise NotImplementedError("callbacks are not supported for serialization")

metadata["eval_metric"] = self.eval_metric
metadata["kwargs"] = None
metadata["callbacks"] = None

return metadata

Expand Down Expand Up @@ -277,8 +288,8 @@ def __init__(
n_bits: Union[int, Dict[str, int]] = 6,
max_depth: Optional[int] = 3,
learning_rate: Optional[float] = None,
n_estimators: Optional[int] = 20,
objective: Optional[str] = None,
n_estimators: int = 20,
objective: Optional[str] = "reg:squarederror",
booster: Optional[str] = None,
tree_method: Optional[str] = None,
n_jobs: Optional[int] = None,
Expand All @@ -296,23 +307,23 @@ def __init__(
missing: float = numpy.nan,
num_parallel_tree: Optional[int] = None,
monotone_constraints: Optional[Union[Dict[str, int], str]] = None,
interaction_constraints: Optional[Union[str, List[Tuple[str]]]] = None,
interaction_constraints: Optional[Union[str, Sequence[Sequence[str]]]] = None,
importance_type: Optional[str] = None,
gpu_id: Optional[int] = None,
validate_parameters: Optional[bool] = None,
predictor: Optional[str] = None,
enable_categorical: bool = False,
random_state: Optional[int] = None,
random_state: Optional[Union[RandomState, int]] = None,
verbosity: Optional[int] = None,
eval_metric: Optional[str] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
sampling_method: Optional[str] = None,
max_leaves: Optional[int] = None,
max_bin: Optional[int] = None,
max_cat_to_onehot: Optional[int] = None,
grow_policy: Optional[str] = None,
callbacks: Optional[List[Callable]] = None,
callbacks: Optional[List[TrainingCallback]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs,
**kwargs: Any,
):
# base_score != 0.5 or None does not seem to not pass our tests
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/474
Expand Down Expand Up @@ -442,7 +453,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["validate_parameters"] = self.validate_parameters
metadata["predictor"] = self.predictor
metadata["enable_categorical"] = self.enable_categorical
metadata["use_label_encoder"] = self.use_label_encoder
metadata["random_state"] = self.random_state
metadata["verbosity"] = self.verbosity
metadata["max_bin"] = self.max_bin
Expand All @@ -451,11 +461,17 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["grow_policy"] = self.grow_policy
metadata["sampling_method"] = self.sampling_method
metadata["early_stopping_rounds"] = self.early_stopping_rounds
metadata["eval_metric"] = self.eval_metric

# Callables are not serializable
assert not self.kwargs, "kwargs are not supported for serialization"
assert not self.callbacks, "callbacks are not supported for serialization"
if callable(self.eval_metric):
raise NotImplementedError("Callable eval_metric is not supported for serialization")

if self.kwargs:
raise NotImplementedError("kwargs are not supported for serialization")

if self.callbacks:
raise NotImplementedError("callbacks are not supported for serialization")

metadata["eval_metric"] = self.eval_metric
metadata["kwargs"] = None
metadata["callbacks"] = None

Expand Down
26 changes: 23 additions & 3 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,12 +2082,14 @@ def get_params(model):
if cls.__name__ in ["XGBClassifier", "XGBRegressor"]:
params = {}

# Recursively gather parameters from all base classes
# Recursively gather parameters from all base classes, starting from the current class
def gather_params(c):
sig = inspect.signature(c)
params.update({k: v.default for k, v in sig.parameters.items()})
# First, recursively gather from base classes so child class can overwrite
for base in c.__bases__:
gather_params(base)
# Update with the current class's parameters
sig = inspect.signature(c)
params.update({k: v.default for k, v in sig.parameters.items()})

gather_params(cls)
return params
Expand Down Expand Up @@ -2140,3 +2142,21 @@ def is_nan(x):
f"Expected: {[sklearn_params_defaults[param] for param in differing_defaults]}, "
f"Found: {[cml_params_defaults[param] for param in differing_defaults]}"
)


@pytest.mark.parametrize("model_class", _get_sklearn_tree_models())
@pytest.mark.parametrize(
"param, error_message",
[
({"eval_metric": lambda x: x}, "Callable eval_metric is not supported for serialization"),
({"kwargs": {"extra": "param"}}, "kwargs are not supported for serialization"),
({"callbacks": [lambda x: x]}, "callbacks are not supported for serialization"),
],
)
def test_xgb_serialization_errors(model_class, param, error_message):
"""Test that XGBoost models with unsupported parameters raise errors on serialization."""
model_name = get_model_name(model_class)
if model_name in ["XGBClassifier", "XGBRegressor"]:
with pytest.raises(NotImplementedError, match=error_message):
model = instantiate_model_generic(model_class, 5, **param)
model.dumps()

0 comments on commit 32d0037

Please sign in to comment.