From bf8ae08d56255ed63724721174176a28333c556b Mon Sep 17 00:00:00 2001 From: Lin Guo Date: Tue, 17 Dec 2024 07:01:50 +0000 Subject: [PATCH] Move conditional expand logic to base class Also add a gpu-per-node entry --- lib/ramble/ramble/workflow_manager.py | 22 ++++++++++++++ .../slurm/workflow_manager.py | 30 +++++-------------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/lib/ramble/ramble/workflow_manager.py b/lib/ramble/ramble/workflow_manager.py index ce0e3376f..567dde991 100644 --- a/lib/ramble/ramble/workflow_manager.py +++ b/lib/ramble/ramble/workflow_manager.py @@ -14,6 +14,7 @@ from ramble.util.naming import NS_SEPARATOR import ramble.util.class_attributes import ramble.util.directives +from ramble.expander import ExpanderError class WorkflowManagerBase(metaclass=WorkflowManagerMeta): @@ -48,6 +49,27 @@ def get_status(self, workspace): """Return status of a given job""" return None + def conditional_expand(self, templates): + """Return a (potentially empty) list of expanded strings + + Args: + templates: A list of templates to expand. + If the template cannot be fully expanded, it's skipped. + Returns: + A list of expanded strings + """ + expander = self.app_inst.expander + expanded = [] + for tpl in templates: + try: + rendered = expander.expand_var(tpl, allow_passthrough=False) + if rendered: + expanded.append(rendered) + except ExpanderError: + # Skip a particular entry if any of the vars are not defined + continue + return expanded + def copy(self): """Deep copy a workflow manager instance""" new_copy = type(self)(self._file_path) diff --git a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py index ff528a6fb..b043c90c7 100644 --- a/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py +++ b/var/ramble/repos/builtin/workflow_managers/slurm/workflow_manager.py @@ -37,12 +37,6 @@ def __init__(self, file_path): self.runner = SlurmRunner() - workflow_manager_variable( - name="partition", - default="", - description="Name of the slurm partition for job submission", - ) - workflow_manager_variable( name="job_name", default="{application_name}_{workload_name}_{experiment_name}", @@ -96,31 +90,21 @@ def _slurm_execute_script(self): expander = self.app_inst.expander # Adding pre-defined sets of headers pragmas = [ - ("#SBATCH -N {}", "n_nodes"), - ("#SBATCH -p {}", "partition"), - ("#SBATCH --ntasks-per-node {}", "processes_per_node"), - ("#SBATCH -J {}", "job_name"), + ("#SBATCH -N {n_nodes}"), + ("#SBATCH -p {partition}"), + ("#SBATCH --ntasks-per-node {processes_per_node}"), + ("#SBATCH -J {job_name}"), + ("#SBATCH --gpus-per-node {gpus_per_node}"), ] - for tpl, var in pragmas: - try: - val = expander.expand_var_name(var, allow_passthrough=False) - except ExpanderError: - # Skip a particular header if any of the vars are not defined - continue - if val: - headers.append(tpl.format(val)) - # Adding extra arbitrary headers try: extra_sbatch_headers_raw = expander.expand_var_name( "extra_sbatch_headers", allow_passthrough=False ) extra_sbatch_headers = extra_sbatch_headers_raw.strip().split("\n") - extra_headers = [ - expander.expand_var(h) for h in extra_sbatch_headers - ] - headers = headers + extra_headers + pragmas = pragmas + extra_sbatch_headers except ExpanderError: pass + headers = headers + self.conditional_expand(pragmas) header_str = "\n".join(headers) content = rf""" {header_str}