Skip to content

Updating Healpix CUDA primitive #290

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Updating Healpix CUDA primitive #290

wants to merge 11 commits into from

Conversation

ASKabalan
Copy link
Collaborator

@ASKabalan ASKabalan commented Mar 26, 2025

Adding a few updates

  • Updating to the newest custom call API (API 4) using FFI
  • implementing a grad rule for healpix cuda FFT
  • Implementing a Batching rule

A batching rule seems to be very important for two things
Being able to jacrev/ jacfwd
and because in most cases .. the size of a healpix map can fit on a single GPU but sometimes we want to batch the spherical transform

I will be doing that next

@ASKabalan ASKabalan marked this pull request as draft March 26, 2025 16:25
@ASKabalan ASKabalan marked this pull request as ready for review March 28, 2025 16:08
@ASKabalan
Copy link
Collaborator Author

Hello @matt-graham @jasonmcewen @CosmoMatt

Just a quick PR to wrap up a few stuff

  1. Updated the binding API to the newest FFI
  2. Added a vmap implementation of the cuda primitive
  3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell)
  4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100
  5. Removed two files which are now no longer needed with the FFI API (kernel helpers) (so maybe they should be removed from the license section)
  6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982

And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)

There is an issue with building pyssht not sure that this is my fault

I will check the failing worflows when I get the chance, but in the meantime a review is appreciated

Copy link
Collaborator

@matt-graham matt-graham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @matt-graham @jasonmcewen @CosmoMatt

Just a quick PR to wrap up a few stuff

1. Updated the binding API to the newest [FFI](https://docs.jax.dev/en/latest/ffi.html)

2. Added a vmap implementation of the cuda primitive

3. Added a transpose rule which allows jacfwd and jacrev (consequently grad aswell)

4. added more tests https://github.com/astro-informatics/s2fft/blob/ASKabalan/tests/test_healpix_ffts.py#L100

5. Removed two files which are now no longer needed with the FFI API ([kernel helpers](https://github.com/astro-informatics/s2fft/blob/main/lib/include/kernel_helpers.h)) (so maybe they should be removed from the license section)

6. Constrained nanobind to be nanobind >=2.0,<2.6" because of a regression [[BUG]: Regression when using scikit build tools and nanobind wjakob/nanobind#982](https://github.com/wjakob/nanobind/issues/982)

And finally I added cudastreamhandler which is used to split the XLA provided stream for the VMAP lowering (this header is my own work)

There is an issue with building pyssht not sure that this is my fault

I will check the failing worflows when I get the chance, but in the meantime a review is appreciated

Hi @ASKabalan, sorry for the delay in getting back to you.

This all sounds great - thanks for picking up #237 in particular and for the updates to use the newer FFI interface.

With regards to the failing workflows - this was probably due to #292 which was fixed in #293. If you merge in latest main here that should hopefully resolve the upstream dependency build problems that were causing the test workflows to fail.

I've added some initial review comments below. Will have a closer look next week and try testing this out, but don't have access to GPU machine atm.

Comment on lines 150 to 151
flm_hp = samples.flm_2d_to_hp(flm, L)
f = hp.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use s2fft.inverse(flm, L=L, reality=False, method="jax", sampling="healpix") here instead of going via healpy? Rationale being that I would have a slight preference for minimising the number of additional tests that depend on healpy as it we are no longer requiring it as direct dependency for package and in the long run it might be possible to also remove it as a test dependency.

@matt-graham
Copy link
Collaborator

I've tried building, installing and running this on a system with CUDA 12.6 + a NVIDIA A100, and running the HEALPix FFT tests with

pytest tests/test_healpix_ffts.py

consistently the tests hang when trying to run the first test_healpix_fft_cuda instance.

Running just the IFFT tests with

pytest tests/test_healpix_ffts.py::test_healpix_ifft_cuda

the tests for both set of test parameters pass.

Trying to dig into this a bit, running the following locally

import healpy
import jax
import s2fft
import numpy

jax.config.update("jax_enable_x64", True)

seed = 20250416
nside = 4
L = 2 * nside
reality = False

rng = numpy.random.default_rng(seed)
flm = s2fft.utils.signal_generator.generate_flm(rng=rng, L=L, reality=reality)
flm_hp = s2fft.sampling.s2_samples.flm_2d_to_hp(flm, L)
f = healpy.sphtfunc.alm2map(flm_hp, nside, lmax=L - 1)
flm_cuda = s2fft.utils.healpix_ffts.healpix_fft_cuda(f=f, L=L, nside=nside, reality=reality).block_until_ready()

raises an error

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: : CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

so it looks like there is some memory addressing issue somewhere in the healpix_fft_cuda implementation?

@ASKabalan
Copy link
Collaborator Author

Thank you

I was able to reproduce with 12.4.1 but not locally with 12.4

I will take a look

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Check autodiff and batching support for healpix_fft_cuda primitive and add if needed
2 participants