Skip to content

Commit

Permalink
Merge pull request #828 from linsword13/check-partition
Browse files Browse the repository at this point in the history
Add checks around partition config
  • Loading branch information
douglasjacobsen authored Jan 16, 2025
2 parents f74a69c + 14429e8 commit e313fff
Showing 1 changed file with 51 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def _execute_vars(self):
("#SBATCH -e {experiment_run_dir}/slurm-%j.err"),
("#SBATCH --gpus-per-node {gpus_per_node}"),
]
if expander.expand_var_name("slurm_partition"):
partition = expander.expand_var_name("slurm_partition")
self._check_partition(partition)
if partition:
pragmas.append("#SBATCH -p {slurm_partition}")
extra_headers = (
self.app_inst.variables["extra_sbatch_headers"].strip().split("\n")
Expand All @@ -129,6 +131,33 @@ def _execute_vars(self):
header_str = "\n".join(self.conditional_expand(pragmas))
return {"sbatch_headers_str": header_str}

def _check_partition(self, partition):
"""Warns about potential issues of the slurm_partition config
Only gives out warning as the user may be relying on a custom
execute template that contains the relevant partition info.
"""
try:
partition_prop = self.runner.get_partitions()
except RunnerError:
return
if partition_prop is None:
return

partitions = partition_prop["partitions"]
if partition not in partitions:
default_partition = partition_prop["default_partition"]
if default_partition is not None and not partition:
logger.info(
"`slurm_partition` is not given, "
f"using default partition {default_partition}"
)
else:
logger.warn(
"Missing valid `slurm_partition` setting. "
f"It should be one of {partitions}"
)

def get_status(self, workspace):
expander = self.app_inst.expander
run_dir = expander.expand_var_name("experiment_run_dir")
Expand All @@ -154,6 +183,7 @@ def __init__(self, dry_run=False):
self.dry_run = dry_run
self.squeue_runner = None
self.sacct_runner = None
self.sinfo_runner = None
self.run_dir = None

def _ensure_runner(self, runner_name: str):
Expand Down Expand Up @@ -190,3 +220,23 @@ def get_status(self, job_id):
sacct_args = ["-o", "state", "-X", "-n", "-j", job_id]
status_out = self.sacct_runner.command(*sacct_args, output=str)
return status_out.strip()

def get_partitions(self):
if self.dry_run:
return None
self._ensure_runner("sinfo")
sinfo_args = ["-h"]
out = self.sinfo_runner.command(*sinfo_args, output=str).strip()
partitions = set()
default_partition = None
for line in out.split("\n"):
info = line.split()
name = info[0].strip()
if name.endswith("*"):
name = name[:-1]
default_partition = name
partitions.add(name)
return {
"default_partition": default_partition,
"partitions": partitions,
}

0 comments on commit e313fff

Please sign in to comment.