Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shield against breaking changes from scikit-learn 1.3.0 release #598

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions hdbscan/hdbscan_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
of Applications with Noise
"""

import sklearn
import numpy as np

from packaging.version import Version
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.metrics import pairwise_distances
from scipy.sparse import issparse
Expand Down Expand Up @@ -37,7 +39,14 @@
from .plots import CondensedTree, SingleLinkageTree, MinimumSpanningTree
from .prediction import PredictionData

FAST_METRICS = KDTree.valid_metrics + BallTree.valid_metrics + ["cosine", "arccos"]
if Version(sklearn.__version__) >= Version("1.3.0"):
kdtree_valid_metrics = KDTree.valid_metrics()
balltree_valid_metrics = BallTree.valid_metrics()
else:
kdtree_valid_metrics = KDTree.valid_metrics
balltree_valid_metrics = BallTree.valid_metrics

FAST_METRICS = kdtree_valid_metrics + balltree_valid_metrics + ["cosine", "arccos"]

# Author: Leland McInnes <[email protected]>
# Steve Astels <[email protected]>
Expand Down Expand Up @@ -742,19 +751,19 @@ def hdbscan(
_hdbscan_generic
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
elif algorithm == "prims_kdtree":
if metric not in KDTree.valid_metrics:
if metric not in kdtree_valid_metrics:
raise ValueError("Cannot use Prim's with KDTree for this" " metric!")
(single_linkage_tree, result_min_span_tree) = memory.cache(
_hdbscan_prims_kdtree
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
elif algorithm == "prims_balltree":
if metric not in BallTree.valid_metrics:
if metric not in balltree_valid_metrics:
raise ValueError("Cannot use Prim's with BallTree for this" " metric!")
(single_linkage_tree, result_min_span_tree) = memory.cache(
_hdbscan_prims_balltree
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
elif algorithm == "boruvka_kdtree":
if metric not in BallTree.valid_metrics:
if metric not in balltree_valid_metrics:
raise ValueError("Cannot use Boruvka with KDTree for this" " metric!")
(single_linkage_tree, result_min_span_tree) = memory.cache(
_hdbscan_boruvka_kdtree
Expand All @@ -771,7 +780,7 @@ def hdbscan(
**kwargs
)
elif algorithm == "boruvka_balltree":
if metric not in BallTree.valid_metrics:
if metric not in balltree_valid_metrics:
raise ValueError("Cannot use Boruvka with BallTree for this" " metric!")
if (X.shape[0] // leaf_size) > 16000:
warn(
Expand Down Expand Up @@ -802,7 +811,7 @@ def hdbscan(
(single_linkage_tree, result_min_span_tree) = memory.cache(
_hdbscan_generic
)(X, min_samples, alpha, metric, p, leaf_size, gen_min_span_tree, **kwargs)
elif metric in KDTree.valid_metrics:
elif metric in kdtree_valid_metrics:
# TO DO: Need heuristic to decide when to go to boruvka;
# still debugging for now
if X.shape[1] > 60:
Expand Down Expand Up @@ -1237,9 +1246,9 @@ def generate_prediction_data(self):

if self.metric in FAST_METRICS:
min_samples = self.min_samples or self.min_cluster_size
if self.metric in KDTree.valid_metrics:
if self.metric in kdtree_valid_metrics:
tree_type = "kdtree"
elif self.metric in BallTree.valid_metrics:
elif self.metric in balltree_valid_metrics:
tree_type = "balltree"
else:
warn("Metric {} not supported for prediction data!".format(self.metric))
Expand Down
13 changes: 11 additions & 2 deletions hdbscan/robust_single_linkage_.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
"""
Robust Single Linkage: Density based single linkage clustering.
"""
import sklearn
import numpy as np

from packaging.version import Version
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.metrics import pairwise_distances
from scipy.sparse import issparse
Expand All @@ -24,7 +26,14 @@
#
# License: BSD 3 clause

FAST_METRICS = KDTree.valid_metrics + BallTree.valid_metrics
if Version(sklearn.__version__) >= Version("1.3.0"):
kdtree_valid_metrics = KDTree.valid_metrics()
balltree_valid_metrics = BallTree.valid_metrics()
else:
kdtree_valid_metrics = KDTree.valid_metrics
balltree_valid_metrics = BallTree.valid_metrics

FAST_METRICS = kdtree_valid_metrics + balltree_valid_metrics


def _rsl_generic(X, k=5, alpha=1.4142135623730951, metric='euclidean',
Expand Down Expand Up @@ -266,7 +275,7 @@ def robust_single_linkage(X, cut, k=5, alpha=1.4142135623730951,
# We can't do much with sparse matrices ...
single_linkage_tree = memory.cache(_rsl_generic)(
X, k, alpha, metric, **kwargs)
elif metric in KDTree.valid_metrics:
elif metric in kdtree_valid_metrics:
# Need heuristic to decide when to go to boruvka;
# still debugging for now
if X.shape[1] > 128:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
cython>=0.27
numpy>=1.20
packaging
scipy>= 1.0
scikit-learn>=0.20
joblib>=1.0
joblib>=1.0