Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance with larger k #64

Open
ErlerPhilipp opened this issue Mar 15, 2022 · 7 comments
Open

Performance with larger k #64

ErlerPhilipp opened this issue Mar 15, 2022 · 7 comments

Comments

@ErlerPhilipp
Copy link

Hi,

I'm trying to get a faster spatial data structure than the SciPy KDTree. I'm using the following test code with pyktree==1.3.4 and SciPy==1.7.1:

import numpy as np
from scipy.spatial import kdtree
import time

data_pts = np.random.random(size=(35000, 3)).astype(np.float32)

query_pts = np.random.random(size=(9*6, 3)).astype(np.float32)
num_neighbors = 1000

from pykdtree.kdtree import KDTree
pykd_tree = KDTree(data_pts, leafsize=10)

start_pykd = time.time()
for _ in range(100):
    dist_pykd, idx_pykd = pykd_tree.query(query_pts, k=num_neighbors)
end_pykd = time.time()
print('pykdtree: {}'.format(end_pykd - start_pykd))

scipy_kd_tree = kdtree.KDTree(data_pts, leafsize=10)

start_scipy = time.time()
for _ in range(100):
    dist_scipy, idx_scipy = scipy_kd_tree.query(query_pts, k=num_neighbors)
end_scipy = time.time()
print('scipy: {}'.format(end_scipy - start_scipy))

I get these timings in seconds:

k    | pykdtree | scipy
________________________
1     | 0.009    | 0.0875
10    | 0.0189   | 0.1035
100   | 0.0939   | 0.1889
1000  | 3.269    | 0.952
10000 | >60      | 7.4145

Am I doing anything wrong? Why does pykdtree get so much slower with larger k? Is pykdtree perhaps growing (re-allocating) its output arrays many times?

@djhoese
Copy link
Collaborator

djhoese commented Mar 15, 2022

Very interesting. What operating system are you using? Version of Python? And how did you install pykdtree?

The reallocation seems like a good guess, but I'd also be curious if OpenMP is doing something weird like deciding to spawn a ton of extra threads. @storpipfugl would have to comment on if the algorithm used is expected to perform well with such a high k. In my own work I've only ever used k<10.

@ErlerPhilipp
Copy link
Author

Installed with conda, no OpenMP (at least I didn't do anything on my own). I want it single-threaded since I'd use it in worker processes of my ML stuff.
Python 3.8.5 on Windows 10 (64bit)
I got similar scaling with Ubuntu in WSL

@djhoese
Copy link
Collaborator

djhoese commented Mar 15, 2022

The OpenMP should be automatic if it was available when installed. Conda-forge I think includes it in all the builds. You can force it to single thread by setting the environment variable OMP_NUM_THREADS=1.

With your example script I see similar timings (actually worse). Theoretically we should be able to add some options to the cython .pyx file and profile the execution. I don't have time to try this out though. If that is something you'd be willing to look into let me know how I can help.

@ErlerPhilipp
Copy link
Author

I just tried it with:

export OMP_NUM_THREADS=1
python test.py

and got the same timings.

I'm no expert for Cyphon or OpenMP but the code looks fine, without obvious re-allocation. Maybe some it's some low-level OMP issue, I don't know. It's probably not worth my time right now. After all, simple brute force might be an option for me.

Anyway, thanks for the very quick confirmation!

@ErlerPhilipp
Copy link
Author

Quick update: brute-force is no option for ~35k points.

start_brute = time.time()
for _ in range(100):
    idx_brute = np.empty(shape=(query_pts.shape[0], num_neighbors), dtype=np.uint32)
    dist_brute = np.empty(shape=(query_pts.shape[0], num_neighbors), dtype=np.float32)
    for qi, query_pt in enumerate(query_pts):
        dist_brute_query = cartesian_dist_1_n(query_pt, data_pts)
        sort_ids = np.argsort(dist_brute_query)
        idx_brute_knn = sort_ids[:num_neighbors]
        idx_brute[qi] = idx_brute_knn
        dist_brute[qi] = dist_brute_query[idx_brute_knn]
end_brute = time.time()
print('brute: {}'.format(end_brute - start_brute))

results in ~13.5 sec for k=1000. I guess, I'll stick with SciPy for now.

@ingowald
Copy link

Reason this is getting so slow for larger k is that I currently do my "sorted list of currently closest elements" by simply sorting that list every time I put in a new element. For small k that's just as good as using a heap instead, and much simpler to code / easier to read, so that's why i picked that (I hadn't expected anybody to use more than a k of 10!). I'l switch to a heap implmentation, that should fix that.

@ErlerPhilipp
Copy link
Author

Oh nice! Would still be useful for me. I will try it when available. I guess deep learning on point clouds adds some unexpected use cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants