Skip to content

Commit

Permalink
Add the ability to define template-specific extra variables
Browse files Browse the repository at this point in the history
The ability to generate dynamic variables is often needed when defining
more complicated templates (for instance, programmatically generate
slurm job headers.) Having this template-specific option to generate
extra rendering vars reduces the need to ingest (more global-scoped)
variables via `define_variables` via phase hooks.
  • Loading branch information
linsword13 committed Jan 8, 2025
1 parent 418a9a6 commit ef0296c
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
10 changes: 7 additions & 3 deletions lib/ramble/ramble/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,7 +2270,7 @@ def _get_template_config(
break
if not found:
raise ApplicationError(f"Object {obj.name} is missing template file at {src_path}")
return {**tpl_config, "src_path": src_path}
return (obj, {**tpl_config, "src_path": src_path})

for tpl_config in self.templates.values():
yield _get_template_config(self, tpl_config)
Expand All @@ -2289,10 +2289,14 @@ def _get_template_config(

def _render_object_templates(self, extra_vars):
run_dir = self.expander.experiment_run_dir
for tpl_config in self._object_templates():
for obj, tpl_config in self._object_templates():
src_path = tpl_config["src_path"]
with open(src_path) as f_in:
content = f_in.read()
extra_vars_func_name = tpl_config.get("extra_vars_func_name")
if extra_vars_func_name is not None:
extra_vars_func = getattr(obj, extra_vars_func_name)
extra_vars.update(extra_vars_func())
rendered = self.expander.expand_var(content, extra_vars=extra_vars)
out_path = os.path.join(run_dir, tpl_config["dest_name"])
perm = tpl_config.get("content_perm", _DEFAULT_CONTENT_PERM)
Expand All @@ -2303,7 +2307,7 @@ def _render_object_templates(self, extra_vars):

def _define_object_template_vars(self):
run_dir = self.expander.experiment_run_dir
for tpl_config in self._object_templates():
for _, tpl_config in self._object_templates():
var_name = tpl_config["var_name"]
if var_name is not None:
path = os.path.join(run_dir, tpl_config["dest_name"])
Expand Down
13 changes: 12 additions & 1 deletion lib/ramble/ramble/language/shared_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# option. This file may not be copied, modified, or distributed
# except according to those terms.

from typing import Optional

import ramble.language.language_base
import ramble.language.language_helpers
import ramble.success_criteria
Expand Down Expand Up @@ -480,7 +482,12 @@ def _execute_target_shells(obj):

@shared_directive("templates")
def register_template(
name: str, src_name: str, dest_name: str, define_var: bool = True, output_perm=None
name: str,
src_name: str,
dest_name: str,
define_var: bool = True,
extra_vars_func: Optional[str] = None,
output_perm=None,
):
"""Directive to define an object-specific template to be rendered into experiment run_dir.
Expand All @@ -498,15 +505,19 @@ def register_template(
dest_name: The leaf name of the rendered output under the experiment
run directory.
define_var: Controls if a variable named `name` should be defined.
extra_vars_func: If present, the name of the function to call to return
a dict of extra variables used to render the template.
output_perm: The chmod mask for the rendered output file.
"""

def _define_template(obj):
var_name = name if define_var else None
extra_vars_func_name = f"_{extra_vars_func}" if extra_vars_func is not None else None
obj.templates[name] = {
"src_name": src_name,
"dest_name": dest_name,
"var_name": var_name,
"extra_vars_func_name": extra_vars_func_name,
"output_perm": output_perm,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,14 @@ class Template(ExecutableApplication):
workload="test_template",
)

register_phase(
"ingest_dynamic_variables",
pipeline="setup",
run_before=["make_experiments"],
register_template(
name="bar",
src_name="bar.tpl",
dest_name="bar.sh",
extra_vars_func="bar_vars",
)

def _ingest_dynamic_variables(self, workspace, app_inst):
def _bar_vars(self):
expander = self.expander
val = expander.expand_var('"hello {hello_name}"')
self.define_variable("dynamic_hello_world", val)

register_template("bar", src_name="bar.tpl", dest_name="bar.sh")
return {"dynamic_hello_world": val}

0 comments on commit ef0296c

Please sign in to comment.