diff --git a/ManifoldEM/S2tessellation.py b/ManifoldEM/S2tessellation.py index 46e6b3a..9ffaedc 100644 --- a/ManifoldEM/S2tessellation.py +++ b/ManifoldEM/S2tessellation.py @@ -17,9 +17,10 @@ def quaternion_to_S2(q): def collect_nearest_neighbors(X, Q): + nbins = X.shape[0] nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree').fit(X) - _, neighb_bins = nbrs.kneighbors(Q) - bin_counts = np.bincount(neighb_bins.squeeze()) + neighb_bins = nbrs.kneighbors(Q, return_distance=False) + bin_counts = np.bincount(neighb_bins.squeeze(), minlength=nbins) return neighb_bins, bin_counts