Skip to content

Commit

Permalink
chore: simplify post processing for classifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Sep 14, 2023
1 parent 09323a2 commit 15c9ddc
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,6 @@ def n_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover

return len(self.classes_)

def _set_post_processing_params(self):
super()._set_post_processing_params()
self.post_processing_params.update({"classes_": self.classes_})

def fit(self, X: Data, y: Target, **fit_parameters):
X, y = check_X_y_and_assert_multi_output(X, y)

Expand Down Expand Up @@ -765,17 +761,16 @@ def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.
def post_processing(self, y_preds: numpy.ndarray) -> numpy.ndarray:
y_preds = super().post_processing(y_preds)

# Retrieve the number of target classes
classes = self.post_processing_params["classes_"]
# If the prediction array is 1D, which happens with some models such as XGBCLassifier or
# LogisticRegression models, we have a binary classification problem
n_classes = y_preds.shape[1] if y_preds.ndim > 1 and y_preds.shape[1] > 1 else 2

# If the predictions only has one dimension (i.e., binary classification problem), apply the
# sigmoid operator
if len(classes) == 2:
# For binary classification problem, apply the sigmoid operator
if n_classes == 2:
y_preds = numpy_sigmoid(y_preds)[0]

# If the prediction array is 1D (which happens with some models such as XGBCLassifier
# models), transform the output into a 2D array [1-p, p], with p the initial
# output probabilities
# If the prediction array is 1D, transform the output into a 2D array [1-p, p],
# with p the initial output probabilities
if y_preds.ndim == 1 or y_preds.shape[1] == 1:
y_preds = numpy.concatenate((1 - y_preds, y_preds), axis=1)

Expand Down

0 comments on commit 15c9ddc

Please sign in to comment.