Skip to content

Commit

Permalink
test: Added basic test for EnsembleDistanceClassifier
Browse files Browse the repository at this point in the history
- Also renamed EnsembleDistanceMetricClassifier to EnsembleDistanceClassifier
  • Loading branch information
sidchaini committed Oct 21, 2024
1 parent 0eed08a commit 5638bb7
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 5 deletions.
4 changes: 2 additions & 2 deletions distclassipy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@

from .classifier import (
DistanceMetricClassifier,
EnsembleDistanceMetricClassifier,
EnsembleDistanceClassifier,
)
from .distances import Distance, _ALL_METRICS

__version__ = "0.2.0"

__all__ = [
"DistanceMetricClassifier",
"EnsembleDistanceMetricClassifier",
"EnsembleDistanceClassifier",
"Distance",
"_ALL_METRICS",
]
4 changes: 2 additions & 2 deletions distclassipy/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def find_best_metrics(
return quantile_scores_df, best_metrics_per_quantile, group_bins


class EnsembleDistanceMetricClassifier(BaseEstimator, ClassifierMixin):
class EnsembleDistanceClassifier(BaseEstimator, ClassifierMixin):
"""An ensemble classifier that uses different metrics for each quantile."""

def __init__(
Expand Down Expand Up @@ -541,7 +541,7 @@ def __init__(

def fit(
self, X: np.ndarray, y: np.ndarray, n_quantiles: int = 4
) -> "EnsembleDistanceMetricClassifier":
) -> "EnsembleDistanceClassifier":
"""Fit the ensemble classifier using the best metrics for each quantile.
Parameters
Expand Down
21 changes: 20 additions & 1 deletion tests/test_classifier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from distclassipy.classifier import DistanceMetricClassifier
from distclassipy.classifier import (
DistanceMetricClassifier,
EnsembleDistanceClassifier,
)

import numpy as np

import pytest

from sklearn.datasets import make_classification
from sklearn.utils.estimator_checks import check_estimator


Expand Down Expand Up @@ -134,3 +138,18 @@ def test_confidence_calculation():
clf.predict_and_analyse(X)
distance_confidence = clf.calculate_confidence()
assert distance_confidence.shape == (3, len(np.unique(y)))


# Test basic functionality of EnsembleDistanceClassifier
def test_ensemble_distance_classifier():
X, y = make_classification(
n_samples=1000,
n_features=4,
n_informative=2,
shuffle=True,
)
clf = EnsembleDistanceClassifier(feat_idx=0)
clf.fit(X, y)
predictions = clf.predict(X)
assert len(predictions) == len(y)
assert set(predictions).issubset(set(y))

0 comments on commit 5638bb7

Please sign in to comment.