Skip to content

Commit

Permalink
chore: remove custom classifier attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Sep 15, 2023
1 parent 9520589 commit a97f25b
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 42 deletions.
39 changes: 22 additions & 17 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,31 +699,38 @@ class BaseClassifier(BaseEstimator):
the predicted values as well as handling a mapping of classes in case they are not ordered.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@property
def target_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover
warnings.warn(
"Attribute 'target_classes_' is now deprecated. Please use 'classes_' instead.",
category=UserWarning,
stacklevel=2,
)

#: The classifier's different target classes. Is None if the model is not fitted.
self.target_classes_: Optional[numpy.ndarray] = None
return self.classes_

#: The classifier's number of different target classes. Is None if the model is not fitted.
self.n_classes_: Optional[int] = None
@property
def n_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover
warnings.warn(
"Attribute 'n_classes_' is now deprecated. Please use 'len(classes_)' instead.",
category=UserWarning,
stacklevel=2,
)

return len(self.classes_)

def _set_post_processing_params(self):
super()._set_post_processing_params()
self.post_processing_params.update({"n_classes_": self.n_classes_})
self.post_processing_params.update({"classes_": self.classes_})

def fit(self, X: Data, y: Target, **fit_parameters):
X, y = check_X_y_and_assert_multi_output(X, y)

# Retrieve the different target classes
classes = numpy.unique(y)
self.target_classes_ = classes

# Compute the number of target classes
self.n_classes_ = len(classes)

# Make sure y contains at least two classes
assert_true(self.n_classes_ > 1, "You must provide at least 2 classes in y.")
assert_true(len(classes) > 1, "You must provide at least 2 classes in y.")

# Change to composition in order to avoid diamond inheritance and indirect super() calls
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3249
Expand Down Expand Up @@ -753,19 +760,17 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.
# Retrieve the class with the highest probability
y_preds = numpy.argmax(y_proba, axis=1)

assert self.target_classes_ is not None, self._is_not_fitted_error_message()

return self.target_classes_[y_preds]
return self.classes_[y_preds]

def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
y_preds = super().post_processing(y_preds)

# Retrieve the number of target classes
n_classes_ = self.post_processing_params["n_classes_"]
classes = self.post_processing_params["classes_"]

# If the predictions only has one dimension (i.e., binary classification problem), apply the
# sigmoid operator
if n_classes_ == 2:
if len(classes) == 2:
y_preds = numpy_sigmoid(y_preds)[0]

# If the prediction array is 1D (which happens with some models such as XGBCLassifier
Expand Down
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["target_classes_"] = self.target_classes_
metadata["n_classes_"] = self.n_classes_
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["penalty"] = self.penalty
Expand Down Expand Up @@ -567,8 +566,7 @@ def load_dict(cls, metadata: Dict):
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.target_classes_ = metadata["target_classes_"]
obj.n_classes_ = metadata["n_classes_"]
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.penalty = metadata["penalty"]
Expand Down
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["target_classes_"] = self.target_classes_
metadata["n_classes_"] = self.n_classes_
metadata["classes_"] = self.classes_

# skorch attributes that cannot be serialized
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3550
Expand Down Expand Up @@ -638,8 +637,7 @@ def load_dict(cls, metadata: Dict):
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.target_classes_ = metadata["target_classes_"]
obj.n_classes_ = metadata["n_classes_"]
obj.classes_ = metadata["classes_"]

# skorch
obj.lr = metadata["lr"]
Expand Down
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["target_classes_"] = self.target_classes_
metadata["n_classes_"] = self.n_classes_
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["n_estimators"] = self.n_estimators
Expand Down Expand Up @@ -131,8 +130,7 @@ def load_dict(cls, metadata: Dict):
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.target_classes_ = metadata["target_classes_"]
obj.n_classes_ = metadata["n_classes_"]
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.n_estimators = metadata["n_estimators"]
Expand Down
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["target_classes_"] = self.target_classes_
metadata["n_classes_"] = self.n_classes_
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["penalty"] = self.penalty
Expand Down Expand Up @@ -232,8 +231,7 @@ def load_dict(cls, metadata: Dict):
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.target_classes_ = metadata["target_classes_"]
obj.n_classes_ = metadata["n_classes_"]
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.penalty = metadata["penalty"]
Expand Down
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["target_classes_"] = self.target_classes_
metadata["n_classes_"] = self.n_classes_
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["criterion"] = self.criterion
Expand Down Expand Up @@ -126,8 +125,7 @@ def load_dict(cls, metadata: Dict):
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.target_classes_ = metadata["target_classes_"]
obj.n_classes_ = metadata["n_classes_"]
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.criterion = metadata["criterion"]
Expand Down
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["target_classes_"] = self.target_classes_
metadata["n_classes_"] = self.n_classes_
metadata["classes_"] = self.classes_

# XGBoost
metadata["max_depth"] = self.max_depth
Expand Down Expand Up @@ -185,8 +184,7 @@ def load_dict(cls, metadata: Dict):
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.target_classes_ = metadata["target_classes_"]
obj.n_classes_ = metadata["n_classes_"]
obj.classes_ = metadata["classes_"]

# XGBoost
obj.max_depth = metadata["max_depth"]
Expand Down
2 changes: 1 addition & 1 deletion tests/sklearn/test_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_compile_and_calib(
if is_classifier_or_partial_classifier(model_class):
y_pred_clear = model.predict(x_train, fhe="disable")
# Check that the predicted classes are all contained in the model class list
assert set(numpy.unique(y_pred_clear)).issubset(set(model.target_classes_))
assert set(numpy.unique(y_pred_clear)).issubset(set(model.classes_))

# Compile the model
model.compile(
Expand Down

0 comments on commit a97f25b

Please sign in to comment.