From ef0296c743467bb8d9e6e1c29bb6deca0057c0a8 Mon Sep 17 00:00:00 2001 From: Lin Guo Date: Wed, 8 Jan 2025 10:17:50 -0800 Subject: [PATCH] Add the ability to define template-specific extra variables The ability to generate dynamic variables is often needed when defining more complicated templates (for instance, programmatically generate slurm job headers.) Having this template-specific option to generate extra rendering vars reduces the need to ingest (more global-scoped) variables via `define_variables` via phase hooks. --- lib/ramble/ramble/application.py | 10 +++++++--- lib/ramble/ramble/language/shared_language.py | 13 ++++++++++++- .../applications/template/application.py | 15 +++++++-------- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/lib/ramble/ramble/application.py b/lib/ramble/ramble/application.py index fbe9f2e33..e6d14a754 100644 --- a/lib/ramble/ramble/application.py +++ b/lib/ramble/ramble/application.py @@ -2270,7 +2270,7 @@ def _get_template_config( break if not found: raise ApplicationError(f"Object {obj.name} is missing template file at {src_path}") - return {**tpl_config, "src_path": src_path} + return (obj, {**tpl_config, "src_path": src_path}) for tpl_config in self.templates.values(): yield _get_template_config(self, tpl_config) @@ -2289,10 +2289,14 @@ def _get_template_config( def _render_object_templates(self, extra_vars): run_dir = self.expander.experiment_run_dir - for tpl_config in self._object_templates(): + for obj, tpl_config in self._object_templates(): src_path = tpl_config["src_path"] with open(src_path) as f_in: content = f_in.read() + extra_vars_func_name = tpl_config.get("extra_vars_func_name") + if extra_vars_func_name is not None: + extra_vars_func = getattr(obj, extra_vars_func_name) + extra_vars.update(extra_vars_func()) rendered = self.expander.expand_var(content, extra_vars=extra_vars) out_path = os.path.join(run_dir, tpl_config["dest_name"]) perm = tpl_config.get("content_perm", _DEFAULT_CONTENT_PERM) @@ -2303,7 +2307,7 @@ def _render_object_templates(self, extra_vars): def _define_object_template_vars(self): run_dir = self.expander.experiment_run_dir - for tpl_config in self._object_templates(): + for _, tpl_config in self._object_templates(): var_name = tpl_config["var_name"] if var_name is not None: path = os.path.join(run_dir, tpl_config["dest_name"]) diff --git a/lib/ramble/ramble/language/shared_language.py b/lib/ramble/ramble/language/shared_language.py index 10da695f7..9f386ddae 100644 --- a/lib/ramble/ramble/language/shared_language.py +++ b/lib/ramble/ramble/language/shared_language.py @@ -6,6 +6,8 @@ # option. This file may not be copied, modified, or distributed # except according to those terms. +from typing import Optional + import ramble.language.language_base import ramble.language.language_helpers import ramble.success_criteria @@ -480,7 +482,12 @@ def _execute_target_shells(obj): @shared_directive("templates") def register_template( - name: str, src_name: str, dest_name: str, define_var: bool = True, output_perm=None + name: str, + src_name: str, + dest_name: str, + define_var: bool = True, + extra_vars_func: Optional[str] = None, + output_perm=None, ): """Directive to define an object-specific template to be rendered into experiment run_dir. @@ -498,15 +505,19 @@ def register_template( dest_name: The leaf name of the rendered output under the experiment run directory. define_var: Controls if a variable named `name` should be defined. + extra_vars_func: If present, the name of the function to call to return + a dict of extra variables used to render the template. output_perm: The chmod mask for the rendered output file. """ def _define_template(obj): var_name = name if define_var else None + extra_vars_func_name = f"_{extra_vars_func}" if extra_vars_func is not None else None obj.templates[name] = { "src_name": src_name, "dest_name": dest_name, "var_name": var_name, + "extra_vars_func_name": extra_vars_func_name, "output_perm": output_perm, } diff --git a/var/ramble/repos/builtin.mock/applications/template/application.py b/var/ramble/repos/builtin.mock/applications/template/application.py index cf231ad55..6493f8183 100644 --- a/var/ramble/repos/builtin.mock/applications/template/application.py +++ b/var/ramble/repos/builtin.mock/applications/template/application.py @@ -25,15 +25,14 @@ class Template(ExecutableApplication): workload="test_template", ) - register_phase( - "ingest_dynamic_variables", - pipeline="setup", - run_before=["make_experiments"], + register_template( + name="bar", + src_name="bar.tpl", + dest_name="bar.sh", + extra_vars_func="bar_vars", ) - def _ingest_dynamic_variables(self, workspace, app_inst): + def _bar_vars(self): expander = self.expander val = expander.expand_var('"hello {hello_name}"') - self.define_variable("dynamic_hello_world", val) - - register_template("bar", src_name="bar.tpl", dest_name="bar.sh") + return {"dynamic_hello_world": val}