Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional support of JAX to accelerate some partial derivatives #418

Open
kanekosh opened this issue Dec 19, 2023 · 1 comment
Open

Comments

@kanekosh
Copy link
Contributor

kanekosh commented Dec 19, 2023

Description of feature

When using a dense VLM mesh, compute_partials in some components (e.g., eval_mtx in aerodynamics) becomes a bottleneck for derivative computations. These partials can be accelerated by replacing the current analytical derivatives with AD.
Aditya Deshpande and Sriram Bommakanti tried it out for the AE588 project, and they showed that AD actually accelerated the partials. Their prototype implementation can be found in their fork. Note that they used AD for only some part of the compute_partials computations, but not to the entire partial computation.

AD support should be optional because we don't want to add JAX as a hard dependency (for now), and AD likely doesn't offer performance benefits for moderate mesh size.

Potential solution

  1. Run profiling and identify the components that can be accelerated by AD. eval_mtx is one, but there could be others.
  2. Replace (part of) the compute_partials method with AD. We'll need to try out multiple AD options as Aditya and Sriram did.
  3. Add an optional dependency on JAX in setup.py
  4. Add a documentation page on AD - ideally, suggest a mesh size threshold at which the AD becomes faster than the default analytical partials.
@kanekosh
Copy link
Contributor Author

kanekosh commented Jul 8, 2024

Ultimately we want to fully automate the compute_partials using AD. OpenMDAO docs now have some examples to do so.

One of the current obstacles for using AD in OAS is that we cannot automate the sparsity exploitation. Many components in OAS have array inputs and array outputs, where the partials are usually sparse and we declare the sparsity pattern when declare_partials. I don't think this partial sparsity declaration is automated by OM or Jax at this moment.
Therefore, we'd still have to manually figure out the sparsity pattern and set corresponding AD seeds to hit jvp or vjp. This really doesn't save the implementation efforts because figuring out the sparsity pattern is often more time-consuming than implementing compute-partials. Or we'd have to just ignore the sparsity pattern and treat all partials as dense, but this is not good for performance.

If the sparse partials with AD is supported by OpenMDAO in the future, then I think OAS can benefit from AD a lot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant