Fix indexing of nupts for batched nufft (n_tot > 1
)
#47
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
jax-finufft has two levels of batching: an inner level where finufft does multiple transforms for the same set of NU points (
n_transf > 1
), and an outer level where jax-finufft does multiple transforms with different NU points (looping overn_tot
). Therefore the NU points arrays have shape[n_tot, n_j]
, and the source array has shape[n_tot, n_transf, n_j]
. However, the NU points arrays were being indexed as if they had the latter shape. This was leading to out-of-bounds memory accesses on the GPU when trying to usen_tot > 1
, e.g. as a result ofjax.vmap
.This fixes the GPU runtime error in #37.
The submodule update points us at a more recent finufft. We're still using at a fork while we wait for upstream work to finish, but this update brings us much closer to the current state of the upstream. The fork also uses fewer threads per block in certain register-intensive 3D operations, which should fix CUDA errors about not enough resources.
CC @Matematija