Skip to content

Commit

Permalink
chore: fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RomanBredehoft committed Sep 15, 2023
1 parent 8d8e20c commit 9044376
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ def target_classes_(self) -> Optional[numpy.ndarray]: # pragma: no cover
Optional[numpy.ndarray]: The model's classes.
"""
warnings.warn(
"Attribute 'target_classes_' is now deprecated. Please use 'classes_' instead.",
"Attribute 'target_classes_' is now deprecated and will be removed in a future "
"version. Please use 'classes_' instead.",
category=UserWarning,
stacklevel=2,
)
Expand All @@ -725,8 +726,14 @@ def n_classes_(self) -> int: # pragma: no cover
Returns:
int: The model's number of classes.
"""

# Tree-based classifiers from scikit-learn provide a `n_classes_` attribute
if self.sklearn_model is not None and hasattr(self.sklearn_model, "n_classes_"):
return self.sklearn_model.n_classes_

warnings.warn(
"Attribute 'n_classes_' is now deprecated. Please use 'len(classes_)' instead.",
"Attribute 'n_classes_' is now deprecated and will be removed in a future version. "
"Please use 'len(classes_)' instead.",
category=UserWarning,
stacklevel=2,
)
Expand Down
2 changes: 2 additions & 0 deletions src/concrete/ml/sklearn/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def dump_dict(self) -> Dict[str, Any]:
metadata["history_"] = self.history_
metadata["initialized_"] = self.initialized_
metadata["virtual_params_"] = self.virtual_params_
metadata["classes_"] = self.classes_

assert hasattr(
self, "module__n_layers"
Expand Down Expand Up @@ -622,6 +623,7 @@ def load_dict(cls, metadata: Dict):
# Instantiate the model
obj = NeuralNetClassifier(
module__n_layers=metadata["module__n_layers"],
classes=metadata["classes_"],
)

# Concrete-ML
Expand Down

0 comments on commit 9044376

Please sign in to comment.