Skip to content

Commit

Permalink
chore: simplify post processing for classifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Sep 15, 2023
1 parent a97f25b commit 8d8e20c
Show file tree
Hide file tree
Showing 8 changed files with 79,533 additions and 52 deletions.
35 changes: 22 additions & 13 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,13 @@ class BaseClassifier(BaseEstimator):

@property
def target_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover
"""Get the model's classes.
Using this attribute is deprecated.
Returns:
Optional[numpy.ndarray]: The model's classes.
"""
warnings.warn(
"Attribute 'target_classes_' is now deprecated. Please use 'classes_' instead.",
category=UserWarning,
Expand All @@ -710,7 +717,14 @@ def target_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover
return self.classes_

@property
def n_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover
def n_classes_(self) -> int: # pragma: no cover
"""Get the model's number of classes.
Using this attribute is deprecated.
Returns:
int: The model's number of classes.
"""
warnings.warn(
"Attribute 'n_classes_' is now deprecated. Please use 'len(classes_)' instead.",
category=UserWarning,
Expand All @@ -719,10 +733,6 @@ def n_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover

return len(self.classes_)

def _set_post_processing_params(self):
super()._set_post_processing_params()
self.post_processing_params.update({"classes_": self.classes_})

def fit(self, X: Data, y: Target, **fit_parameters):
X, y = check_X_y_and_assert_multi_output(X, y)

Expand Down Expand Up @@ -765,17 +775,16 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
y_preds = super().post_processing(y_preds)

# Retrieve the number of target classes
classes = self.post_processing_params["classes_"]
# If the prediction array is 1D, which happens with some models such as XGBCLassifier or
# LogisticRegression models, we have a binary classification problem
n_classes = y_preds.shape[1] if y_preds.ndim > 1 and y_preds.shape[1] > 1 else 2

# If the predictions only has one dimension (i.e., binary classification problem), apply the
# sigmoid operator
if len(classes) == 2:
# For binary classification problem, apply the sigmoid operator
if n_classes == 2:
y_preds = numpy_sigmoid(y_preds)[0]

# If the prediction array is 1D (which happens with some models such as XGBCLassifier
# models), transform the output into a 2D array [1-p, p], with p the initial
# output probabilities
# If the prediction array is 1D, transform the output into a 2D array [1-p, p],
# with p the initial output probabilities
if y_preds.ndim == 1 or y_preds.shape[1] == 1:
y_preds = numpy.concatenate((1 - y_preds, y_preds), axis=1)

Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/sklearn/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,9 +525,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["_q_bias"] = self._q_bias
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["penalty"] = self.penalty
metadata["dual"] = self.dual
Expand Down Expand Up @@ -565,9 +562,6 @@ def load_dict(cls, metadata: Dict):
obj._q_bias = metadata["_q_bias"]
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.penalty = metadata["penalty"]
obj.dual = metadata["dual"]
Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/sklearn/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["quantized_module_"] = self.quantized_module_
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["classes_"] = self.classes_

# skorch attributes that cannot be serialized
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/3550
# Disable mypy as running isinstance with a Callable type unexpectedly raises an issue:
Expand Down Expand Up @@ -636,9 +633,6 @@ def load_dict(cls, metadata: Dict):
obj.quantized_module_ = metadata["quantized_module_"]
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.classes_ = metadata["classes_"]

# skorch
obj.lr = metadata["lr"]
obj.max_epochs = metadata["max_epochs"]
Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/sklearn/rf.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["framework"] = self.framework
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["n_estimators"] = self.n_estimators
metadata["bootstrap"] = self.bootstrap
Expand Down Expand Up @@ -129,9 +126,6 @@ def load_dict(cls, metadata: Dict):
)
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.n_estimators = metadata["n_estimators"]
obj.bootstrap = metadata["bootstrap"]
Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/sklearn/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["_q_bias"] = self._q_bias
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["penalty"] = self.penalty
metadata["loss"] = self.loss
Expand Down Expand Up @@ -230,9 +227,6 @@ def load_dict(cls, metadata: Dict):
obj._q_bias = metadata["_q_bias"]
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.penalty = metadata["penalty"]
obj.loss = metadata["loss"]
Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/sklearn/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["framework"] = self.framework
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["classes_"] = self.classes_

# Scikit-Learn
metadata["criterion"] = self.criterion
metadata["splitter"] = self.splitter
Expand Down Expand Up @@ -124,9 +121,6 @@ def load_dict(cls, metadata: Dict):
)
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.classes_ = metadata["classes_"]

# Scikit-Learn
obj.criterion = metadata["criterion"]
obj.splitter = metadata["splitter"]
Expand Down
6 changes: 0 additions & 6 deletions src/concrete/ml/sklearn/xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["framework"] = self.framework
metadata["post_processing_params"] = self.post_processing_params

# Classifier
metadata["classes_"] = self.classes_

# XGBoost
metadata["max_depth"] = self.max_depth
metadata["learning_rate"] = self.learning_rate
Expand Down Expand Up @@ -183,9 +180,6 @@ def load_dict(cls, metadata: Dict):
)
obj.post_processing_params = metadata["post_processing_params"]

# Classifier
obj.classes_ = metadata["classes_"]

# XGBoost
obj.max_depth = metadata["max_depth"]
obj.learning_rate = metadata["learning_rate"]
Expand Down
Loading

0 comments on commit 8d8e20c

Please sign in to comment.