Skip to content

Commit

Permalink
chore: add default value check + fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrery committed Apr 30, 2024
1 parent 20b18e0 commit f4bb1fc
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 42 deletions.
2 changes: 1 addition & 1 deletion src/concrete/ml/sklearn/rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
min_samples_split=2,
min_samples_leaf=1,
min_weight_fraction_leaf=0.0,
max_features="sqrt",
max_features=1.0,
max_leaf_nodes=None,
min_impurity_decrease=0.0,
bootstrap=True,
Expand Down
12 changes: 8 additions & 4 deletions src/concrete/ml/sklearn/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
max_depth: Optional[int] = 3,
learning_rate: Optional[float] = None,
n_estimators: Optional[int] = 20,
objective: Optional[str] = "binary:logistic",
objective: Optional[str] = None,
booster: Optional[str] = None,
tree_method: Optional[str] = None,
n_jobs: Optional[int] = None,
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(
max_depth: Optional[int] = 3,
learning_rate: Optional[float] = None,
n_estimators: Optional[int] = 20,
objective: Optional[str] = "reg:squarederror",
objective: Optional[str] = None,
booster: Optional[str] = None,
tree_method: Optional[str] = None,
n_jobs: Optional[int] = None,
Expand Down Expand Up @@ -450,10 +450,14 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["max_cat_to_onehot"] = self.max_cat_to_onehot
metadata["grow_policy"] = self.grow_policy
metadata["sampling_method"] = self.sampling_method
metadata["callbacks"] = self.callbacks
metadata["early_stopping_rounds"] = self.early_stopping_rounds
metadata["eval_metric"] = self.eval_metric
metadata["kwargs"] = self.kwargs

# Callables are not serializable
assert not self.kwargs, "kwargs are not supported for serialization"
assert not self.callbacks, "callbacks are not supported for serialization"
metadata["kwargs"] = None
metadata["callbacks"] = None

return metadata

Expand Down
121 changes: 84 additions & 37 deletions tests/sklearn/test_sklearn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import copy
import inspect
import json
import math
import os
import sys
import tempfile
Expand Down Expand Up @@ -106,6 +107,16 @@
# the CRT.
N_BITS_THRESHOLD_FOR_CRT_FHE_CIRCUITS = 9

# Expected different default parameters for some models
EXPECTED_DIFFERENT_DEFAULT_PARAMETERS = {
"KNeighborsClassifier": {"n_neighbors": 3},
"SGDClassifier": {"loss": "log_loss"},
"RandomForestClassifier": {"n_estimators": 20, "max_depth": 4},
"RandomForestRegressor": {"n_estimators": 20, "max_depth": 4},
"XGBClassifier": {"n_estimators": 20, "max_depth": 3},
"XGBRegressor": {"n_estimators": 20, "max_depth": 3},
}


def get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option):
"""Prepare the the (x, y) data-set."""
Expand Down Expand Up @@ -2041,9 +2052,13 @@ def test_error_raise_unsupported_pandas_values(model_class, bad_value, expected_
+ get_sklearn_tree_models_and_datasets()
+ get_sklearn_neighbors_models_and_datasets(),
)
def test_initialization_variables_match(model_class, parameters, load_data, is_weekly_option):
"""Test that CML models can be initialized with the same parameters scikit-learn models."""
n_bits = N_BITS_THRESHOLD_FOR_SKLEARN_CORRECTNESS_TESTS
def test_initialization_variables_and_defaults_match(
model_class, parameters, load_data, is_weekly_option
):
"""Test CML models init parameters and default values vs scikit-learn models."""
n_bits = get_n_bits_non_correctness(model_class)

model_name = get_model_name(model_class)

x, y = get_dataset(model_class, parameters, n_bits, load_data, is_weekly_option)

Expand All @@ -2052,44 +2067,76 @@ def test_initialization_variables_match(model_class, parameters, load_data, is_w

# Fit the model to create the equivalent sklearn model
with warnings.catch_warnings():
# Sometimes, we miss convergence, which is not a problem for our test
# Ignore convergence warnings
warnings.simplefilter("ignore", category=ConvergenceWarning)
model.fit(x, y)

# Assert the sklearn model has been created
assert hasattr(model, "sklearn_model"), "Sklearn model not found"

# Function to retrieve the parameters from any model
# XGBoost init params are in base classes, so we need to gather them recursively
def get_params(model):
"""Get the initializer parameters of the given model."""
cls = get_model_class(model.sklearn_model)
if cls.__name__ in ["XGBClassifier", "XGBRegressor"]:
params = {}

# Recursively gather parameters from all base classes
def gather_params(c):
sig = inspect.signature(c)
params.update({k: v.default for k, v in sig.parameters.items()})
for base in c.__bases__:
gather_params(base)

gather_params(cls)
return params

# Else, return parameters for non-xgboost models
sig = inspect.signature(cls)
return {k: v.default for k, v in sig.parameters.items()}

# Get the constructor parameters of both the custom and sklearn models
cml_params = set(inspect.signature(model.__class__).parameters.keys())

# Accumulate parameters from all base classes of the sklearn model
def get_params(cls):
params = set(inspect.signature(cls).parameters.keys())
for base in cls.__bases__:
params.update(get_params(base))
return params

# Conditionally gather parameters from base classes for XGBClassifier and XGBRegressor
if model.sklearn_model.__class__.__name__ in ["XGBClassifier", "XGBRegressor"]:
sklearn_params = get_params(model.sklearn_model.__class__)
else:
sklearn_params = set(inspect.signature(model.sklearn_model.__class__).parameters.keys())

# Allow 'n_bits' as an additional parameter for CML models
expected_difference = {"n_bits"}

# Allow fit_encrypted and parameters_range for SGDClassifier
if model.__class__.__name__ == "SGDClassifier":
expected_difference.add("fit_encrypted")
expected_difference.add("parameters_range")

# Calculate differences
missing_params = sklearn_params - cml_params
extra_params = (cml_params - sklearn_params) - expected_difference

assert (
not missing_params
), f"Concrete ML {model.__class__.__name__} is missing these init parameters: {missing_params}"
assert (
not extra_params
), f"Concrete ML {model.__class__.__name__} has extra init parameters: {extra_params}"
cml_params_defaults = {
k: v.default for k, v in inspect.signature(model.__class__).parameters.items()
}
sklearn_params_defaults = get_params(model)

# Calculate differences in parameters and defaults
missing_params = set(sklearn_params_defaults.keys()) - set(cml_params_defaults.keys())
extra_params = (set(cml_params_defaults.keys()) - set(sklearn_params_defaults.keys())) - {
"n_bits"
}

# Allow 'fit_encrypted' and 'parameters_range' for SGDClassifier
if model_name == "SGDClassifier":
extra_params -= {"fit_encrypted", "parameters_range"}

def is_nan(x):
"""Check if a variable is nan."""
return isinstance(x, float) and math.isnan(x)

differing_defaults = {
param
for param in sklearn_params_defaults.keys() & cml_params_defaults.keys()
if not (
sklearn_params_defaults[param] == cml_params_defaults[param]
# Some parameter can be nan which can't be compared using equality
or (is_nan(sklearn_params_defaults[param]) and is_nan(cml_params_defaults[param]))
)
}

# Remove expected different params defaults from differing_defaults
expected_differences = EXPECTED_DIFFERENT_DEFAULT_PARAMETERS.get(model_name, {})
# For mypy
assert isinstance(expected_differences, dict)
differing_defaults.difference_update(expected_differences.keys())

# Assert parameter exist and matching defaults
assert not missing_params, f"{model_name} is missing these init parameters: {missing_params}"
assert not extra_params, f"{model_name} has extra init parameters: {extra_params}"
assert not differing_defaults, (
f"Default values do not match for: {differing_defaults}. "
f"Expected: {[sklearn_params_defaults[param] for param in differing_defaults]}, "
f"Found: {[cml_params_defaults[param] for param in differing_defaults]}"
)

0 comments on commit f4bb1fc

Please sign in to comment.