Skip to content

Commit

Permalink
feature: Big change, the metric is now passed during prediction and n…
Browse files Browse the repository at this point in the history
…ot during fitting.
  • Loading branch information
sidchaini committed Sep 17, 2024
1 parent abb20c4 commit 303a26b
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 85 deletions.
2 changes: 1 addition & 1 deletion distclassipy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@
from .classifier import DistanceMetricClassifier # noqa
from .distances import Distance # noqa

__version__ = "0.1.6a0"
__version__ = "0.2.0a0"
167 changes: 92 additions & 75 deletions distclassipy/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,52 @@
}


def initialize_metric_function(metric):
"""Set the metric function based on the provided metric.
If the metric is a string, the function will look for a corresponding
function in scipy.spatial.distance or distances.Distance. If the metric
is a function, it will be used directly.
"""
if callable(metric):
metric_fn_ = metric
metric_arg_ = metric

elif isinstance(metric, str):
metric_str_lowercase = metric.lower()
metric_found = False
for package_str, source in METRIC_SOURCES_.items():

# Don't use scipy for jaccard as their implementation only works with
# booleans - use custom jaccard instead
if (
package_str == "scipy.spatial.distance"
and metric_str_lowercase == "jaccard"
):
continue

if hasattr(source, metric_str_lowercase):
metric_fn_ = getattr(source, metric_str_lowercase)
metric_found = True

# Use the string as an argument if it belongs to scipy as it is
# optimized
metric_arg_ = (
metric if package_str == "scipy.spatial.distance" else metric_fn_
)
break
if not metric_found:
raise ValueError(
f"{metric} metric not found. Please pass a string of the "
"name of a metric in scipy.spatial.distance or "
"distances.Distance, or pass a metric function directly. For a "
"list of available metrics, see: "
"https://sidchaini.github.io/DistClassiPy/distances.html or "
"https://docs.scipy.org/doc/scipy/reference/spatial.distance.html"
)
return metric_fn_, metric_arg_


class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
"""A distance-based classifier that supports different distance metrics.
Expand All @@ -53,8 +99,6 @@ class DistanceMetricClassifier(BaseEstimator, ClassifierMixin):
Parameters
----------
metric : str or callable, default="euclidean"
The distance metric to use for calculating the distance between features.
scale : bool, default=True
Whether to scale the distance between the test object and the centroid for a
class in the feature space. If True, the data will be scaled based on the
Expand All @@ -72,26 +116,13 @@ class in the feature space. If True, the data will be scaled based on the
Attributes
----------
metric : str or callable
The distance metric used for classification.
scale : bool
Indicates whether the data is scaled.
central_stat : str
The statistic used for calculating central tendency.
dispersion_stat : str
The statistic used for calculating dispersion.
See Also
--------
scipy.spatial.dist : Other distance metrics provided in SciPy
distclassipy.Distance : Distance metrics included with DistClassiPy
Notes
-----
If using distance metrics supported by SciPy, it is desirable to pass a string,
which allows SciPy to use an optimized C version of the code instead of the slower
Python version.
References
----------
.. [1] "Light Curve Classification with DistClassiPy: a new distance-based
Expand All @@ -113,63 +144,15 @@ class in the feature space. If True, the data will be scaled based on the

def __init__(
self,
metric: str | Callable = "euclidean",
scale: bool = True,
central_stat: str = "median",
dispersion_stat: str = "std",
):
"""Initialize the classifier with specified parameters."""
self.metric = metric
self.scale = scale
self.central_stat = central_stat
self.dispersion_stat = dispersion_stat

def initialize_metric_function(self):
"""Set the metric function based on the provided metric.
If the metric is a string, the function will look for a corresponding
function in scipy.spatial.distance or distances.Distance. If the metric
is a function, it will be used directly.
"""
if callable(self.metric):
self.metric_fn_ = self.metric
self.metric_arg_ = self.metric

elif isinstance(self.metric, str):
metric_str_lowercase = self.metric.lower()
metric_found = False
for package_str, source in METRIC_SOURCES_.items():

# Don't use scipy for jaccard as their implementation only works with
# booleans - use custom jaccard instead
if (
package_str == "scipy.spatial.distance"
and metric_str_lowercase == "jaccard"
):
continue

if hasattr(source, metric_str_lowercase):
self.metric_fn_ = getattr(source, metric_str_lowercase)
metric_found = True

# Use the string as an argument if it belongs to scipy as it is
# optimized
self.metric_arg_ = (
self.metric
if package_str == "scipy.spatial.distance"
else self.metric_fn_
)
break
if not metric_found:
raise ValueError(
f"{self.metric} metric not found. Please pass a string of the "
"name of a metric in scipy.spatial.distance or "
"distances.Distance, or pass a metric function directly. For a "
"list of available metrics, see: "
"https://sidchaini.github.io/DistClassiPy/distances.html or "
"https://docs.scipy.org/doc/scipy/reference/spatial.distance.html"
)

def fit(self, X: np.array, y: np.array, feat_labels: list[str] = None):
"""Calculate the feature space centroid for all classes.
Expand Down Expand Up @@ -200,8 +183,6 @@ def fit(self, X: np.array, y: np.array, feat_labels: list[str] = None):
1
] # Number of features seen during fit - required for sklearn compatibility.

self.initialize_metric_function()

if feat_labels is None:
feat_labels = [f"Feature_{x}" for x in range(X.shape[1])]

Expand Down Expand Up @@ -249,7 +230,11 @@ def fit(self, X: np.array, y: np.array, feat_labels: list[str] = None):

return self

def predict(self, X: np.array):
def predict(
self,
X: np.array,
metric: str | Callable = "euclidean",
):
"""Predict the class labels for the provided X.
The prediction is based on the distance of each data point in the input sample
Expand All @@ -260,18 +245,33 @@ def predict(self, X: np.array):
----------
X : array-like of shape (n_samples, n_features)
The input samples.
metric : str or callable, default="euclidean"
The distance metric to use for calculating the distance between features.
Returns
-------
y : ndarray of shape (n_samples,)
The predicted classes.
See Also
--------
scipy.spatial.dist : Other distance metrics provided in SciPy
distclassipy.Distance : Distance metrics included with DistClassiPy
Notes
-----
If using distance metrics supported by SciPy, it is desirable to pass a string,
which allows SciPy to use an optimized C version of the code instead of the slower
Python version.
"""
check_is_fitted(self, "is_fitted_")
X = check_array(X)

metric_fn_, metric_arg_ = initialize_metric_function(metric)

if not self.scale:
dist_arr = scipy.spatial.distance.cdist(
XA=X, XB=self.df_centroid_.to_numpy(), metric=self.metric_arg_
XA=X, XB=self.df_centroid_.to_numpy(), metric=metric_arg_
)

else:
Expand All @@ -288,16 +288,18 @@ def predict(self, X: np.array):
w = wtdf.loc[cl].to_numpy() # 1/std dev
XB = XB * w # w is for this class only
XA = X * w # w is for this class only
cl_dist = scipy.spatial.distance.cdist(
XA=XA, XB=XB, metric=self.metric_arg_
)
cl_dist = scipy.spatial.distance.cdist(XA=XA, XB=XB, metric=metric_arg_)
dist_arr_list.append(cl_dist)
dist_arr = np.column_stack(dist_arr_list)

y_pred = self.classes_[dist_arr.argmin(axis=1)]
return y_pred

def predict_and_analyse(self, X: np.array):
def predict_and_analyse(
self,
X: np.array,
metric: str | Callable = "euclidean",
):
"""Predict the class labels for the provided X and perform analysis.
The prediction is based on the distance of each data point in the input sample
Expand All @@ -311,18 +313,35 @@ def predict_and_analyse(self, X: np.array):
----------
X : array-like of shape (n_samples, n_features)
The input samples.
metric : str or callable, default="euclidean"
The distance metric to use for calculating the distance between features.
Returns
-------
y : ndarray of shape (n_samples,)
The predicted classes.
See Also
--------
scipy.spatial.dist : Other distance metrics provided in SciPy
distclassipy.Distance : Distance metrics included with DistClassiPy
Notes
-----
If using distance metrics supported by SciPy, it is desirable to pass a string,
which allows SciPy to use an optimized C version of the code instead of the slower
Python version.
"""
check_is_fitted(self, "is_fitted_")
X = check_array(X)

metric_fn_, metric_arg_ = initialize_metric_function(metric)

if not self.scale:
dist_arr = scipy.spatial.distance.cdist(
XA=X, XB=self.df_centroid_.to_numpy(), metric=self.metric_arg_
XA=X, XB=self.df_centroid_.to_numpy(), metric=metric_arg_
)

else:
Expand All @@ -339,9 +358,7 @@ def predict_and_analyse(self, X: np.array):
w = wtdf.loc[cl].to_numpy() # 1/std dev
XB = XB * w # w is for this class only
XA = X * w # w is for this class only
cl_dist = scipy.spatial.distance.cdist(
XA=XA, XB=XB, metric=self.metric_arg_
)
cl_dist = scipy.spatial.distance.cdist(XA=XA, XB=XB, metric=metric_arg_)
dist_arr_list.append(cl_dist)
dist_arr = np.column_stack(dist_arr_list)

Expand Down
21 changes: 12 additions & 9 deletions tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@

# Test initialization of the classifier with specific parameters
def test_init():
clf = DistanceMetricClassifier(metric="euclidean", scale=True)
assert clf.metric == "euclidean"
clf = DistanceMetricClassifier(scale=True)
assert clf.scale is True


Expand Down Expand Up @@ -53,18 +52,20 @@ def test_predict_without_stdscale():
def test_metric_scipy():
X = np.array([[1, 2], [3, 4], [5, 6]]) # Sample feature set
y = np.array([0, 1, 0]) # Sample target values
clf = DistanceMetricClassifier(metric="cityblock")
clf = DistanceMetricClassifier()
clf.fit(X, y)
assert clf.metric == "cityblock"
clf.predict(X, metric="cityblock")
pass


# Test using different distance metrics - from distclassipy
def test_metric_dcpy():
X = np.array([[1, 2], [3, 4], [5, 6]]) # Sample feature set
y = np.array([0, 1, 0]) # Sample target values
clf = DistanceMetricClassifier(metric="soergel")
clf = DistanceMetricClassifier()
clf.fit(X, y)
assert clf.metric == "soergel"
clf.predict(X, metric="soergel")
pass


# Test using custom defined metric
Expand All @@ -75,9 +76,10 @@ def test_metric_custom():
def metric_euc(u, v):
return np.sqrt(np.sum((u - v) ** 2))

clf = DistanceMetricClassifier(metric=metric_euc)
clf = DistanceMetricClassifier()
clf.fit(X, y)
assert callable(clf.metric)
clf.predict(X, metric=metric_euc)
pass


# Test using invalid metric
Expand All @@ -86,8 +88,9 @@ def test_metric_invalid():
y = np.array([0, 1, 0]) # Sample target values

with pytest.raises(ValueError):
clf = DistanceMetricClassifier(metric="chaini")
clf = DistanceMetricClassifier()
clf.fit(X, y)
clf.predict(X, metric="chaini")


# Test setting central statistical method to median
Expand Down

0 comments on commit 303a26b

Please sign in to comment.