Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Multiple GPUs #1495

Open
wants to merge 84 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
41d3595
add multiple gpu support to backend
YigitElma Dec 25, 2024
860b1ff
fix if statement
YigitElma Dec 25, 2024
003eca9
fix stuff
YigitElma Dec 25, 2024
74f5f3d
try
YigitElma Dec 25, 2024
eed3f6d
fix issue
YigitElma Dec 25, 2024
b7e9435
update jac_chunk_size assignment
YigitElma Dec 25, 2024
ab7402e
try putting the grid accross devices
YigitElma Dec 25, 2024
2a7ab0d
fix issue with none constants
YigitElma Dec 25, 2024
afa349c
revert jnp.asarrays in grid
YigitElma Dec 25, 2024
04af924
replicate state vector on all deviecs
YigitElma Dec 25, 2024
aa4f9aa
allow variable number of gpus, copy some data to every device
YigitElma Dec 25, 2024
d676888
not put back to one device for testing
YigitElma Dec 25, 2024
d3d2663
handle num_device=1 case
YigitElma Dec 25, 2024
ec05139
update
YigitElma Dec 25, 2024
ea0b584
fix typo
YigitElma Dec 25, 2024
f353649
fix issue
YigitElma Dec 25, 2024
a7847df
it was a stupid mistake
YigitElma Dec 25, 2024
5c0f811
I don't know why this was changed
YigitElma Dec 25, 2024
c976088
put the copying inside the jitted part
YigitElma Dec 26, 2024
e15f7b2
shard A, Z and D too
YigitElma Dec 26, 2024
36cd4e1
fix
YigitElma Dec 26, 2024
7e82f6d
fix
YigitElma Dec 26, 2024
c963c1a
don't shard A
YigitElma Dec 26, 2024
ebd8dd1
clean up
YigitElma Dec 26, 2024
172d211
shard tangents too
YigitElma Dec 26, 2024
bd986be
shard v in different way
YigitElma Dec 26, 2024
163801e
don't cover set_device for coverage
YigitElma Dec 26, 2024
528c17c
Merge branch 'master' into yge/multigpu
YigitElma Jan 24, 2025
33b7c0b
add getter for parallel force objective
YigitElma Jan 29, 2025
35dd7b0
add notebook for testing
YigitElma Jan 29, 2025
e9c6e63
build and distribute objectives in getter
YigitElma Jan 29, 2025
9f19885
maybe use same grid res
YigitElma Jan 30, 2025
57ab00c
add build flag to getter
YigitElma Jan 30, 2025
b28bc4e
do not jit the ObjectiveFunction because jax doesn't allow it
YigitElma Jan 30, 2025
c8f4826
move extra stuff
YigitElma Feb 5, 2025
c3a4803
move whole objective on gpu
YigitElma Feb 5, 2025
b599b91
add pconcat function normal concatenate doesn't accepts arrays from d…
YigitElma Feb 6, 2025
05f705a
use more pconcat
YigitElma Feb 6, 2025
7c36f3a
test not passing constants
YigitElma Feb 6, 2025
66a4f95
try something
YigitElma Feb 6, 2025
293b6f0
try something
YigitElma Feb 6, 2025
5088395
instead replicate eq every device
YigitElma Feb 6, 2025
2c93a6a
try something
YigitElma Feb 6, 2025
84179d1
return replicated eq and use that otherwise outer eq and obj eq are n…
YigitElma Feb 6, 2025
1ee3452
reorder steps
YigitElma Feb 6, 2025
088324f
copy params to device before passing to function
YigitElma Feb 10, 2025
97c3dec
add device_id for forcebalance
YigitElma Feb 11, 2025
2b7e007
update notebook
YigitElma Feb 11, 2025
856a115
delete old line
YigitElma Feb 11, 2025
27d0c73
add testing cell
YigitElma Feb 11, 2025
3015545
clean up
YigitElma Feb 11, 2025
c8481e1
move params to device for printing too
YigitElma Feb 11, 2025
e9ae2da
Merge branch 'master' into yge/multigpu
YigitElma Feb 11, 2025
a800fd4
update notebook to plot grid
YigitElma Feb 11, 2025
69161c2
made it WORK! pass all params on given device, merge arrays on cpu or…
Feb 11, 2025
8b044eb
Merge remote-tracking branch 'refs/remotes/origin/yge/multigpu' into …
YigitElma Feb 12, 2025
23f6612
fix formatting after cluster
YigitElma Feb 12, 2025
1e1dfeb
fix some problems for testing and docs
YigitElma Feb 12, 2025
dc34d8e
Merge branch 'master' into yge/multigpu
YigitElma Feb 12, 2025
fd7638b
ignore multidevice for notebook tests, add additional warnings for gp…
YigitElma Feb 12, 2025
0a77b1a
add changelog, fix notebook tests
YigitElma Feb 12, 2025
637c1c2
Merge branch 'master' into yge/multigpu
dpanici Feb 12, 2025
306ae44
add warning for deriv_mode blocked and moving array to CPU
YigitElma Feb 12, 2025
f5dd1fa
add option to suppress cpu warning
YigitElma Feb 12, 2025
3326426
make upper case
YigitElma Feb 12, 2025
fef9a90
clean up set_device
YigitElma Feb 12, 2025
7e09142
nuch of clean up
YigitElma Feb 12, 2025
927e8aa
Merge branch 'master' into yge/multigpu
YigitElma Feb 12, 2025
c32d7b4
clean up, fix issues
YigitElma Feb 13, 2025
e2c0f77
fix set_device config['device'] problem
YigitElma Feb 13, 2025
2052633
update notebook and add device_id to all objectives
YigitElma Feb 13, 2025
0d4cc47
fix missing docs
YigitElma Feb 13, 2025
12ba4db
initial test for proximal
YigitElma Feb 13, 2025
f446804
add obj._device attr for cleaner device_put
YigitElma Feb 13, 2025
e5ed5cb
jit what you can, use pconcat
YigitElma Feb 13, 2025
6dd7611
fix device jit issue
YigitElma Feb 13, 2025
b3e961f
make _device None for single device cases
YigitElma Feb 13, 2025
bfb371c
ok now it is fixed
YigitElma Feb 13, 2025
b6b4337
Merge branch 'master' into yge/multigpu
YigitElma Feb 13, 2025
62e827e
implement multicpu, add a test, need to make it work tho
YigitElma Feb 14, 2025
315b4ed
improve test
YigitElma Feb 14, 2025
3edf125
Merge branch 'master' into yge/multigpu
YigitElma Feb 14, 2025
b22fb50
add multiprocessing, for some reason jax.Dvice object is not picklabl…
YigitElma Feb 21, 2025
46ed909
Merge branch 'master' into yge/multigpu
YigitElma Feb 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .github/workflows/notebook_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ jobs:
source .venv-${{ env.version }}/bin/activate
export PYTHONPATH=$(pwd)
pytest -v --nbmake "./docs/notebooks" \
--nbmake-timeout=2000 \
--ignore=./docs/notebooks/zernike_eval.ipynb \
--splits 3 \
--group ${{ matrix.group }} \
--splitting-algorithm least_duration
--nbmake-timeout=2000 \
--ignore=./docs/notebooks/zernike_eval.ipynb \
--ignore=./docs/notebooks/tutorials/multi_device.ipynb \
--splits 3 \
--group ${{ matrix.group }} \
--splitting-algorithm least_duration
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ New Features
- Adds a new function ``desc.coils.initialize_helical_coils`` for creating an initial guess for stage 2 helical coil optimization.
- Adds ``desc.vmec_utils.make_boozmn_output `` for writing boozmn.nc style output files
for compatibility with other codes which expect such files from the Booz_Xform code.
- Adds initial support for multiple GPU optimization. This allows to compute derivatives on multiple GPU, and allows more memory intense objectives. Note that: at this phase, the multi-device support is for memory, not speed.
- Renames compute quantity ``sqrt(g)_B`` to ``sqrt(g)_Boozer_DESC`` to more accurately reflect what the quantiy is (the jacobian from (rho,theta_B,zeta_B) to (rho,theta,zeta)), and adds a new function to compute ``sqrt(g)_Boozer`` which is the jacobian from (rho,theta_B,zeta_B) to (R,phi,Z).
- Allows specification of Nyquist spectrum maximum modenumbers when using ``VMECIO.save`` to save a DESC .h5 file as a VMEC-format wout file
- Adds a new objective ``desc.objectives.ExternalObjective`` for wrapping external codes with finite differences.
Expand Down
173 changes: 129 additions & 44 deletions desc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import importlib
import os
import platform
import re
import subprocess
import warnings

import colorama
import psutil
from termcolor import colored

from ._version import get_versions
Expand Down Expand Up @@ -58,36 +61,111 @@ def __getattr__(name):
BANNER = colored(_BANNER, "magenta")


config = {"device": None, "avail_mem": None, "kind": None}
config = {"devices": None, "avail_mem": None, "kind": None, "num_device": None}


def set_device(kind="cpu", gpuid=None):
def _get_processor_name():
"""Get the processor name of the current system."""
if platform.system() == "Windows":
return platform.processor()
elif platform.system() == "Darwin":
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
command = "sysctl -n machdep.cpu.brand_string"
return subprocess.check_output(command).strip()
elif platform.system() == "Linux":
command = "cat /proc/cpuinfo"
all_info = subprocess.check_output(command, shell=True).decode().strip()
for line in all_info.split("\n"):
if "model name" in line:
return re.sub(".*model name.*:", "", line, 1)
return ""


def _set_cpu_count(n):
"""Set the number of CPUs visible to JAX.

By default, JAX sees the whole CPU as a single device, regardless of the number of
cores or threads. It then uses multiple cores and threads for lower level
parallelism within individual operations.

Alternatively, you can force JAX to expose a given number of "virtual" CPUs that
can then be used manually for higher level parallelism (as in at the level of
multiple objective functions.)

This function is mainly for testing on CI purposes of the parallelism in DESC.

Parameters
----------
n : int
Number of virtual CPUs for high level parallelism.

Notes
-----
This function must be called before importing anything else from DESC or JAX,
and before calling ``desc.set_device``, otherwise it will have no effect.
"""
xla_flags = os.getenv("XLA_FLAGS", "")
xla_flags = re.sub(
r"--xla_force_host_platform_device_count=\S+", "", xla_flags
).split()
os.environ["XLA_FLAGS"] = " ".join(
[f"--xla_force_host_platform_device_count={n}"] + xla_flags
)


def set_device(kind="cpu", gpuid=None, num_device=1): # noqa: C901
"""Sets the device to use for computation.

If kind==``'gpu'`` and a gpuid is specified, uses the specified GPU. If
gpuid==``None`` or a wrong GPU id is given, checks available GPUs and selects the
one with the most available memory.
Respects environment variable CUDA_VISIBLE_DEVICES for selecting from multiple
available GPUs
available GPUs.

Notes
-----
This function must be called before importing anything else from DESC or JAX,
otherwise it will have no effect.

Parameters
----------
kind : {``'cpu'``, ``'gpu'``}
whether to use CPU or GPU.
gpuid : int, optional
GPU id to use. Default is None. Supported only when num_device is 1.
num_device : int
number of devices to use. Default is 1.

"""
config["kind"] = kind
config["num_device"] = num_device

cpu_mem = psutil.virtual_memory().available / 1024**3 # RAM in GB
cpu_info = _get_processor_name()
config["cpu_info"] = f"{cpu_info} CPU"
config["cpu_mem"] = cpu_mem

if kind == "cpu":
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import psutil

cpu_mem = psutil.virtual_memory().available / 1024**3 # RAM in GB
config["device"] = "CPU"
config["avail_mem"] = cpu_mem
if num_device == 1:
config["devices"] = [f"{cpu_info} CPU"]
config["avail_mems"] = [cpu_mem]
else:
try:
import jax

jax_cpu = jax.devices("cpu")
assert len(jax_cpu) == num_device
config["devices"] = [f"{dev}" for dev in jax_cpu]
config["avail_mems"] = [cpu_mem for _ in range(num_device)]
except ModuleNotFoundError:
raise ValueError(
"JAX not installed. Please install JAX to use multiple CPUs."
"Alternatively, set num_device=1 to use a single CPU."
)

if kind == "gpu":
# Set CUDA_DEVICE_ORDER so the IDs assigned by CUDA match those from nvidia-smi
elif kind == "gpu":
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import nvgpu

Expand All @@ -100,55 +178,62 @@ def set_device(kind="cpu", gpuid=None):
set_device(kind="cpu")
return

maxmem = 0
selected_gpu = None
gpu_ids = [dev["index"] for dev in devices]
if "CUDA_VISIBLE_DEVICES" in os.environ:
cuda_ids = [
s for s in re.findall(r"\b\d+\b", os.environ["CUDA_VISIBLE_DEVICES"])
]
# check that the visible devices actually exist and are gpus
gpu_ids = [i for i in cuda_ids if i in gpu_ids]
if len(gpu_ids) == 0:
# cuda visible devices = '' -> don't use any gpu
warnings.warn(
colored(
(
"CUDA_VISIBLE_DEVICES={} ".format(
os.environ["CUDA_VISIBLE_DEVICES"]
)
+ "did not match any physical GPU "
+ "(id={}), falling back to CPU".format(
[dev["index"] for dev in devices]
)
),
f"CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']} did "
"not match any physical GPU "
f"(id={[dev['index'] for dev in devices]}), falling back to CPU",
"yellow",
)
)
set_device(kind="cpu")
return

devices = [dev for dev in devices if dev["index"] in gpu_ids]
memories = {dev["index"]: dev["mem_total"] - dev["mem_used"] for dev in devices}

if num_device == 1:
if gpuid is not None:
if str(gpuid) in gpu_ids:
selected_gpu = next(
dev for dev in devices if dev["index"] == str(gpuid)
)
else:
warnings.warn(
colored(
f"Specified gpuid {gpuid} not found, selecting GPU with "
"most memory",
"yellow",
)
)
else:
selected_gpu = max(
devices, key=lambda dev: dev["mem_total"] - dev["mem_used"]
)
devices = [selected_gpu]

if gpuid is not None and (str(gpuid) in gpu_ids):
selected_gpu = [dev for dev in devices if dev["index"] == str(gpuid)][0]
else:
for dev in devices:
mem = dev["mem_total"] - dev["mem_used"]
if mem > maxmem:
maxmem = mem
selected_gpu = dev
config["device"] = selected_gpu["type"] + " (id={})".format(
selected_gpu["index"]
)
if gpuid is not None and not (str(gpuid) in gpu_ids):
warnings.warn(
colored(
"Specified gpuid {} not found, falling back to ".format(str(gpuid))
+ config["device"],
"yellow",
if num_device > len(devices):
raise ValueError(
f"Requested {num_device} GPUs, but only {len(devices)} available"
)
)
config["avail_mem"] = (
selected_gpu["mem_total"] - selected_gpu["mem_used"]
) / 1024 # in GB
os.environ["CUDA_VISIBLE_DEVICES"] = str(selected_gpu["index"])
if gpuid is not None:
# TODO: implement multiple GPU selection
raise ValueError("Cannot specify `gpuid` when requesting multiple GPUs")

config["avail_mems"] = [
memories[dev["index"]] / 1024 for dev in devices[:num_device]
] # in GB
config["devices"] = [
f"{dev['type']} (id={dev['index']})" for dev in devices[:num_device]
]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
str(dev["index"]) for dev in devices[:num_device]
)
Loading
Loading