Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions related to HEALPix support #250

Open
Magwos opened this issue Dec 3, 2024 · 1 comment
Open

Questions related to HEALPix support #250

Magwos opened this issue Dec 3, 2024 · 1 comment

Comments

@Magwos
Copy link

Magwos commented Dec 3, 2024

Hello,
I have been working on interfacing s2fft with https://github.com/CMBSciPol/jax-healpy (currently only the transform for spin 0 is included, I am working on the interface towards spin=-2 and 2 and the adaptation of other healpy routines). Our goal is to have jittable and differentiable JAX adaptations of the healpy routines.

For this interface, I had some questions related to the methods jax and jax_healpy.
First, is I am not sure if the best practice would be to use only the jax method for both CPU and GPU, or if on CPU it would be better to switch to jax_healpy (for instance using jax.lax.platform)

Following this point, the forward function with the jax_healpy method is currently not jitable as the healpy_forward expects a jnp.array, but it transforms it right away into a np.array which is not tracable, here:

flm = jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=iter))

Would it be possible to have it jittable? This can probably be solved with a pure_callback to make it jittable, and on which you have already the custom_jvp in place to make it auto-differentiable, but I am not sure if it is the best way to proceed.

Otherwise, I guess the only way to proceed is to use the jax method instead?

I also saw that you were using jnp.repeat here:

indices = jnp.repeat(jnp.expand_dims(jnp.arange(L), -1), 2 * L - 1, axis=-1)
in the forward_jax. On my side, I found jnp.broadcast_to to be in general more performant for the same use. In the specific line I highlighted, you may be able to replace this line with indices = jnp.broadcast_to(jnp.arange(L), (2 * L - 1, L,)).T to obtain the same result with a gain of a factor 2 in average (for L = 255 in my benchmark) for this specific execution.

Again, thanks a lot for all of your hard work!

@Magwos Magwos changed the title Questions related to HEALPix Questions related to HEALPix support Dec 3, 2024
@matt-graham
Copy link
Collaborator

Hi @Magwos, apologies for our slow reply!

For this interface, I had some questions related to the methods jax and jax_healpy. First, is I am not sure if the best practice would be to use only the jax method for both CPU and GPU, or if on CPU it would be better to switch to jax_healpy (for instance using jax.lax.platform)

At the moment the healpy wrapper implementation corresponding to method = "jax_healpy" is significantly faster than the alternative pure JAX implementation corresponding to method = "jax" when running on CPU for larger bandlimits L, so on CPU I would say that is probably the best bet in terms of performance for now. Some illustrative benchmark results comparing s2fft.forward and s2fft.inverse for method = "jax" and method = "jax_healpy" for various L when running on a CPU:

forward
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
(method: jax, L: 64, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):  0.0095s, max(time):   0.011s
(method: jax, L: 128, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):   0.081s, max(time):   0.084s
(method: jax, L: 256, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):    0.77s, max(time):    0.78s
(method: jax_healpy, L: 64, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):  0.0040s, max(time):  0.0044s
(method: jax_healpy, L: 128, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):   0.015s, max(time):   0.035s
(method: jax_healpy, L: 256, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):   0.041s, max(time):   0.050s
inverse
(method: jax, L: 64, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):   0.012s, max(time):   0.014s
(method: jax, L: 128, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):   0.092s, max(time):   0.096s
(method: jax, L: 256, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):    0.84s, max(time):    0.90s
(method: jax_healpy, L: 64, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time): 0.00037s, max(time): 0.00045s
(method: jax_healpy, L: 128, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time): 0.00080s, max(time): 0.00086s
(method: jax_healpy, L: 256, L_lower: 0, sampling: healpix, spin: 0, reality: True, spmd: False):
    min(time):  0.0031s, max(time):  0.0032s

As of #244 being merged, the healpy wrapper also supports forward- and higher-order automatic differentiation.

Following this point, the forward function with the jax_healpy method is currently not jitable as the healpy_forward expects a jnp.array, but it transforms it right away into a np.array which is not tracable, here:

flm = jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=iter))

Would it be possible to have it jittable? This can probably be solved with a pure_callback to make it jittable, and on which you have already the custom_jvp in place to make it auto-differentiable, but I am not sure if it is the best way to proceed.

Otherwise, I guess the only way to proceed is to use the jax method instead?

We updated the implementation of the healpy wrappers in #244 after I think you first wrote the above. We now define custom JAX (linear) primitives healpy_map2alm and healpy_alm2map in s2fft.transforms.c_backend_spherical which wrap the corresponding healpy.map2alm and healpy.alm2map functions, and define custom transposition rules to support both forward and reverse mode automatic differentiation. Currently we do not define lowerings however for these primitives so they are still not compatible with jax.jit. Unfortunately now that we are defining these as custom primitives it is also not possible to use jax.pure_callback to make them JITable as even if we use jax.pure_callback in their implementations in

def _healpy_map2alm_impl(f: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
return jnp.array(healpy.map2alm(np.array(f), lmax=L - 1, iter=0))

def _healpy_alm2map_impl(flm: jnp.ndarray, L: int, nside: int) -> jnp.ndarray:
return jnp.array(healpy.alm2map(np.array(flm), lmax=L - 1, nside=nside))

when JIT transformed JAX bypasses these implementation rules and instead looks for lowering rules for the current platform resulting in something like

 NotImplementedError: MLIR translation rule for primitive 'healpy_alm2map' not found for platform cpu

One option I think would be to define a lowering rule for the CPU which uses the compiled C++ functions healpy is wrapping. Alternatively if jax-ml/jax#24726 or something like it is merged, we may be able to use jax.pure_callback and still define custom derivative behaviour for both forward and reverse mode without using the internal primitive API.

I also saw that you were using jnp.repeat here:

indices = jnp.repeat(jnp.expand_dims(jnp.arange(L), -1), 2 * L - 1, axis=-1)

in the forward_jax. On my side, I found jnp.broadcast_to to be in general more performant for the same use. In the specific line I highlighted, you may be able to replace this line with indices = jnp.broadcast_to(jnp.arange(L), (2 * L - 1, L,)).T to obtain the same result with a gain of a factor 2 in average (for L = 255 in my benchmark) for this specific execution.

Thanks for the heads up, we'll have a look at using this alternative!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants