diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index 5ac220efd..0ddbafd71 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -699,31 +699,55 @@ class BaseClassifier(BaseEstimator): the predicted values as well as handling a mapping of classes in case they are not ordered. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + @property + def target_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover + """Get the model's classes. + + Using this attribute is deprecated. + + Returns: + Optional[numpy.ndarray]: The model's classes. + """ + warnings.warn( + "Attribute 'target_classes_' is now deprecated and will be removed in a future " + "version. Please use 'classes_' instead.", + category=UserWarning, + stacklevel=2, + ) + + return self.classes_ + + @property + def n_classes_(self) -> int: # pragma: no cover + """Get the model's number of classes. + + Using this attribute is deprecated. + + Returns: + int: The model's number of classes. + """ - #: The classifier's different target classes. Is None if the model is not fitted. - self.target_classes_: Optional[numpy.ndarray] = None + # Tree-based classifiers from scikit-learn provide a `n_classes_` attribute + if self.sklearn_model is not None and hasattr(self.sklearn_model, "n_classes_"): + return self.sklearn_model.n_classes_ - #: The classifier's number of different target classes. Is None if the model is not fitted. - self.n_classes_: Optional[int] = None + warnings.warn( + "Attribute 'n_classes_' is now deprecated and will be removed in a future version. " + "Please use 'len(classes_)' instead.", + category=UserWarning, + stacklevel=2, + ) - def _set_post_processing_params(self): - super()._set_post_processing_params() - self.post_processing_params.update({"n_classes_": self.n_classes_}) + return len(self.classes_) def fit(self, X: Data, y: Target, **fit_parameters): X, y = check_X_y_and_assert_multi_output(X, y) # Retrieve the different target classes classes = numpy.unique(y) - self.target_classes_ = classes - - # Compute the number of target classes - self.n_classes_ = len(classes) # Make sure y contains at least two classes - assert_true(self.n_classes_ > 1, "You must provide at least 2 classes in y.") + assert_true(len(classes) > 1, "You must provide at least 2 classes in y.") # Change to composition in order to avoid diamond inheritance and indirect super() calls # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3249 @@ -753,24 +777,21 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy. # Retrieve the class with the highest probability y_preds = numpy.argmax(y_proba, axis=1) - assert self.target_classes_ is not None, self._is_not_fitted_error_message() - - return self.target_classes_[y_preds] + return self.classes_[y_preds] def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray: y_preds = super().post_processing(y_preds) - # Retrieve the number of target classes - n_classes_ = self.post_processing_params["n_classes_"] + # If the prediction array is 1D, which happens with some models such as XGBCLassifier or + # LogisticRegression models, we have a binary classification problem + n_classes = y_preds.shape[1] if y_preds.ndim > 1 and y_preds.shape[1] > 1 else 2 - # If the predictions only has one dimension (i.e., binary classification problem), apply the - # sigmoid operator - if n_classes_ == 2: + # For binary classification problem, apply the sigmoid operator + if n_classes == 2: y_preds = numpy_sigmoid(y_preds)[0] - # If the prediction array is 1D (which happens with some models such as XGBCLassifier - # models), transform the output into a 2D array [1-p, p], with p the initial - # output probabilities + # If the prediction array is 1D, transform the output into a 2D array [1-p, p], + # with p the initial output probabilities if y_preds.ndim == 1 or y_preds.shape[1] == 1: y_preds = numpy.concatenate((1 - y_preds, y_preds), axis=1) diff --git a/src/concrete/ml/sklearn/linear_model.py b/src/concrete/ml/sklearn/linear_model.py index 19dbf2d0e..b50a51a1f 100644 --- a/src/concrete/ml/sklearn/linear_model.py +++ b/src/concrete/ml/sklearn/linear_model.py @@ -525,10 +525,6 @@ def dump_dict(self) -> Dict[str, Any]: metadata["_q_bias"] = self._q_bias metadata["post_processing_params"] = self.post_processing_params - # Classifier - metadata["target_classes_"] = self.target_classes_ - metadata["n_classes_"] = self.n_classes_ - # Scikit-Learn metadata["penalty"] = self.penalty metadata["dual"] = self.dual @@ -566,10 +562,6 @@ def load_dict(cls, metadata: Dict): obj._q_bias = metadata["_q_bias"] obj.post_processing_params = metadata["post_processing_params"] - # Classifier - obj.target_classes_ = metadata["target_classes_"] - obj.n_classes_ = metadata["n_classes_"] - # Scikit-Learn obj.penalty = metadata["penalty"] obj.dual = metadata["dual"] diff --git a/src/concrete/ml/sklearn/qnn.py b/src/concrete/ml/sklearn/qnn.py index 7e4d94823..1bbbbcfd5 100644 --- a/src/concrete/ml/sklearn/qnn.py +++ b/src/concrete/ml/sklearn/qnn.py @@ -547,10 +547,6 @@ def dump_dict(self) -> Dict[str, Any]: metadata["quantized_module_"] = self.quantized_module_ metadata["post_processing_params"] = self.post_processing_params - # Classifier - metadata["target_classes_"] = self.target_classes_ - metadata["n_classes_"] = self.n_classes_ - # skorch attributes that cannot be serialized # FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3550 # Disable mypy as running isinstance with a Callable type unexpectedly raises an issue: @@ -594,6 +590,7 @@ def dump_dict(self) -> Dict[str, Any]: metadata["history_"] = self.history_ metadata["initialized_"] = self.initialized_ metadata["virtual_params_"] = self.virtual_params_ + metadata["classes_"] = self.classes_ assert hasattr( self, "module__n_layers" @@ -626,6 +623,7 @@ def load_dict(cls, metadata: Dict): # Instantiate the model obj = NeuralNetClassifier( module__n_layers=metadata["module__n_layers"], + classes=metadata["classes_"], ) # Concrete-ML @@ -637,10 +635,6 @@ def load_dict(cls, metadata: Dict): obj.quantized_module_ = metadata["quantized_module_"] obj.post_processing_params = metadata["post_processing_params"] - # Classifier - obj.target_classes_ = metadata["target_classes_"] - obj.n_classes_ = metadata["n_classes_"] - # skorch obj.lr = metadata["lr"] obj.max_epochs = metadata["max_epochs"] diff --git a/src/concrete/ml/sklearn/rf.py b/src/concrete/ml/sklearn/rf.py index 16022fd09..00685a047 100644 --- a/src/concrete/ml/sklearn/rf.py +++ b/src/concrete/ml/sklearn/rf.py @@ -85,10 +85,6 @@ def dump_dict(self) -> Dict[str, Any]: metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params - # Classifier - metadata["target_classes_"] = self.target_classes_ - metadata["n_classes_"] = self.n_classes_ - # Scikit-Learn metadata["n_estimators"] = self.n_estimators metadata["bootstrap"] = self.bootstrap @@ -130,10 +126,6 @@ def load_dict(cls, metadata: Dict): ) obj.post_processing_params = metadata["post_processing_params"] - # Classifier - obj.target_classes_ = metadata["target_classes_"] - obj.n_classes_ = metadata["n_classes_"] - # Scikit-Learn obj.n_estimators = metadata["n_estimators"] obj.bootstrap = metadata["bootstrap"] diff --git a/src/concrete/ml/sklearn/svm.py b/src/concrete/ml/sklearn/svm.py index 746bab6c4..1636a0061 100644 --- a/src/concrete/ml/sklearn/svm.py +++ b/src/concrete/ml/sklearn/svm.py @@ -192,10 +192,6 @@ def dump_dict(self) -> Dict[str, Any]: metadata["_q_bias"] = self._q_bias metadata["post_processing_params"] = self.post_processing_params - # Classifier - metadata["target_classes_"] = self.target_classes_ - metadata["n_classes_"] = self.n_classes_ - # Scikit-Learn metadata["penalty"] = self.penalty metadata["loss"] = self.loss @@ -231,10 +227,6 @@ def load_dict(cls, metadata: Dict): obj._q_bias = metadata["_q_bias"] obj.post_processing_params = metadata["post_processing_params"] - # Classifier - obj.target_classes_ = metadata["target_classes_"] - obj.n_classes_ = metadata["n_classes_"] - # Scikit-Learn obj.penalty = metadata["penalty"] obj.loss = metadata["loss"] diff --git a/src/concrete/ml/sklearn/tree.py b/src/concrete/ml/sklearn/tree.py index 07373d979..b81558b77 100644 --- a/src/concrete/ml/sklearn/tree.py +++ b/src/concrete/ml/sklearn/tree.py @@ -85,10 +85,6 @@ def dump_dict(self) -> Dict[str, Any]: metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params - # Classifier - metadata["target_classes_"] = self.target_classes_ - metadata["n_classes_"] = self.n_classes_ - # Scikit-Learn metadata["criterion"] = self.criterion metadata["splitter"] = self.splitter @@ -125,10 +121,6 @@ def load_dict(cls, metadata: Dict): ) obj.post_processing_params = metadata["post_processing_params"] - # Classifier - obj.target_classes_ = metadata["target_classes_"] - obj.n_classes_ = metadata["n_classes_"] - # Scikit-Learn obj.criterion = metadata["criterion"] obj.splitter = metadata["splitter"] diff --git a/src/concrete/ml/sklearn/xgb.py b/src/concrete/ml/sklearn/xgb.py index 2a700ab55..86301a5cb 100644 --- a/src/concrete/ml/sklearn/xgb.py +++ b/src/concrete/ml/sklearn/xgb.py @@ -126,10 +126,6 @@ def dump_dict(self) -> Dict[str, Any]: metadata["framework"] = self.framework metadata["post_processing_params"] = self.post_processing_params - # Classifier - metadata["target_classes_"] = self.target_classes_ - metadata["n_classes_"] = self.n_classes_ - # XGBoost metadata["max_depth"] = self.max_depth metadata["learning_rate"] = self.learning_rate @@ -184,10 +180,6 @@ def load_dict(cls, metadata: Dict): ) obj.post_processing_params = metadata["post_processing_params"] - # Classifier - obj.target_classes_ = metadata["target_classes_"] - obj.n_classes_ = metadata["n_classes_"] - # XGBoost obj.max_depth = metadata["max_depth"] obj.learning_rate = metadata["learning_rate"] diff --git a/tests/sklearn/test_qnn.py b/tests/sklearn/test_qnn.py index f164acb38..18a1d9d00 100644 --- a/tests/sklearn/test_qnn.py +++ b/tests/sklearn/test_qnn.py @@ -219,7 +219,7 @@ def test_compile_and_calib( if is_classifier_or_partial_classifier(model_class): y_pred_clear = model.predict(x_train, fhe="disable") # Check that the predicted classes are all contained in the model class list - assert set(numpy.unique(y_pred_clear)).issubset(set(model.target_classes_)) + assert set(numpy.unique(y_pred_clear)).issubset(set(model.classes_)) # Compile the model model.compile(