-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Multiple GPUs #1495
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1495 +/- ##
==========================================
- Coverage 95.69% 95.63% -0.07%
==========================================
Files 101 100 -1
Lines 26348 26372 +24
==========================================
+ Hits 25215 25220 +5
- Misses 1133 1152 +19
|
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | +1.34 +/- 3.85 | +8.72e-03 +/- 2.52e-02 | 6.62e-01 +/- 2.1e-02 | 6.53e-01 +/- 1.4e-02 |
test_build_transform_fft_highres | +0.83 +/- 6.54 | +7.87e-03 +/- 6.22e-02 | 9.59e-01 +/- 5.9e-02 | 9.51e-01 +/- 2.0e-02 |
test_equilibrium_init_lowres | +1.21 +/- 2.42 | +5.04e-02 +/- 1.01e-01 | 4.21e+00 +/- 9.3e-02 | 4.16e+00 +/- 3.7e-02 |
test_objective_compile_atf | -0.74 +/- 2.91 | -6.31e-02 +/- 2.48e-01 | 8.46e+00 +/- 8.4e-02 | 8.52e+00 +/- 2.3e-01 |
test_objective_compute_atf | -0.93 +/- 2.81 | -1.54e-04 +/- 4.67e-04 | 1.64e-02 +/- 3.1e-04 | 1.66e-02 +/- 3.5e-04 |
test_objective_jac_atf | -1.89 +/- 1.78 | -3.93e-02 +/- 3.70e-02 | 2.04e+00 +/- 2.1e-02 | 2.08e+00 +/- 3.1e-02 |
test_perturb_1 | +5.79 +/- 2.16 | +8.89e-01 +/- 3.31e-01 | 1.62e+01 +/- 3.1e-01 | 1.54e+01 +/- 1.0e-01 |
test_proximal_jac_atf | -1.20 +/- 1.23 | -9.91e-02 +/- 1.02e-01 | 8.15e+00 +/- 7.1e-02 | 8.25e+00 +/- 7.2e-02 |
test_proximal_freeb_compute | -0.90 +/- 1.42 | -1.96e-03 +/- 3.09e-03 | 2.16e-01 +/- 2.4e-03 | 2.18e-01 +/- 1.9e-03 |
-test_solve_fixed_iter | +18.15 +/- 1.76 | +6.08e+00 +/- 5.88e-01 | 3.96e+01 +/- 4.4e-01 | 3.35e+01 +/- 3.9e-01 |
test_objective_compute_ripple | +0.20 +/- 2.21 | +1.41e-03 +/- 1.58e-02 | 7.16e-01 +/- 9.4e-03 | 7.14e-01 +/- 1.3e-02 |
test_objective_grad_ripple | -0.33 +/- 1.25 | -9.22e-03 +/- 3.52e-02 | 2.80e+00 +/- 2.7e-02 | 2.81e+00 +/- 2.3e-02 |
test_build_transform_fft_lowres | +0.94 +/- 7.25 | +5.84e-03 +/- 4.50e-02 | 6.26e-01 +/- 3.5e-02 | 6.21e-01 +/- 2.8e-02 |
test_equilibrium_init_medres | +0.84 +/- 1.55 | +3.69e-02 +/- 6.80e-02 | 4.42e+00 +/- 5.2e-02 | 4.38e+00 +/- 4.4e-02 |
test_equilibrium_init_highres | +0.84 +/- 1.48 | +4.50e-02 +/- 7.89e-02 | 5.39e+00 +/- 6.4e-02 | 5.35e+00 +/- 4.6e-02 |
test_objective_compile_dshape_current | +0.71 +/- 4.95 | +3.00e-02 +/- 2.08e-01 | 4.23e+00 +/- 1.4e-01 | 4.20e+00 +/- 1.6e-01 |
test_objective_compute_dshape_current | +0.80 +/- 2.36 | +4.33e-05 +/- 1.28e-04 | 5.46e-03 +/- 9.2e-05 | 5.42e-03 +/- 8.8e-05 |
test_objective_jac_dshape_current | -0.51 +/- 7.26 | -2.19e-04 +/- 3.14e-03 | 4.30e-02 +/- 2.2e-03 | 4.32e-02 +/- 2.2e-03 |
-test_perturb_2 | +5.07 +/- 1.43 | +1.04e+00 +/- 2.95e-01 | 2.16e+01 +/- 2.0e-01 | 2.06e+01 +/- 2.2e-01 |
test_proximal_jac_atf_with_eq_update | +1.77 +/- 0.93 | +3.02e-01 +/- 1.58e-01 | 1.73e+01 +/- 9.0e-02 | 1.70e+01 +/- 1.3e-01 |
test_proximal_freeb_jac | +0.31 +/- 1.08 | +2.18e-02 +/- 7.66e-02 | 7.13e+00 +/- 4.9e-02 | 7.11e+00 +/- 5.9e-02 |
-test_solve_fixed_iter_compiled | +11.42 +/- 1.71 | +2.40e+00 +/- 3.60e-01 | 2.34e+01 +/- 3.3e-01 | 2.10e+01 +/- 1.4e-01 |
test_LinearConstraintProjection_build | +3.38 +/- 2.35 | +3.80e-01 +/- 2.65e-01 | 1.16e+01 +/- 1.1e-01 | 1.12e+01 +/- 2.4e-01 |
test_objective_compute_ripple_spline | +0.03 +/- 1.19 | +1.09e-04 +/- 4.13e-03 | 3.47e-01 +/- 1.8e-03 | 3.47e-01 +/- 3.7e-03 |
test_objective_grad_ripple_spline | +1.46 +/- 1.83 | +2.06e-02 +/- 2.58e-02 | 1.43e+00 +/- 2.1e-02 | 1.41e+00 +/- 1.5e-02 | |
#763 |
…e so I removed the obj._device attribute, instead use config kind and device_id to have same function
@@ -4,7 +4,8 @@ | |||
|
|||
import numpy as np | |||
|
|||
from desc.backend import jit, jnp, put | |||
from desc import config as desc_config | |||
from desc.backend import jax, jit, jnp, put, pconcat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📝 [flake8] <1> reported by reviewdog 🐶
isort found an import in the wrong position
Initial support for multi-GPU optimization.
ObjectiveFunction
with multipleForceBalance
objectives that are distributed to multiple GPUsjit_with_device
decorator to jit a function a specific device (without this everything runs on GPU id:0)pconcat
to concatenate, hstack or vstack a list of arrays that lives on different devices to a single device. If the resultant arrays fits to GPU id:0, puts them to GPU, otherwise puts them to CPU. This function is used for compute and jac methods. If CPU is used, QR and other types of linear algebra on Jacobian get slower, but there is no memory restriction._device_id
to_Objective
class (defaults to 0) for making parallelization work with other objectives (currently not automated, maybe addjax.device_put
stuff to build method)We won't see any speed improvement for the trust_region_subproblem solvers, because
JAX
doesn't support distributed linear algebra yet.TODO:
ProximalProjection
(a newjit_if_not_parallel
decorator)Things to consider,
_jvp_blocked
for the derivatives. Ideally, we should run each Jacobian calculation in parallel using MPI (or similar)get_forcabalance_parallel
smaller.Resolves #1071 (currently without MPI4JAX but we should use it to run operations on GPUs parallel)