Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix indexing of nupts for batched nufft (n_tot > 1) #47

Merged
merged 5 commits into from
Dec 6, 2023
Merged

Conversation

lgarrison
Copy link
Member

jax-finufft has two levels of batching: an inner level where finufft does multiple transforms for the same set of NU points (n_transf > 1), and an outer level where jax-finufft does multiple transforms with different NU points (looping over n_tot). Therefore the NU points arrays have shape [n_tot, n_j], and the source array has shape [n_tot, n_transf, n_j]. However, the NU points arrays were being indexed as if they had the latter shape. This was leading to out-of-bounds memory accesses on the GPU when trying to use n_tot > 1, e.g. as a result of jax.vmap.

This fixes the GPU runtime error in #37.

The submodule update points us at a more recent finufft. We're still using at a fork while we wait for upstream work to finish, but this update brings us much closer to the current state of the upstream. The fork also uses fewer threads per block in certain register-intensive 3D operations, which should fix CUDA errors about not enough resources.

CC @Matematija

@lgarrison lgarrison requested a review from dfm December 5, 2023 21:29
@Matematija
Copy link

Amazing. I will report back as soon as I test things out on my side.

@Matematija
Copy link

Works for me now. Thank you again.

Copy link
Collaborator

@dfm dfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! And it looks like this isn't a problem in the CPU version, right?

@lgarrison
Copy link
Member Author

Interestingly, it looks like the same bug was fixed on the CPU a few years ago (d4622b8), but that was after the GPU fork, and we never ported the fix to the GPU code. Oops!

@lgarrison lgarrison merged commit 772ca6d into main Dec 6, 2023
5 checks passed
@lgarrison lgarrison deleted the fix-setpts branch December 6, 2023 20:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants