Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Nvidia GPU Support #57

Merged
merged 15 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions charms/sackd/charmcraft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ parts:
charm:
charm-binary-python-packages:
- cryptography ~= 44.0.0
- jsonschema ~= 4.23.0

provides:
slurmctld:
Expand Down
1 change: 1 addition & 0 deletions charms/slurmctld/charmcraft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ parts:
charm:
charm-binary-python-packages:
- cryptography ~= 44.0.0
- jsonschema ~= 4.23.0
- pydantic

config:
Expand Down
2 changes: 1 addition & 1 deletion charms/slurmctld/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
ops==2.17.1
slurmutils~=0.9.0
slurmutils<1.0.0,>=0.11.0
58 changes: 55 additions & 3 deletions charms/slurmctld/src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
WaitingStatus,
main,
)
from slurmutils.models import CgroupConfig, SlurmConfig
from slurmutils.models import CgroupConfig, GRESConfig, GRESNode, SlurmConfig

from charms.grafana_agent.v0.cos_agent import COSAgentProvider
from charms.hpc_libs.v0.is_container import is_container
Expand Down Expand Up @@ -87,8 +87,8 @@ def __init__(self, *args):
self._slurmdbd.on.slurmdbd_unavailable: self._on_slurmdbd_unavailable,
self._slurmd.on.partition_available: self._on_write_slurm_conf,
self._slurmd.on.partition_unavailable: self._on_write_slurm_conf,
self._slurmd.on.slurmd_available: self._on_write_slurm_conf,
self._slurmd.on.slurmd_departed: self._on_write_slurm_conf,
self._slurmd.on.slurmd_available: self._on_slurmd_available,
self._slurmd.on.slurmd_departed: self._on_slurmd_departed,
self._slurmrestd.on.slurmrestd_available: self._on_slurmrestd_available,
self.on.show_current_config_action: self._on_show_current_config_action,
self.on.drain_action: self._on_drain_nodes_action,
Expand Down Expand Up @@ -214,6 +214,58 @@ def _on_resume_nodes_action(self, event: ActionEvent) -> None:
except subprocess.CalledProcessError as e:
event.fail(message=f"Error resuming {nodes}: {e.output}")

def _on_slurmd_available(self, event: SlurmdAvailableEvent) -> None:
self._add_to_gres_conf(event)
self._on_write_slurm_conf(event)

def _on_slurmd_departed(self, event: SlurmdDepartedEvent) -> None:
# Lack of map between departing unit and NodeName complicates removal of node from gres.conf.
# Instead, rewrite full gres.conf with data from remaining units.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
self._write_gres_conf(event)
self._on_write_slurm_conf(event)

def _add_to_gres_conf(self, event: SlurmdAvailableEvent) -> None:
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
"""Write new nodes to gres.conf configuration file for Generic Resource scheduling."""
# This function does not perform an "scontrol reconfigure". It is expected
# _on_write_slurm_conf() is called immediately following to do this.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved

# Only the leader should write the config.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
if not self.model.unit.is_leader():
return

if not self._check_status():
event.defer()
return

if gres_info := event.gres_info:
# Build list of GRESNodes expected by slurmutils
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
gres_nodes = []
for resource in gres_info:
node = GRESNode(NodeName=str(event.node_name), **resource)
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
gres_nodes.append(node)

# Update gres.conf
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
with self._slurmctld.gres.edit() as config:
config.nodes[event.node_name] = gres_nodes

def _write_gres_conf(self, event: SlurmdDepartedEvent) -> None:
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
"""Write out current gres.conf configuration file for Generic Resource scheduling."""
# This function does not perform an "scontrol reconfigure". It is expected
# _on_write_slurm_conf() is called immediately following to do this.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved

# Only the leader should write the config.
if not self.model.unit.is_leader():
return

if not self._check_status():
event.defer()
return

# Get current GRES state for all available nodes and write to gres.conf.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
gres_all_nodes = self._slurmd.get_gres()
gres_conf = GRESConfig(Nodes=gres_all_nodes)
self._slurmctld.gres.dump(gres_conf)

def _on_write_slurm_conf(
self,
event: Union[
Expand Down
44 changes: 43 additions & 1 deletion charms/slurmctld/src/interface_slurmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ class PartitionUnavailableEvent(EventBase):
class SlurmdAvailableEvent(EventBase):
"""Emitted when the slurmd unit joins the relation."""

def __init__(self, handle, node_name, gres_info=None):
super().__init__(handle)
self.node_name = node_name
self.gres_info = gres_info

def snapshot(self):
"""Snapshot the event data."""
return {
"node_name": self.node_name,
"gres_info": self.gres_info,
}

def restore(self, snapshot):
"""Restore the snapshot of the event data."""
self.node_name = snapshot.get("node_name")
self.gres_info = snapshot.get("gres_info")
dsloanm marked this conversation as resolved.
Show resolved Hide resolved


class SlurmdDepartedEvent(EventBase):
"""Emitted when one slurmd departs."""
Expand Down Expand Up @@ -124,7 +141,10 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None:
if node_config := node.get("node_parameters"):
if node_name := node_config.get("NodeName"):
self._charm.new_nodes = list(set(self._charm.new_nodes + [node_name]))
self.on.slurmd_available.emit()
self.on.slurmd_available.emit(
node_name=node_name, gres_info=node.get("gres")
)
logger.debug(f"_on_relation_changed node_config = {node_config}")
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
else:
logger.debug(f"`node` data does not exist for unit: {unit}.")
else:
Expand Down Expand Up @@ -245,3 +265,25 @@ def get_new_nodes_and_nodes_and_partitions(self) -> Dict[str, Any]:
else []
)
return {"DownNodes": new_node_down_nodes, "Nodes": nodes, "Partitions": partitions}

def get_gres(self) -> Dict[str, Any]:
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
"""Return GRES configuration for all currently related compute nodes."""
# Loop over all relation units, gathering GRES info.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
gres_info = {}
if relations := self.framework.model.relations.get(self._relation_name):
for relation in relations:
for unit in relation.units:

if node := self._get_node_from_relation(relation, unit):
# Ignore nodes without GRES devices
if (gres := node.get("gres")) and (
node_config := node.get("node_parameters")
):

node_name = node_config["NodeName"]
# slurmutils expects NodeName in values.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
for device in gres:
device["NodeName"] = node_name
gres_info[node_name] = gres

return gres_info
1 change: 1 addition & 0 deletions charms/slurmd/charmcraft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ parts:
charm:
charm-binary-python-packages:
- cryptography ~= 44.0.0
- jsonschema ~= 4.23.0
nhc:
plugin: nil
build-packages:
Expand Down
87 changes: 84 additions & 3 deletions charms/slurmd/src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

"""Slurmd Operator Charm."""

import itertools
import logging
from pathlib import Path
from typing import Any, Dict, cast

from interface_slurmctld import Slurmctld, SlurmctldAvailableEvent
Expand All @@ -21,7 +23,7 @@
main,
)
from slurmutils.models.option import NodeOptionSet, PartitionOptionSet
from utils import machine, nhc, service
from utils import gpu, machine, nhc, service

from charms.hpc_libs.v0.slurm_ops import SlurmdManager, SlurmOpsError
from charms.operator_libs_linux.v0.juju_systemd_notices import ( # type: ignore[import-untyped]
Expand Down Expand Up @@ -74,11 +76,18 @@ def __init__(self, *args, **kwargs):

def _on_install(self, event: InstallEvent) -> None:
"""Perform installation operations for slurmd."""
# Account for case where base image has been auto-upgraded by Juju and a reboot is pending
# before charm code runs. Reboot "now", before the current hook completes, and restart the
# hook after reboot. Prevents issues such as drivers/kernel modules being installed for a
# running kernel pending replacement by a newer version on reboot.
self._reboot_if_required(now=True)

self.unit.status = WaitingStatus("installing slurmd")

try:
self._slurmd.install()
nhc.install()
gpu.autoinstall()
self.unit.set_workload_version(self._slurmd.version())
# TODO: https://github.com/orgs/charmed-hpc/discussions/10 -
# Evaluate if we should continue doing the service override here
Expand All @@ -92,6 +101,7 @@ def _on_install(self, event: InstallEvent) -> None:
event.defer()

self._check_status()
self._reboot_if_required()

def _on_config_changed(self, _: ConfigChangedEvent) -> None:
"""Handle charm configuration changes."""
Expand Down Expand Up @@ -214,7 +224,7 @@ def _on_show_nhc_config(self, event: ActionEvent) -> None:
event.set_results({"nhc.conf": "/etc/nhc/nhc.conf not found."})

def _on_node_config_action_event(self, event: ActionEvent) -> None:
"""Get or set the user_supplied_node_conifg.
"""Get or set the user_supplied_node_config.

Return the node config if the `node-config` parameter is not specified, otherwise
parse, validate, and store the input of the `node-config` parameter in stored state.
Expand Down Expand Up @@ -321,15 +331,86 @@ def _check_status(self) -> bool:

return True

def _reboot_if_required(self, now: bool = False) -> None:
"""Perform a reboot of the unit if required, e.g. following a driver installation."""
if Path("/var/run/reboot-required").exists():
logger.info("unit rebooting")
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
self.unit.reboot(now)

@staticmethod
def _ranges_and_strides(nums) -> str:
"""Return ranges and strides for given iterable.

Requires input elements to be unique and sorted ascending.

Returns:
A square-bracketed string with comma-separated ranges of consecutive values.

example_input = [0,1,2,3,4,5,6,8,9,10,12,14,15,16,18]
example_output = '[0-6,8-10,12,14-16,18]'
"""
out = "["

# The input is enumerate()-ed to produce a list of tuples of the elements and their indices.
# groupby() uses the lambda key function to group these tuples by the difference between the element and index.
# Consecutive values have equal difference between element and index, so are grouped together.
# Hence, the elements of the first and last members of each group give the range of consecutive values.
# If the group has only a single member, there are no consecutive values either side of it (a "stride").
for _, group in itertools.groupby(enumerate(nums), lambda elems: elems[1] - elems[0]):
group = list(group)

if len(group) == 1:
# Single member, this is a stride.
out += f"{group[0][1]},"
else:
# Range of consecutive values is first-last in group.
out += f"{group[0][1]}-{group[-1][1]},"

out = out.rstrip(",") + "]"
return out
dsloanm marked this conversation as resolved.
Show resolved Hide resolved

def get_node(self) -> Dict[Any, Any]:
"""Get the node from stored state."""
slurmd_info = machine.get_slurmd_info()

# Get GPU info and build GRES configuration.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
gres_info = []
if gpus := gpu.get_gpus():
for model, devices in gpus.items():
# Build gres.conf line for this GPU model.
if len(devices) == 1:
device_suffix = next(iter(devices))
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
else:
# For multi-gpu setups, "File" uses ranges and strides syntax,
# e.g. File=/dev/nvidia[0-3], File=/dev/nvidia[0,2-3]
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
device_suffix = self._ranges_and_strides(sorted(devices))
gres_line = {
# NodeName included in node_parameters.
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
"Name": "gpu",
"Type": model,
"File": f"/dev/nvidia{device_suffix}",
}
# Append to list of GRES lines for all models
dsloanm marked this conversation as resolved.
Show resolved Hide resolved
gres_info.append(gres_line)

# Add to node parameters to ensure included in slurm.conf.
slurm_conf_gres = f"gpu:{model}:{len(devices)}"
try:
# Add to existing Gres line.
slurmd_info["Gres"].append(slurm_conf_gres)
except KeyError:
# Create a new Gres entry if none present
slurmd_info["Gres"] = [slurm_conf_gres]
dsloanm marked this conversation as resolved.
Show resolved Hide resolved

node = {
"node_parameters": {
**machine.get_slurmd_info(),
**slurmd_info,
"MemSpecLimit": "1024",
**self._user_supplied_node_parameters,
},
"new_node": self._new_node,
# Do not include GRES configuration if no GPUs detected.
**({"gres": gres_info} if len(gres_info) > 0 else {}),
}
logger.debug(f"Node Configuration: {node}")
return node
Expand Down
Loading