diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py index bce03d3f5..786c7cc1e 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py +++ b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py @@ -120,7 +120,9 @@ def _execute_vars(self): ("#SBATCH -e {experiment_run_dir}/slurm-%j.err"), ("#SBATCH --gpus-per-node {gpus_per_node}"), ] - if expander.expand_var_name("slurm_partition"): + partition = expander.expand_var_name("slurm_partition") + self._check_partition(partition) + if partition: pragmas.append("#SBATCH -p {slurm_partition}") extra_headers = ( self.app_inst.variables["extra_sbatch_headers"].strip().split("\n") @@ -129,6 +131,33 @@ def _execute_vars(self): header_str = "\n".join(self.conditional_expand(pragmas)) return {"sbatch_headers_str": header_str} + def _check_partition(self, partition): + """Warns about potential issues of the slurm_partition config + + Only gives out warning as the user may be relying on a custom + execute template that contains the relevant partition info. + """ + try: + partition_prop = self.runner.get_partitions() + except RunnerError: + return + if partition_prop is None: + return + + partitions = partition_prop["partitions"] + if partition not in partitions: + default_partition = partition_prop["default_partition"] + if default_partition is not None and not partition: + logger.info( + "`slurm_partition` is not given, " + f"using default partition {default_partition}" + ) + else: + logger.warn( + "Missing valid `slurm_partition` setting. " + f"It should be one of {partitions}" + ) + def get_status(self, workspace): expander = self.app_inst.expander run_dir = expander.expand_var_name("experiment_run_dir") @@ -154,6 +183,7 @@ def __init__(self, dry_run=False): self.dry_run = dry_run self.squeue_runner = None self.sacct_runner = None + self.sinfo_runner = None self.run_dir = None def _ensure_runner(self, runner_name: str): @@ -190,3 +220,23 @@ def get_status(self, job_id): sacct_args = ["-o", "state", "-X", "-n", "-j", job_id] status_out = self.sacct_runner.command(*sacct_args, output=str) return status_out.strip() + + def get_partitions(self): + if self.dry_run: + return None + self._ensure_runner("sinfo") + sinfo_args = ["-h"] + out = self.sinfo_runner.command(*sinfo_args, output=str).strip() + partitions = set() + default_partition = None + for line in out.split("\n"): + info = line.split() + name = info[0].strip() + if name.endswith("*"): + name = name[:-1] + default_partition = name + partitions.add(name) + return { + "default_partition": default_partition, + "partitions": partitions, + }