Skip to content

Commit

Permalink
refactor(slurmd): improve comments and function names. Simplify
Browse files Browse the repository at this point in the history
try/catch blocks.

Add better support for comma-separated values in `get_slurmd_info`
  • Loading branch information
dsloanm committed Jan 8, 2025
1 parent 5575440 commit 5ae35e1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 48 deletions.
23 changes: 7 additions & 16 deletions charms/slurmd/src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _check_status(self) -> bool:
def _reboot_if_required(self, now: bool = False) -> None:
"""Perform a reboot of the unit if required, e.g. following a driver installation."""
if Path("/var/run/reboot-required").exists():
logger.info("unit rebooting")
logger.info("rebooting unit %s", self.unit.name)
self.unit.reboot(now)

@staticmethod
Expand Down Expand Up @@ -373,34 +373,25 @@ def get_node(self) -> Dict[Any, Any]:
"""Get the node from stored state."""
slurmd_info = machine.get_slurmd_info()

# Get GPU info and build GRES configuration.
gres_info = []
if gpus := gpu.get_gpus():
if gpus := gpu.get_all_gpu():
for model, devices in gpus.items():
# Build gres.conf line for this GPU model.
if len(devices) == 1:
device_suffix = next(iter(devices))
else:
# For multi-gpu setups, "File" uses ranges and strides syntax,
# e.g. File=/dev/nvidia[0-3], File=/dev/nvidia[0,2-3]
# Get numeric range of devices associated with this GRES resource. See:
# https://slurm.schedmd.com/gres.conf.html#OPT_File
device_suffix = self._ranges_and_strides(sorted(devices))
gres_line = {
# NodeName included in node_parameters.
"Name": "gpu",
"Type": model,
"File": f"/dev/nvidia{device_suffix}",
}
# Append to list of GRES lines for all models
gres_info.append(gres_line)

# Add to node parameters to ensure included in slurm.conf.
slurm_conf_gres = f"gpu:{model}:{len(devices)}"
try:
# Add to existing Gres line.
slurmd_info["Gres"].append(slurm_conf_gres)
except KeyError:
# Create a new Gres entry if none present
slurmd_info["Gres"] = [slurm_conf_gres]
slurmd_info["Gres"] = cast(list[str], slurmd_info.get("Gres", [])) + [
f"gpu:{model}:{len(devices)}"
]

node = {
"node_parameters": {
Expand Down
30 changes: 7 additions & 23 deletions charms/slurmd/src/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
_logger = logging.getLogger(__name__)


class GPUInstallError(Exception):
class GPUOpsError(Exception):
"""Exception raised when a GPU driver installation operation failed."""


Expand All @@ -36,7 +36,7 @@ def __init__(self):
try:
apt.add_package(pkgs, update_cache=True)
except (apt.PackageNotFoundError, apt.PackageError) as e:
raise GPUInstallError(f"failed to install {pkgs} reason: {e}")
raise GPUOpsError(f"failed to install {pkgs} reason: {e}")

# ubuntu-drivers requires apt_pkg for package operations
self._detect = _import("UbuntuDrivers.detect")
Expand All @@ -54,9 +54,8 @@ def _get_linux_modules_metapackage(self, driver) -> str:
"""
return self._detect.get_linux_modules_metapackage(self._apt_pkg.Cache(None), driver)

def system_packages(self) -> list:
def system_packages(self) -> list[str]:
"""Return a list of GPU drivers and kernel module packages for this node."""
# Detect only GPGPU drivers. Not general purpose graphics drivers.
packages = self._system_gpgpu_driver_packages()

# Gather list of driver and kernel modules to install.
Expand All @@ -76,19 +75,14 @@ def system_packages(self) -> list:
# Add to list of packages to install
install_packages += [driver_metapackage, modules_metapackage]

# TODO: do we want to check for nvidia here and add nvidia-fabricmanager-535 libnvidia-nscq-535 in case of nvlink? This is suggested as a manual step at https://documentation.ubuntu.com/server/how-to/graphics/install-nvidia-drivers/#optional-step. If so, how do we get the version number "-535" robustly?

# TODO: what if drivers install but do not require a reboot? Should we "modprobe nvidia" manually? Just always reboot regardless?

# Filter out any empty results as returning
return [p for p in install_packages if p]


def autoinstall() -> None:
"""Autodetect available GPUs and install drivers.
Raises:
GPUInstallError: Raised if error is encountered during package install.
GPUOpsError: Raised if error is encountered during package install.
"""
_logger.info("detecting GPUs and installing drivers")
detector = GPUDriverDetector()
Expand All @@ -102,10 +96,10 @@ def autoinstall() -> None:
try:
apt.add_package(install_packages)
except (apt.PackageNotFoundError, apt.PackageError) as e:
raise GPUInstallError(f"failed to install packages {install_packages}. reason: {e}")
raise GPUOpsError(f"failed to install packages {install_packages}. reason: {e}")


def get_gpus() -> dict:
def get_all_gpu() -> dict[str, set[int]]:
"""Get the GPU devices on this node.
Returns:
Expand All @@ -119,14 +113,12 @@ def get_gpus() -> dict:
"""
gpu_info = {}

# Return immediately if pynvml not installed...
try:
pynvml = _import("pynvml")
except ModuleNotFoundError:
_logger.info("cannot gather GPU info: pynvml module not installed")
return gpu_info

# ...or Nvidia drivers not loaded.
try:
pynvml.nvmlInit()
except pynvml.NVMLError as e:
Expand All @@ -135,7 +127,6 @@ def get_gpus() -> dict:
return gpu_info

gpu_count = pynvml.nvmlDeviceGetCount()
# Loop over all detected GPUs, gathering info by model.
for i in range(gpu_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)

Expand All @@ -146,15 +137,8 @@ def get_gpus() -> dict:
model = pynvml.nvmlDeviceGetName(handle)
model = "_".join(model.split()).lower()

# Number for device path, e.g. if device is /dev/nvidia0, returns 0
minor_number = pynvml.nvmlDeviceGetMinorNumber(handle)

try:
# Add minor number to set of existing numbers for this model.
gpu_info[model].add(minor_number)
except KeyError:
# This is the first time we've seen this model. Create a new entry.
gpu_info[model] = {minor_number}
gpu_info[model] = gpu_info.get(model, set()) | {minor_number}

pynvml.nvmlShutdown()
return gpu_info
Expand Down
21 changes: 12 additions & 9 deletions charms/slurmd/src/utils/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@

import logging
import subprocess
from typing import Any, Dict

_logger = logging.getLogger(__name__)


def get_slurmd_info() -> Dict[str, Any]:
"""Get machine info as reported by `slurmd -C`."""
def get_slurmd_info() -> dict[str, str | list[str]]:
"""Get machine info as reported by `slurmd -C`.
For details see:
https://slurm.schedmd.com/slurmd.html
"""
try:
r = subprocess.check_output(["slurmd", "-C"], text=True).strip()
except subprocess.CalledProcessError as e:
Expand All @@ -31,10 +34,10 @@ def get_slurmd_info() -> Dict[str, Any]:

info = {}
for opt in r.split()[:-1]:
key, value = opt.split("=")
# Split comma-separated lists, e.g. Gres=gpu:model_a:1,gpu:model_b:1
if "," in value:
info[key] = value.split(",")
else:
info[key] = value
k, v = opt.split("=")
if k == "Gres":
info[k] = v.split(",")
continue

info[k] = v
return info

0 comments on commit 5ae35e1

Please sign in to comment.