Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX and Numpy functions: How to structure them? #58

Open
jakharkaran opened this issue Mar 27, 2024 · 4 comments
Open

JAX and Numpy functions: How to structure them? #58

jakharkaran opened this issue Mar 27, 2024 · 4 comments
Assignees
Labels
enhancement New feature or request question Further information is requested

Comments

@jakharkaran
Copy link
Collaborator

jakharkaran commented Mar 27, 2024

This repository contains a 2D Navier-Stokes equation solver and data processing methods. The solver, written using the JAX library, is computationally expensive and leverages GPU acceleration. The less intensive post-processing methods use NumPy.

Some functions are required by both the solver and post-processing. Currently, duplicate copies exist – one for JAX and one for NumPy. What is the best way to optimize this code structure?

  • Passing a backend variable backend = 'numpy' or 'jax': Suitable for functions where the underlying structure is identical between NumPy and JAX, with only the library calls differing
  • Write all the functions in JAX (Need to think over it)
    Pros:
    • User friendliness
      Cons:
    • Requires careful consideration of JAX-specific function implementations.
@jakharkaran jakharkaran added the enhancement New feature or request label Mar 27, 2024
@jakharkaran jakharkaran self-assigned this Mar 27, 2024
@jakharkaran jakharkaran added the question Further information is requested label Mar 27, 2024
@rmojgani
Copy link
Member

rmojgani commented Mar 27, 2024

Also at low resolutions, numpy is faster than gpu-jax, so ability to run on either is beneficential, falling back to numpy when jax is not installed is also favorable

@jakharkaran
Copy link
Collaborator Author

For the solver, At low resolution, JAX-CPU is faster than JAX-GPU. I haven't tested for numpy. I can mention this in the readme for now.

@rmojgani
Copy link
Member

I mean for this reason, it is good to have it as a selectionable option, even for the runs

@jwallwork23
Copy link

A colleague just told me about the Python "Array API" concept, which might be useful reading for this:
https://data-apis.org/array-api/latest/purpose_and_scope.html#stakeholders

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants