-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Questions related to HEALPix support #250
Comments
Hi @Magwos, apologies for our slow reply!
At the moment the
As of #244 being merged, the
We updated the implementation of the s2fft/s2fft/transforms/c_backend_spherical.py Lines 300 to 301 in 8f6e4d5
s2fft/s2fft/transforms/c_backend_spherical.py Lines 347 to 348 in 8f6e4d5
when JIT transformed JAX bypasses these implementation rules and instead looks for lowering rules for the current platform resulting in something like
One option I think would be to define a lowering rule for the CPU which uses the compiled C++ functions
Thanks for the heads up, we'll have a look at using this alternative! |
Hello,
I have been working on interfacing s2fft with https://github.com/CMBSciPol/jax-healpy (currently only the transform for spin 0 is included, I am working on the interface towards spin=-2 and 2 and the adaptation of other healpy routines). Our goal is to have jittable and differentiable JAX adaptations of the healpy routines.
For this interface, I had some questions related to the methods jax and jax_healpy.
First, is I am not sure if the best practice would be to use only the jax method for both CPU and GPU, or if on CPU it would be better to switch to jax_healpy (for instance using
jax.lax.platform
)Following this point, the forward function with the jax_healpy method is currently not jitable as the healpy_forward expects a
jnp.array,
but it transforms it right away into anp.array
which is not tracable, here:s2fft/s2fft/transforms/c_backend_spherical.py
Line 331 in 5d1d13f
Would it be possible to have it jittable? This can probably be solved with a pure_callback to make it jittable, and on which you have already the custom_jvp in place to make it auto-differentiable, but I am not sure if it is the best way to proceed.
Otherwise, I guess the only way to proceed is to use the jax method instead?
I also saw that you were using
jnp.repeat
here:s2fft/s2fft/transforms/spherical.py
Line 733 in 5d1d13f
jnp.broadcast_to
to be in general more performant for the same use. In the specific line I highlighted, you may be able to replace this line withindices = jnp.broadcast_to(jnp.arange(L), (2 * L - 1, L,)).T
to obtain the same result with a gain of a factor 2 in average (for L = 255 in my benchmark) for this specific execution.Again, thanks a lot for all of your hard work!
The text was updated successfully, but these errors were encountered: