From e09bc23523d64582ad9b96eb265eb391e1682289 Mon Sep 17 00:00:00 2001 From: Lin Guo Date: Mon, 6 Jan 2025 16:01:39 -0800 Subject: [PATCH] Search templates up the object inheritance chain --- lib/ramble/ramble/application.py | 26 +++++++++++--- .../ramble/test/end_to_end/test_template.py | 36 +++++++++++++++++++ .../template-inherited/application.py | 24 +++++++++++++ 3 files changed, 81 insertions(+), 5 deletions(-) create mode 100644 var/ramble/repos/builtin.mock/applications/template-inherited/application.py diff --git a/lib/ramble/ramble/application.py b/lib/ramble/ramble/application.py index 450843212..a98d66f6b 100644 --- a/lib/ramble/ramble/application.py +++ b/lib/ramble/ramble/application.py @@ -2275,9 +2275,19 @@ def evaluate_success(self): def _object_templates(self): """Return templates defined from different objects associated with the app_inst""" - def _get_template_config(obj, tpl_config): - src_path = os.path.join(os.path.dirname(obj._file_path), tpl_config["src_name"]) - if not os.path.isfile(src_path): + def _get_template_config( + obj, tpl_config, obj_type=ramble.repository.ObjectTypes.applications + ): + found = False + # Search up the object chain + object_paths = [e[1] for e in ramble.repository.list_object_files(obj, obj_type)] + src_name = tpl_config["src_name"] + for obj_path in object_paths: + src_path = os.path.join(os.path.dirname(obj_path), src_name) + if os.path.isfile(src_path): + found = True + break + if not found: raise ApplicationError(f"Object {obj.name} is missing template file at {src_path}") return {**tpl_config, "src_path": src_path} @@ -2285,10 +2295,16 @@ def _get_template_config(obj, tpl_config): yield _get_template_config(self, tpl_config) for mod in self._modifier_instances: for tpl_config in mod.templates.values(): - yield _get_template_config(mod, tpl_config) + yield _get_template_config( + mod, tpl_config, obj_type=ramble.repository.ObjectTypes.modifiers + ) if self.package_manager is not None: for tpl_config in self.package_manager.templates.values(): - yield _get_template_config(self.package_manager, tpl_config) + yield _get_template_config( + self.package_manager, + tpl_config, + obj_type=ramble.repository.ObjectTypes.package_managers, + ) def _render_object_templates(self, extra_vars): run_dir = self.expander.experiment_run_dir diff --git a/lib/ramble/ramble/test/end_to_end/test_template.py b/lib/ramble/ramble/test/end_to_end/test_template.py index ac73b3193..a94b25a08 100644 --- a/lib/ramble/ramble/test/end_to_end/test_template.py +++ b/lib/ramble/ramble/test/end_to_end/test_template.py @@ -55,3 +55,39 @@ def test_template(): with open(execute_path) as f: content = f.read() assert script_path in content + + +def test_template_inherited(): + test_config = """ +ramble: + variables: + mpi_command: mpirun -n {n_ranks} + batch_submit: 'batch_submit {execute_experiment}' + processes_per_node: 1 + n_nodes: 1 + applications: + template-inherited: + workloads: + test_template: + experiments: + test: {} +""" + workspace_name = "test_template_inherited" + ws = ramble.workspace.create(workspace_name) + ws.write() + config_path = os.path.join(ws.config_dir, ramble.workspace.config_file_name) + with open(config_path, "w+") as f: + f.write(test_config) + ws._re_read() + + workspace("setup", "--dry-run", global_args=["-w", workspace_name]) + run_dir = os.path.join(ws.experiment_dir, "template-inherited/test_template/test/") + script_path = os.path.join(run_dir, "bar.sh") + assert os.path.isfile(script_path) + with open(script_path) as f: + content = f.read() + assert "echo hello world-inherited" in content + execute_path = os.path.join(run_dir, "execute_experiment") + with open(execute_path) as f: + content = f.read() + assert script_path in content diff --git a/var/ramble/repos/builtin.mock/applications/template-inherited/application.py b/var/ramble/repos/builtin.mock/applications/template-inherited/application.py new file mode 100644 index 000000000..4882d9df7 --- /dev/null +++ b/var/ramble/repos/builtin.mock/applications/template-inherited/application.py @@ -0,0 +1,24 @@ +# Copyright 2022-2025 The Ramble Authors +# +# Licensed under the Apache License, Version 2.0 or the MIT license +# , at your +# option. This file may not be copied, modified, or distributed +# except according to those terms. + +from ramble.appkit import * + +from ramble.app.builtin.mock.template import Template as TemplateBase + + +class TemplateInherited(TemplateBase): + """An app for testing object templates inheritance.""" + + name = "template-inherited" + + workload_variable( + "hello_name", + default="world-inherited", + description="hello name", + workload="test_template", + )