diff --git a/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py b/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py index a2aae7605..2f907ab55 100644 --- a/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py +++ b/lib/ramble/ramble/test/workflow_manager_functionality/slurm_workflow_manager.py @@ -33,6 +33,7 @@ def test_slurm_workflow(): batch_submit: echo {wm_name} mpi_command: mpirun -n {n_ranks} -hostfile hostfile processes_per_node: 1 + n_nodes: 1 wm_name: ['None', 'slurm'] applications: hostname: @@ -41,8 +42,12 @@ def test_slurm_workflow(): experiments: test_{wm_name}: variables: - n_nodes: 1 - extra_sbatch_headers: "#SBATCH --gpus-per-task={n_threads}" + extra_sbatch_headers: | + #SBATCH --gpus-per-task={n_threads} + #SBATCH --time={time_limit_not_exist} + test_{wm_name}_2: + variables: + slurm_partition: h3 """ with ramble.workspace.create(workspace_name) as ws: ws.write() @@ -83,9 +88,17 @@ def test_slurm_workflow(): content = f.read() assert "scontrol show hostnames" in content assert "#SBATCH --gpus-per-task=1" in content + assert "#SBATCH -p" not in content + assert "#SBATCH --time" not in content with open(os.path.join(path, "batch_query")) as f: content = f.read() assert "squeue" in content with open(os.path.join(path, "batch_cancel")) as f: content = f.read() assert "scancel" in content + + # Assert on the experiment with non-empty partition variable given + path = os.path.join(ws.experiment_dir, "hostname", "local", "test_slurm_2") + with open(os.path.join(path, "slurm_execute_experiment")) as f: + content = f.read() + assert "#SBATCH -p h3" in content diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py index 105a1c855..9477befc0 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py +++ b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py @@ -9,7 +9,6 @@ import os from ramble.wmkit import * -from ramble.expander import ExpanderError from ramble.application import experiment_status from spack.util.executable import ProcessError @@ -78,6 +77,12 @@ def __init__(self, file_path): description="mpirun prefix, mostly served as an overridable default", ) + workflow_manager_variable( + name="slurm_partition", + default="", + description="partition to submit job to, if unspecified, it uses the default partition", + ) + register_template( name="batch_submit", src_name="batch_submit.tpl", @@ -109,21 +114,18 @@ def _execute_vars(self): # Adding pre-defined and custom headers pragmas = [ ("#SBATCH -N {n_nodes}"), - ("#SBATCH -p {partition}"), ("#SBATCH --ntasks-per-node {processes_per_node}"), ("#SBATCH -J {job_name}"), ("#SBATCH -o {experiment_run_dir}/slurm-%j.out"), ("#SBATCH -e {experiment_run_dir}/slurm-%j.err"), ("#SBATCH --gpus-per-node {gpus_per_node}"), ] - try: - extra_sbatch_headers_raw = expander.expand_var_name( - "extra_sbatch_headers", allow_passthrough=False - ) - extra_sbatch_headers = extra_sbatch_headers_raw.strip().split("\n") - pragmas = pragmas + extra_sbatch_headers - except ExpanderError: - pass + if expander.expand_var_name("slurm_partition"): + pragmas.append("#SBATCH -p {slurm_partition}") + extra_headers = ( + self.app_inst.variables["extra_sbatch_headers"].strip().split("\n") + ) + pragmas = pragmas + extra_headers header_str = "\n".join(self.conditional_expand(pragmas)) return {"sbatch_headers_str": header_str}