diff --git a/distclassipy/__init__.py b/distclassipy/__init__.py index 921299b..0713b18 100644 --- a/distclassipy/__init__.py +++ b/distclassipy/__init__.py @@ -24,7 +24,7 @@ from .classifier import ( DistanceMetricClassifier, - EnsembleDistanceMetricClassifier, + EnsembleDistanceClassifier, ) from .distances import Distance, _ALL_METRICS @@ -32,7 +32,7 @@ __all__ = [ "DistanceMetricClassifier", - "EnsembleDistanceMetricClassifier", + "EnsembleDistanceClassifier", "Distance", "_ALL_METRICS", ] diff --git a/distclassipy/classifier.py b/distclassipy/classifier.py index 71cf870..bde4ab7 100644 --- a/distclassipy/classifier.py +++ b/distclassipy/classifier.py @@ -506,7 +506,7 @@ def find_best_metrics( return quantile_scores_df, best_metrics_per_quantile, group_bins -class EnsembleDistanceMetricClassifier(BaseEstimator, ClassifierMixin): +class EnsembleDistanceClassifier(BaseEstimator, ClassifierMixin): """An ensemble classifier that uses different metrics for each quantile.""" def __init__( @@ -541,7 +541,7 @@ def __init__( def fit( self, X: np.ndarray, y: np.ndarray, n_quantiles: int = 4 - ) -> "EnsembleDistanceMetricClassifier": + ) -> "EnsembleDistanceClassifier": """Fit the ensemble classifier using the best metrics for each quantile. Parameters diff --git a/tests/test_classifier.py b/tests/test_classifier.py index d34a10d..dbe48ce 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -1,9 +1,13 @@ -from distclassipy.classifier import DistanceMetricClassifier +from distclassipy.classifier import ( + DistanceMetricClassifier, + EnsembleDistanceClassifier, +) import numpy as np import pytest +from sklearn.datasets import make_classification from sklearn.utils.estimator_checks import check_estimator @@ -134,3 +138,18 @@ def test_confidence_calculation(): clf.predict_and_analyse(X) distance_confidence = clf.calculate_confidence() assert distance_confidence.shape == (3, len(np.unique(y))) + + +# Test basic functionality of EnsembleDistanceClassifier +def test_ensemble_distance_classifier(): + X, y = make_classification( + n_samples=1000, + n_features=4, + n_informative=2, + shuffle=True, + ) + clf = EnsembleDistanceClassifier(feat_idx=0) + clf.fit(X, y) + predictions = clf.predict(X) + assert len(predictions) == len(y) + assert set(predictions).issubset(set(y))