diff --git a/distclassipy/__init__.py b/distclassipy/__init__.py index 914f472..2f53be7 100644 --- a/distclassipy/__init__.py +++ b/distclassipy/__init__.py @@ -25,4 +25,4 @@ from .classifier import DistanceMetricClassifier # noqa from .distances import Distance # noqa -__version__ = "0.1.6a0" +__version__ = "0.2.0a0" diff --git a/distclassipy/classifier.py b/distclassipy/classifier.py index 2e08c01..0616431 100644 --- a/distclassipy/classifier.py +++ b/distclassipy/classifier.py @@ -40,6 +40,52 @@ } +def initialize_metric_function(metric): + """Set the metric function based on the provided metric. + + If the metric is a string, the function will look for a corresponding + function in scipy.spatial.distance or distances.Distance. If the metric + is a function, it will be used directly. + """ + if callable(metric): + metric_fn_ = metric + metric_arg_ = metric + + elif isinstance(metric, str): + metric_str_lowercase = metric.lower() + metric_found = False + for package_str, source in METRIC_SOURCES_.items(): + + # Don't use scipy for jaccard as their implementation only works with + # booleans - use custom jaccard instead + if ( + package_str == "scipy.spatial.distance" + and metric_str_lowercase == "jaccard" + ): + continue + + if hasattr(source, metric_str_lowercase): + metric_fn_ = getattr(source, metric_str_lowercase) + metric_found = True + + # Use the string as an argument if it belongs to scipy as it is + # optimized + metric_arg_ = ( + metric if package_str == "scipy.spatial.distance" else metric_fn_ + ) + break + if not metric_found: + raise ValueError( + f"{metric} metric not found. Please pass a string of the " + "name of a metric in scipy.spatial.distance or " + "distances.Distance, or pass a metric function directly. For a " + "list of available metrics, see: " + "https://sidchaini.github.io/DistClassiPy/distances.html or " + "https://docs.scipy.org/doc/scipy/reference/spatial.distance.html" + ) + return metric_fn_, metric_arg_ + + class DistanceMetricClassifier(BaseEstimator, ClassifierMixin): """A distance-based classifier that supports different distance metrics. @@ -53,8 +99,6 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin): Parameters ---------- - metric : str or callable, default="euclidean" - The distance metric to use for calculating the distance between features. scale : bool, default=True Whether to scale the distance between the test object and the centroid for a class in the feature space. If True, the data will be scaled based on the @@ -72,8 +116,6 @@ class in the feature space. If True, the data will be scaled based on the Attributes ---------- - metric : str or callable - The distance metric used for classification. scale : bool Indicates whether the data is scaled. central_stat : str @@ -81,17 +123,6 @@ class in the feature space. If True, the data will be scaled based on the dispersion_stat : str The statistic used for calculating dispersion. - See Also - -------- - scipy.spatial.dist : Other distance metrics provided in SciPy - distclassipy.Distance : Distance metrics included with DistClassiPy - - Notes - ----- - If using distance metrics supported by SciPy, it is desirable to pass a string, - which allows SciPy to use an optimized C version of the code instead of the slower - Python version. - References ---------- .. [1] "Light Curve Classification with DistClassiPy: a new distance-based @@ -113,63 +144,15 @@ class in the feature space. If True, the data will be scaled based on the def __init__( self, - metric: str | Callable = "euclidean", scale: bool = True, central_stat: str = "median", dispersion_stat: str = "std", ): """Initialize the classifier with specified parameters.""" - self.metric = metric self.scale = scale self.central_stat = central_stat self.dispersion_stat = dispersion_stat - def initialize_metric_function(self): - """Set the metric function based on the provided metric. - - If the metric is a string, the function will look for a corresponding - function in scipy.spatial.distance or distances.Distance. If the metric - is a function, it will be used directly. - """ - if callable(self.metric): - self.metric_fn_ = self.metric - self.metric_arg_ = self.metric - - elif isinstance(self.metric, str): - metric_str_lowercase = self.metric.lower() - metric_found = False - for package_str, source in METRIC_SOURCES_.items(): - - # Don't use scipy for jaccard as their implementation only works with - # booleans - use custom jaccard instead - if ( - package_str == "scipy.spatial.distance" - and metric_str_lowercase == "jaccard" - ): - continue - - if hasattr(source, metric_str_lowercase): - self.metric_fn_ = getattr(source, metric_str_lowercase) - metric_found = True - - # Use the string as an argument if it belongs to scipy as it is - # optimized - self.metric_arg_ = ( - self.metric - if package_str == "scipy.spatial.distance" - else self.metric_fn_ - ) - break - if not metric_found: - raise ValueError( - f"{self.metric} metric not found. Please pass a string of the " - "name of a metric in scipy.spatial.distance or " - "distances.Distance, or pass a metric function directly. For a " - "list of available metrics, see: " - "https://sidchaini.github.io/DistClassiPy/distances.html or " - "https://docs.scipy.org/doc/scipy/reference/spatial.distance.html" - ) - def fit(self, X: np.array, y: np.array, feat_labels: list[str] = None): """Calculate the feature space centroid for all classes. @@ -200,8 +183,6 @@ def fit(self, X: np.array, y: np.array, feat_labels: list[str] = None): 1 ] # Number of features seen during fit - required for sklearn compatibility. - self.initialize_metric_function() - if feat_labels is None: feat_labels = [f"Feature_{x}" for x in range(X.shape[1])] @@ -249,7 +230,11 @@ def fit(self, X: np.array, y: np.array, feat_labels: list[str] = None): return self - def predict(self, X: np.array): + def predict( + self, + X: np.array, + metric: str | Callable = "euclidean", + ): """Predict the class labels for the provided X. The prediction is based on the distance of each data point in the input sample @@ -260,18 +245,33 @@ def predict(self, X: np.array): ---------- X : array-like of shape (n_samples, n_features) The input samples. + metric : str or callable, default="euclidean" + The distance metric to use for calculating the distance between features. Returns ------- y : ndarray of shape (n_samples,) The predicted classes. + + See Also + -------- + scipy.spatial.dist : Other distance metrics provided in SciPy + distclassipy.Distance : Distance metrics included with DistClassiPy + + Notes + ----- + If using distance metrics supported by SciPy, it is desirable to pass a string, + which allows SciPy to use an optimized C version of the code instead of the slower + Python version. """ check_is_fitted(self, "is_fitted_") X = check_array(X) + metric_fn_, metric_arg_ = initialize_metric_function(metric) + if not self.scale: dist_arr = scipy.spatial.distance.cdist( - XA=X, XB=self.df_centroid_.to_numpy(), metric=self.metric_arg_ + XA=X, XB=self.df_centroid_.to_numpy(), metric=metric_arg_ ) else: @@ -288,16 +288,18 @@ def predict(self, X: np.array): w = wtdf.loc[cl].to_numpy() # 1/std dev XB = XB * w # w is for this class only XA = X * w # w is for this class only - cl_dist = scipy.spatial.distance.cdist( - XA=XA, XB=XB, metric=self.metric_arg_ - ) + cl_dist = scipy.spatial.distance.cdist(XA=XA, XB=XB, metric=metric_arg_) dist_arr_list.append(cl_dist) dist_arr = np.column_stack(dist_arr_list) y_pred = self.classes_[dist_arr.argmin(axis=1)] return y_pred - def predict_and_analyse(self, X: np.array): + def predict_and_analyse( + self, + X: np.array, + metric: str | Callable = "euclidean", + ): """Predict the class labels for the provided X and perform analysis. The prediction is based on the distance of each data point in the input sample @@ -311,18 +313,35 @@ def predict_and_analyse(self, X: np.array): ---------- X : array-like of shape (n_samples, n_features) The input samples. + metric : str or callable, default="euclidean" + The distance metric to use for calculating the distance between features. + Returns ------- y : ndarray of shape (n_samples,) The predicted classes. + + See Also + -------- + scipy.spatial.dist : Other distance metrics provided in SciPy + distclassipy.Distance : Distance metrics included with DistClassiPy + + Notes + ----- + If using distance metrics supported by SciPy, it is desirable to pass a string, + which allows SciPy to use an optimized C version of the code instead of the slower + Python version. + """ check_is_fitted(self, "is_fitted_") X = check_array(X) + metric_fn_, metric_arg_ = initialize_metric_function(metric) + if not self.scale: dist_arr = scipy.spatial.distance.cdist( - XA=X, XB=self.df_centroid_.to_numpy(), metric=self.metric_arg_ + XA=X, XB=self.df_centroid_.to_numpy(), metric=metric_arg_ ) else: @@ -339,9 +358,7 @@ def predict_and_analyse(self, X: np.array): w = wtdf.loc[cl].to_numpy() # 1/std dev XB = XB * w # w is for this class only XA = X * w # w is for this class only - cl_dist = scipy.spatial.distance.cdist( - XA=XA, XB=XB, metric=self.metric_arg_ - ) + cl_dist = scipy.spatial.distance.cdist(XA=XA, XB=XB, metric=metric_arg_) dist_arr_list.append(cl_dist) dist_arr = np.column_stack(dist_arr_list) diff --git a/tests/test_classifier.py b/tests/test_classifier.py index e2eb81b..b5bb72d 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -9,8 +9,7 @@ # Test initialization of the classifier with specific parameters def test_init(): - clf = DistanceMetricClassifier(metric="euclidean", scale=True) - assert clf.metric == "euclidean" + clf = DistanceMetricClassifier(scale=True) assert clf.scale is True @@ -53,18 +52,20 @@ def test_predict_without_stdscale(): def test_metric_scipy(): X = np.array([[1, 2], [3, 4], [5, 6]]) # Sample feature set y = np.array([0, 1, 0]) # Sample target values - clf = DistanceMetricClassifier(metric="cityblock") + clf = DistanceMetricClassifier() clf.fit(X, y) - assert clf.metric == "cityblock" + clf.predict(X, metric="cityblock") + pass # Test using different distance metrics - from distclassipy def test_metric_dcpy(): X = np.array([[1, 2], [3, 4], [5, 6]]) # Sample feature set y = np.array([0, 1, 0]) # Sample target values - clf = DistanceMetricClassifier(metric="soergel") + clf = DistanceMetricClassifier() clf.fit(X, y) - assert clf.metric == "soergel" + clf.predict(X, metric="soergel") + pass # Test using custom defined metric @@ -75,9 +76,10 @@ def test_metric_custom(): def metric_euc(u, v): return np.sqrt(np.sum((u - v) ** 2)) - clf = DistanceMetricClassifier(metric=metric_euc) + clf = DistanceMetricClassifier() clf.fit(X, y) - assert callable(clf.metric) + clf.predict(X, metric=metric_euc) + pass # Test using invalid metric @@ -86,8 +88,9 @@ def test_metric_invalid(): y = np.array([0, 1, 0]) # Sample target values with pytest.raises(ValueError): - clf = DistanceMetricClassifier(metric="chaini") + clf = DistanceMetricClassifier() clf.fit(X, y) + clf.predict(X, metric="chaini") # Test setting central statistical method to median