Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Multiple GPUs #1495

Open
wants to merge 84 commits into
base: master
Choose a base branch
from
Open

Support for Multiple GPUs #1495

wants to merge 84 commits into from

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Dec 25, 2024

Initial support for multi-GPU optimization.

  • Adds a convenience function for getting ObjectiveFunction with multiple ForceBalance objectives that are distributed to multiple GPUs
  • Adds jit_with_device decorator to jit a function a specific device (without this everything runs on GPU id:0)
  • Adds pconcat to concatenate, hstack or vstack a list of arrays that lives on different devices to a single device. If the resultant arrays fits to GPU id:0, puts them to GPU, otherwise puts them to CPU. This function is used for compute and jac methods. If CPU is used, QR and other types of linear algebra on Jacobian get slower, but there is no memory restriction.
  • Adds _device_id to _Objective class (defaults to 0) for making parallelization work with other objectives (currently not automated, maybe add jax.device_put stuff to build method)

We won't see any speed improvement for the trust_region_subproblem solvers, because JAX doesn't support distributed linear algebra yet.

TODO:

  • Implement for ProximalProjection (a new jit_if_not_parallel decorator)
  • Remove the redundant lines of code
  • Implement multiple CPU
  • if we add a new GitHub action we can test it with virtual devices where you can make JAX see different cores as different devices

Things to consider,

  • This is not the most efficient way. I used _jvp_blocked for the derivatives. Ideally, we should run each Jacobian calculation in parallel using MPI (or similar)
  • Maybe implement a new optimizer that uses distributed matrix operations instead of QR and SVD. Probably future PR
  • Make the default grid of get_forcabalance_parallel smaller.

Resolves #1071 (currently without MPI4JAX but we should use it to run operations on GPUs parallel)

Copy link

codecov bot commented Dec 26, 2024

Codecov Report

Attention: Patch coverage is 64.56693% with 45 lines in your changes missing coverage. Please review.

Project coverage is 95.63%. Comparing base (0658b9f) to head (46ed909).

Files with missing lines Patch % Lines
desc/objectives/getters.py 3.70% 26 Missing ⚠️
desc/backend.py 52.38% 10 Missing ⚠️
desc/objectives/objective_funs.py 86.53% 7 Missing ⚠️
desc/optimize/_constraint_wrappers.py 92.59% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1495      +/-   ##
==========================================
- Coverage   95.69%   95.63%   -0.07%     
==========================================
  Files         101      100       -1     
  Lines       26348    26372      +24     
==========================================
+ Hits        25215    25220       +5     
- Misses       1133     1152      +19     
Files with missing lines Coverage Δ
desc/objectives/_bootstrap.py 97.18% <ø> (ø)
desc/objectives/_coils.py 99.38% <ø> (ø)
desc/objectives/_equilibrium.py 95.08% <ø> (ø)
desc/objectives/_fast_ion.py 98.78% <ø> (ø)
desc/objectives/_free_boundary.py 96.17% <ø> (ø)
desc/objectives/_generic.py 99.40% <ø> (ø)
desc/objectives/_geometry.py 96.79% <ø> (ø)
desc/objectives/_neoclassical.py 98.75% <ø> (ø)
desc/objectives/_omnigenity.py 97.06% <ø> (ø)
desc/objectives/_power_balance.py 91.83% <ø> (ø)
... and 7 more

... and 2 files with indirect coverage changes

Copy link
Contributor

github-actions bot commented Dec 26, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     +1.34 +/- 3.85     | +8.72e-03 +/- 2.52e-02 |  6.62e-01 +/- 2.1e-02  |  6.53e-01 +/- 1.4e-02  |
 test_build_transform_fft_highres        |     +0.83 +/- 6.54     | +7.87e-03 +/- 6.22e-02 |  9.59e-01 +/- 5.9e-02  |  9.51e-01 +/- 2.0e-02  |
 test_equilibrium_init_lowres            |     +1.21 +/- 2.42     | +5.04e-02 +/- 1.01e-01 |  4.21e+00 +/- 9.3e-02  |  4.16e+00 +/- 3.7e-02  |
 test_objective_compile_atf              |     -0.74 +/- 2.91     | -6.31e-02 +/- 2.48e-01 |  8.46e+00 +/- 8.4e-02  |  8.52e+00 +/- 2.3e-01  |
 test_objective_compute_atf              |     -0.93 +/- 2.81     | -1.54e-04 +/- 4.67e-04 |  1.64e-02 +/- 3.1e-04  |  1.66e-02 +/- 3.5e-04  |
 test_objective_jac_atf                  |     -1.89 +/- 1.78     | -3.93e-02 +/- 3.70e-02 |  2.04e+00 +/- 2.1e-02  |  2.08e+00 +/- 3.1e-02  |
 test_perturb_1                          |     +5.79 +/- 2.16     | +8.89e-01 +/- 3.31e-01 |  1.62e+01 +/- 3.1e-01  |  1.54e+01 +/- 1.0e-01  |
 test_proximal_jac_atf                   |     -1.20 +/- 1.23     | -9.91e-02 +/- 1.02e-01 |  8.15e+00 +/- 7.1e-02  |  8.25e+00 +/- 7.2e-02  |
 test_proximal_freeb_compute             |     -0.90 +/- 1.42     | -1.96e-03 +/- 3.09e-03 |  2.16e-01 +/- 2.4e-03  |  2.18e-01 +/- 1.9e-03  |
-test_solve_fixed_iter                   |    +18.15 +/- 1.76     | +6.08e+00 +/- 5.88e-01 |  3.96e+01 +/- 4.4e-01  |  3.35e+01 +/- 3.9e-01  |
 test_objective_compute_ripple           |     +0.20 +/- 2.21     | +1.41e-03 +/- 1.58e-02 |  7.16e-01 +/- 9.4e-03  |  7.14e-01 +/- 1.3e-02  |
 test_objective_grad_ripple              |     -0.33 +/- 1.25     | -9.22e-03 +/- 3.52e-02 |  2.80e+00 +/- 2.7e-02  |  2.81e+00 +/- 2.3e-02  |
 test_build_transform_fft_lowres         |     +0.94 +/- 7.25     | +5.84e-03 +/- 4.50e-02 |  6.26e-01 +/- 3.5e-02  |  6.21e-01 +/- 2.8e-02  |
 test_equilibrium_init_medres            |     +0.84 +/- 1.55     | +3.69e-02 +/- 6.80e-02 |  4.42e+00 +/- 5.2e-02  |  4.38e+00 +/- 4.4e-02  |
 test_equilibrium_init_highres           |     +0.84 +/- 1.48     | +4.50e-02 +/- 7.89e-02 |  5.39e+00 +/- 6.4e-02  |  5.35e+00 +/- 4.6e-02  |
 test_objective_compile_dshape_current   |     +0.71 +/- 4.95     | +3.00e-02 +/- 2.08e-01 |  4.23e+00 +/- 1.4e-01  |  4.20e+00 +/- 1.6e-01  |
 test_objective_compute_dshape_current   |     +0.80 +/- 2.36     | +4.33e-05 +/- 1.28e-04 |  5.46e-03 +/- 9.2e-05  |  5.42e-03 +/- 8.8e-05  |
 test_objective_jac_dshape_current       |     -0.51 +/- 7.26     | -2.19e-04 +/- 3.14e-03 |  4.30e-02 +/- 2.2e-03  |  4.32e-02 +/- 2.2e-03  |
-test_perturb_2                          |     +5.07 +/- 1.43     | +1.04e+00 +/- 2.95e-01 |  2.16e+01 +/- 2.0e-01  |  2.06e+01 +/- 2.2e-01  |
 test_proximal_jac_atf_with_eq_update    |     +1.77 +/- 0.93     | +3.02e-01 +/- 1.58e-01 |  1.73e+01 +/- 9.0e-02  |  1.70e+01 +/- 1.3e-01  |
 test_proximal_freeb_jac                 |     +0.31 +/- 1.08     | +2.18e-02 +/- 7.66e-02 |  7.13e+00 +/- 4.9e-02  |  7.11e+00 +/- 5.9e-02  |
-test_solve_fixed_iter_compiled          |    +11.42 +/- 1.71     | +2.40e+00 +/- 3.60e-01 |  2.34e+01 +/- 3.3e-01  |  2.10e+01 +/- 1.4e-01  |
 test_LinearConstraintProjection_build   |     +3.38 +/- 2.35     | +3.80e-01 +/- 2.65e-01 |  1.16e+01 +/- 1.1e-01  |  1.12e+01 +/- 2.4e-01  |
 test_objective_compute_ripple_spline    |     +0.03 +/- 1.19     | +1.09e-04 +/- 4.13e-03 |  3.47e-01 +/- 1.8e-03  |  3.47e-01 +/- 3.7e-03  |
 test_objective_grad_ripple_spline       |     +1.46 +/- 1.83     | +2.06e-02 +/- 2.58e-02 |  1.43e+00 +/- 2.1e-02  |  1.41e+00 +/- 1.5e-02  |

@dpanici
Copy link
Collaborator

dpanici commented Jan 6, 2025

#763
check for overlap with this one

@YigitElma YigitElma requested review from a team, rahulgaur104, f0uriest, dpanici, sinaatalay and unalmis and removed request for a team February 14, 2025 05:43
@YigitElma YigitElma self-assigned this Feb 14, 2025
@YigitElma YigitElma added performance New feature or request to make the code faster gpu Issues related to the GPU backend labels Feb 14, 2025
YigitElma and others added 2 commits February 20, 2025 19:55
…e so I removed the obj._device attribute, instead use config kind and device_id to have same function
@@ -4,7 +4,8 @@

import numpy as np

from desc.backend import jit, jnp, put
from desc import config as desc_config
from desc.backend import jax, jit, jnp, put, pconcat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📝 [flake8] <1> reported by reviewdog 🐶
isort found an import in the wrong position

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
gpu Issues related to the GPU backend performance New feature or request to make the code faster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Parallelize across multiple GPUs with MPI4Jax
3 participants