From 888d042bfe3cb2862944e9ecd0da11aa03d56a9e Mon Sep 17 00:00:00 2001 From: Edan Bainglass Date: Thu, 28 Nov 2024 07:02:00 +0000 Subject: [PATCH] Add missing features in resource settings tabs --- src/aiidalab_qe/app/configuration/__init__.py | 2 +- src/aiidalab_qe/app/submission/__init__.py | 66 +++--- .../app/submission/global_settings/model.py | 195 ++++++++---------- .../app/submission/global_settings/setting.py | 190 ++++++----------- src/aiidalab_qe/app/submission/model.py | 47 +++-- src/aiidalab_qe/common/code/model.py | 56 +++-- src/aiidalab_qe/common/mixins.py | 9 +- src/aiidalab_qe/common/panel.py | 184 +++++++++++------ src/aiidalab_qe/plugins/bands/__init__.py | 2 +- src/aiidalab_qe/plugins/bands/code.py | 28 +-- src/aiidalab_qe/plugins/pdos/__init__.py | 2 +- src/aiidalab_qe/plugins/pdos/code.py | 38 ++-- src/aiidalab_qe/plugins/xas/__init__.py | 2 +- src/aiidalab_qe/plugins/xas/code.py | 28 +-- tests/conftest.py | 25 +-- tests/test_codes.py | 43 ++-- tests/test_submit_qe_workchain.py | 21 +- .../test_create_builder_default.yml | 2 - 18 files changed, 471 insertions(+), 469 deletions(-) diff --git a/src/aiidalab_qe/app/configuration/__init__.py b/src/aiidalab_qe/app/configuration/__init__.py index 5551da9b1..b9775ea0b 100644 --- a/src/aiidalab_qe/app/configuration/__init__.py +++ b/src/aiidalab_qe/app/configuration/__init__.py @@ -56,7 +56,7 @@ def __init__(self, model: ConfigurationStepModel, **kwargs): lambda structure: "" if structure else """ -
+
Please set the input structure first.
""", diff --git a/src/aiidalab_qe/app/submission/__init__.py b/src/aiidalab_qe/app/submission/__init__.py index d9176b9ab..ba457fe72 100644 --- a/src/aiidalab_qe/app/submission/__init__.py +++ b/src/aiidalab_qe/app/submission/__init__.py @@ -39,10 +39,6 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs): self._on_submission, "confirmed", ) - self._model.observe( - self._on_input_structure_change, - "input_structure", - ) self._model.observe( self._on_input_parameters_change, "input_parameters", @@ -77,22 +73,28 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs): self.rendered = False - global_code_model = GlobalResourceSettingsModel() - self.global_code_settings = GlobalResourceSettingsPanel(model=global_code_model) - self._model.add_model("global", global_code_model) - global_code_model.observe( + global_resources_model = GlobalResourceSettingsModel() + self.global_resources = GlobalResourceSettingsPanel( + model=global_resources_model + ) + self._model.add_model("global", global_resources_model) + ipw.dlink( + (self._model, "plugin_overrides"), + (global_resources_model, "plugin_overrides"), + ) + global_resources_model.observe( self._on_plugin_submission_blockers_change, ["submission_blockers"], ) - global_code_model.observe( + global_resources_model.observe( self._on_plugin_submission_warning_messages_change, ["submission_warning_messages"], ) self.settings = { - "global": self.global_code_settings, + "global": self.global_resources, } - self._fetch_plugin_settings() + self._fetch_plugin_resource_settings() self._install_sssp(qe_auto_setup) self._set_up_qe(qe_auto_setup) @@ -211,14 +213,15 @@ def _on_tab_change(self, change): tab: ResourceSettingsPanel = self.tabs.children[tab_index] # type: ignore tab.render() - def _on_input_structure_change(self, _): - """""" - def _on_input_parameters_change(self, _): - self._model.update_active_models() - self._update_tabs() self._model.update_process_label() + self._model.update_plugin_inclusion() + self._model.update_plugin_overrides() self._model.update_submission_blockers() + self._update_tabs() + + def _on_plugin_overrides_change(self, _): + self._model.update_plugin_overrides() def _on_plugin_submission_blockers_change(self, _): self._model.update_submission_blockers() @@ -237,16 +240,13 @@ def _on_submission_blockers_change(self, _): self._model.update_submission_blocker_message() self._update_state() - def _on_submission_warning_change(self, _): - self._model.update_submission_warning_message() - def _on_installation_change(self, _): self._model.update_submission_blockers() def _on_qe_installed(self, _): self._toggle_qe_installation_widget() if self._model.qe_installed: - self._model.refresh_codes() + self._model.update() def _on_sssp_installed(self, _): self._toggle_sssp_installation_widget() @@ -325,14 +325,19 @@ def _update_state(self, _=None): else: self.state = self.state.CONFIGURED - def _fetch_plugin_settings(self): - eps = get_entry_items("aiidalab_qe.properties", "code") - for identifier, data in eps.items(): + def _fetch_plugin_resource_settings(self): + entries = get_entry_items("aiidalab_qe.properties", "resources") + for identifier, resources in entries.items(): for key in ("panel", "model"): - if key not in data: + if key not in resources: raise ValueError(f"Entry {identifier} is missing the '{key}' key") - panel = data["panel"] - model: ResourceSettingsModel = data["model"]() + + panel = resources["panel"] + model: ResourceSettingsModel = resources["model"]() + model.observe( + self._on_plugin_overrides_change, + "override", + ) model.observe( self._on_plugin_submission_blockers_change, ["submission_blockers"], @@ -343,15 +348,6 @@ def _fetch_plugin_settings(self): ) self._model.add_model(identifier, model) - def toggle_plugin(_, model=model): - model.update() - self._update_tabs() - - model.observe( - toggle_plugin, - "include", - ) - self.settings[identifier] = panel( identifier=identifier, model=model, diff --git a/src/aiidalab_qe/app/submission/global_settings/model.py b/src/aiidalab_qe/app/submission/global_settings/model.py index 0345cc6ab..88c072181 100644 --- a/src/aiidalab_qe/app/submission/global_settings/model.py +++ b/src/aiidalab_qe/app/submission/global_settings/model.py @@ -5,14 +5,11 @@ import traitlets as tl from aiida import orm -from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS from aiidalab_qe.common.code import CodeModel, PwCodeModel from aiidalab_qe.common.mixins import HasInputStructure from aiidalab_qe.common.panel import ResourceSettingsModel from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget -DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore - class GlobalResourceSettingsModel( ResourceSettingsModel, @@ -20,6 +17,8 @@ class GlobalResourceSettingsModel( ): """Model for the global code setting.""" + identifier = "global" + dependencies = [ "input_parameters", "input_structure", @@ -27,33 +26,14 @@ class GlobalResourceSettingsModel( input_parameters = tl.Dict() - codes = tl.Dict( - key_trait=tl.Unicode(), # code name - value_trait=tl.Instance(CodeModel), # code metadata - ) - # this is a copy of the codes trait, which is used to trigger the update of the plugin - global_codes = tl.Dict( - key_trait=tl.Unicode(), # code name - value_trait=tl.Dict(), # code metadata - ) - - plugin_mapping = tl.Dict( - key_trait=tl.Unicode(), # plugin identifier - value_trait=tl.List(tl.Unicode()), # list of code names - ) - - submission_blockers = tl.List(tl.Unicode()) - submission_warning_messages = tl.Unicode("") + plugin_overrides = tl.List(tl.Unicode()) + plugin_overrides_notification = tl.Unicode("") include = True def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Used by the code-setup thread to fetch code options - # This is necessary to avoid passing the User object - # between session in separate threads. - self._default_user_email = orm.User.collection.get_default().email self._RUN_ON_LOCALHOST_NUM_SITES_WARN_THRESHOLD = 10 self._RUN_ON_LOCALHOST_VOLUME_WARN_THRESHOLD = 1000 # \AA^3 @@ -66,120 +46,85 @@ def __init__(self, *args, **kwargs):
""" - def refresh_codes(self): - for _, code_model in self.codes.items(): - code_model.update(self._default_user_email) # type: ignore + self.plugin_mapping: dict[str, list[str]] = {} + + self.override = True + + def update(self): + for _, code_model in self.get_models(): + code_model.update(self.DEFAULT_USER_EMAIL) + + def update_global_codes(self): + self.global_codes = self.get_model_state()["codes"] def update_active_codes(self): - for name, code_model in self.codes.items(): - if name != "quantumespresso.pw": + for identifier, code_model in self.get_models(): + if identifier != "quantumespresso.pw": code_model.deactivate() properties = self._get_properties() for identifier, code_names in self.plugin_mapping.items(): if identifier in properties: for code_name in code_names: - self.codes[code_name].activate() + self.get_model(code_name).activate() - def get_model_state(self): - codes = {name: model.get_model_state() for name, model in self.codes.items()} - - return {"codes": codes} - - def set_model_state(self, code_data: dict): - for name, code_model in self.codes.items(): - if name in code_data and code_model.is_active: - code_model.set_model_state(code_data[name]) + def update_plugin_overrides_notification(self): + if self.plugin_overrides: + formatted = "\n".join( + f"
  • {plugin}
  • " for plugin in self.plugin_overrides + ) + self.plugin_overrides_notification = f""" +
    + Currently overriding computational resources for: + +
    + """ + else: + self.plugin_overrides_notification = "" - def add_code(self, identifier: str, code: CodeModel) -> CodeModel | None: - """Add a code to the codes trait.""" - code_model = None - default_calc_job_plugin = code.default_calc_job_plugin + def add_global_model( + self, + identifier: str, + code_model: CodeModel, + ) -> CodeModel | None: + """Registers a code with this model.""" + base_code_model = None + default_calc_job_plugin = code_model.default_calc_job_plugin name = default_calc_job_plugin.split(".")[-1] - if default_calc_job_plugin not in self.codes: + + if not self.has_model(default_calc_job_plugin): if default_calc_job_plugin == "quantumespresso.pw": - code_model = PwCodeModel( + base_code_model = PwCodeModel( name=name, description=name, default_calc_job_plugin=default_calc_job_plugin, ) else: - code_model = CodeModel( + base_code_model = CodeModel( name=name, description=name, default_calc_job_plugin=default_calc_job_plugin, ) - self.codes[default_calc_job_plugin] = code_model - # update the plugin mapping to keep track of which codes are associated with which plugins + self.add_model(default_calc_job_plugin, base_code_model) + if identifier not in self.plugin_mapping: self.plugin_mapping[identifier] = [default_calc_job_plugin] else: self.plugin_mapping[identifier].append(default_calc_job_plugin) - return code_model - - def get_code(self, name) -> CodeModel | None: - if name in self.codes: # type: ignore - return self.codes[name] # type: ignore - - def get_selected_codes(self) -> dict[str, dict]: - return { - name: code_model.get_model_state() - for name, code_model in self.codes.items() - if code_model.is_ready - } - - def set_selected_codes(self, code_data=DEFAULT["codes"]): - for name, code_model in self.codes.items(): - if name in code_data and code_model.is_active: - code_model.set_model_state(code_data[name]) - - def reset(self): - """Reset the model to its default state.""" - for code_model in self.codes.values(): - code_model.reset() - - def _get_properties(self) -> list[str]: - return self.input_parameters.get("workchain", {}).get("properties", []) - - def update_submission_blockers(self): - self.submission_blockers = list(self._check_submission_blockers()) - - def _check_submission_blockers(self): - # No pw code selected (this is ignored while the setup process is running). - pw_code = self.get_code("quantumespresso.pw") - if pw_code and not pw_code.selected and not self.installing_qe: - yield ("No pw code selected") - - # code related to the selected property is not installed - properties = self._get_properties() - message = "Calculating the {property} property requires code {code} to be set." - for identifier, code_names in self.plugin_mapping.items(): - if identifier in properties: - for name in code_names: - code = self.get_code(name) - if not code.is_ready: - yield message.format(property=identifier, code=code.description) - - # check if the QEAppComputationalResourcesWidget is used - for name, code in self.codes.items(): - # skip if the code is not displayed, convenient for the plugin developer - if not code.is_ready: - continue - if not issubclass( - code.code_widget_class, QEAppComputationalResourcesWidget - ): - yield ( - f"Error: hi, plugin developer, please use the QEAppComputationalResourcesWidget from aiidalab_qe.common.widgets for code {name}." - ) + return base_code_model def check_resources(self): - pw_code = self.get_code("quantumespresso.pw") + pw_code_model = self.get_model("quantumespresso.pw") - if not self.input_structure or not pw_code.selected: + if not self.input_structure or not pw_code_model.selected: return # No code selected or no structure, so nothing to do - num_cpus = pw_code.num_cpus * pw_code.num_nodes - on_localhost = orm.load_node(pw_code.selected).computer.hostname == "localhost" + num_cpus = pw_code_model.num_cpus * pw_code_model.num_nodes + on_localhost = ( + orm.load_node(pw_code_model.selected).computer.hostname == "localhost" + ) num_sites = len(self.input_structure.sites) volume = self.input_structure.get_cell_volume() @@ -254,6 +199,40 @@ def check_resources(self): ) ) + def _get_properties(self) -> list[str]: + return self.input_parameters.get("workchain", {}).get("properties", []) + + def _check_submission_blockers(self): + # No pw code selected + pw_code_model = self.get_model("quantumespresso.pw") + if pw_code_model and not pw_code_model.selected: + yield ("No pw code selected") + + # Code related to the selected property is not installed + properties = self._get_properties() + message = "Calculating the {property} property requires code {code} to be set." + for identifier, code_names in self.plugin_mapping.items(): + if identifier in properties: + for name in code_names: + code_model = self.get_model(name) + if not code_model.is_ready: + yield message.format( + property=name, + code=code_model.description, + ) + + # Check if the QEAppComputationalResourcesWidget is used + for identifier, code_model in self.get_models(): + # Skip if the code is not displayed, convenient for the plugin developer + if not code_model.is_ready: + continue + if not issubclass( + code_model.code_widget_class, QEAppComputationalResourcesWidget + ): + yield ( + f"Error: hi, plugin developer, please use the QEAppComputationalResourcesWidget from aiidalab_qe.common.widgets for code {identifier}." + ) + def _estimate_min_cpus( self, n, diff --git a/src/aiidalab_qe/app/submission/global_settings/setting.py b/src/aiidalab_qe/app/submission/global_settings/setting.py index a2df3bbc2..6007226c9 100644 --- a/src/aiidalab_qe/app/submission/global_settings/setting.py +++ b/src/aiidalab_qe/app/submission/global_settings/setting.py @@ -7,42 +7,47 @@ import ipywidgets as ipw -from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS from aiidalab_qe.app.utils import get_entry_items from aiidalab_qe.common.code import CodeModel, PluginCodes, PwCodeModel -from aiidalab_qe.common.panel import ResourceSettingsPanel -from aiidalab_qe.common.widgets import ( - LoadingWidget, - PwCodeResourceSetupWidget, - QEAppComputationalResourcesWidget, -) +from aiidalab_qe.common.panel import ResourceSettingsModel, ResourceSettingsPanel +from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget from .model import GlobalResourceSettingsModel -DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore - class GlobalResourceSettingsPanel(ResourceSettingsPanel[GlobalResourceSettingsModel]): + title = "Global resources" identifier = "global" def __init__(self, model: GlobalResourceSettingsModel, **kwargs): super().__init__(model, **kwargs) - self._set_up_codes() + + self._model.observe( + self._on_input_structure_change, + "input_structure", + ) self._model.observe( self._on_input_parameters_change, "input_parameters", ) self._model.observe( - self._on_input_structure_change, - "input_structure", + self._on_plugin_overrides_change, + "plugin_overrides", ) + self._fetch_plugin_codes() + def render(self): if self.rendered: return self.code_widgets_container = ipw.VBox() - self.code_widgets = {} + + self.plugin_overrides_notification = ipw.HTML() + ipw.dlink( + (self._model, "plugin_overrides_notification"), + (self.plugin_overrides_notification, "value"), + ) self.children = [ ipw.HTML(""" @@ -51,12 +56,14 @@ def render(self): """), self.code_widgets_container, + self.plugin_overrides_notification, ] self.rendered = True + # Render any active codes - self._model.get_code("quantumespresso.pw").activate() - for code_model in self._model.codes.values(): + self._model.get_model("quantumespresso.pw").activate() + for _, code_model in self._model.get_models(): if code_model.is_active: self._toggle_code(code_model) @@ -71,120 +78,28 @@ def _on_input_parameters_change(self, _): def _on_input_structure_change(self, _): self._model.check_resources() + def _on_plugin_overrides_change(self, _): + self._model.update_plugin_overrides_notification() + def _on_code_activation_change(self, change): self._toggle_code(change["owner"]) def _on_code_selection_change(self, _): - """""" self._model.update_submission_blockers() def _on_pw_code_resource_change(self, _): self._model.check_resources() def _on_code_resource_change(self, _): - """Update the plugin resources.""" - # trigger the update of global codes - self._model.global_codes = {} - self._model.global_codes = self._model.get_model_state()["codes"] - - def _set_up_codes(self): - codes: PluginCodes = { - "dft": { - "pw": PwCodeModel( - description="pw.x", - default_calc_job_plugin="quantumespresso.pw", - code_widget_class=PwCodeResourceSetupWidget, - ), - }, - } - # Load codes from plugins - eps = get_entry_items("aiidalab_qe.properties", "code") - for identifier, data in eps.items(): - codes[identifier] = data["model"].codes - for identifier, code_models in codes.items(): - for _, code_model in code_models.items(): - # use the new code model created using the global code model - base_code_model = self._model.add_code(identifier, code_model) - if base_code_model is not None: - base_code_model.observe( - self._on_code_activation_change, - "is_active", - ) - base_code_model.observe( - self._on_code_selection_change, - "selected", - ) - - def _toggle_code(self, code_model: CodeModel): - if not self.rendered: - return - if not code_model.is_rendered: - loading_message = LoadingWidget(f"Loading {code_model.name} code") - self.code_widgets_container.children += (loading_message,) - if code_model.name not in self.code_widgets: - code_widget = code_model.code_widget_class( - description=code_model.description, - default_calc_job_plugin=code_model.default_calc_job_plugin, - ) - self.code_widgets[code_model.name] = code_widget - else: - code_widget = self.code_widgets[code_model.name] - code_widget.layout.display = "block" if code_model.is_active else "none" - if not code_model.is_rendered: - self._render_code_widget(code_model, code_widget) + self._model.update_global_codes() def _render_code_widget( self, code_model: CodeModel, code_widget: QEAppComputationalResourcesWidget, ): - code_model.update(None) - ipw.dlink( - (code_model, "options"), - (code_widget.code_selection.code_select_dropdown, "options"), - ) - ipw.link( - (code_model, "selected"), - (code_widget.code_selection.code_select_dropdown, "value"), - ) - code_widget.code_selection.code_select_dropdown.observe( - self._on_code_selection_change, - "value", - ) - ipw.dlink( - (code_model, "selected"), - (code_widget.code_selection.code_select_dropdown, "disabled"), - lambda selected: not selected, - ) - ipw.link( - (code_model, "num_cpus"), - (code_widget.num_cpus, "value"), - ) - ipw.link( - (code_model, "num_nodes"), - (code_widget.num_nodes, "value"), - ) - ipw.link( - (code_model, "ntasks_per_node"), - (code_widget.resource_detail.ntasks_per_node, "value"), - ) - ipw.link( - (code_model, "cpus_per_task"), - (code_widget.resource_detail.cpus_per_task, "value"), - ) - ipw.link( - (code_model, "max_wallclock_seconds"), - (code_widget.resource_detail.max_wallclock_seconds, "value"), - ) + super()._render_code_widget(code_model, code_widget) if code_model.default_calc_job_plugin == "quantumespresso.pw": - ipw.link( - (code_model, "override"), - (code_widget.parallelization.override, "value"), - ) - ipw.link( - (code_model, "npool"), - (code_widget.parallelization.npool, "value"), - ) code_model.observe( self._on_pw_code_resource_change, [ @@ -195,16 +110,45 @@ def _render_code_widget( "max_wallclock_seconds", ], ) + + def update_options(_, model=code_model): + model.update(self._model.DEFAULT_USER_EMAIL, refresh=True) + + code_widget.code_selection.code_select_dropdown.observe( + update_options, + "options", + ) + + def toggle_widget(_=None, model=code_model, widget=code_widget): + widget = self.code_widgets[model.name] + widget.layout.display = "block" if model.is_active else "none" + code_model.observe( - self._on_code_resource_change, - [ - "num_cpus", - "num_nodes", - "ntasks_per_node", - "cpus_per_task", - "max_wallclock_seconds", - ], + toggle_widget, + "is_active", ) - code_widgets = self.code_widgets_container.children[:-1] # type: ignore - self.code_widgets_container.children = [*code_widgets, code_widget] - code_model.is_rendered = True + + toggle_widget() + + def _fetch_plugin_codes(self): + codes: PluginCodes = { + "dft": { + "pw": PwCodeModel(), + }, + } + entries = get_entry_items("aiidalab_qe.properties", "resources") + for identifier, resources in entries.items(): + resource_model: ResourceSettingsModel = resources["model"]() + codes[identifier] = dict(resource_model.get_models()) + for identifier, code_models in codes.items(): + for _, code_model in code_models.items(): + base_code_model = self._model.add_global_model(identifier, code_model) + if base_code_model is not None: + base_code_model.observe( + self._on_code_activation_change, + "is_active", + ) + base_code_model.observe( + self._on_code_selection_change, + "selected", + ) diff --git a/src/aiidalab_qe/app/submission/model.py b/src/aiidalab_qe/app/submission/model.py index c47c93ed6..d424c51d5 100644 --- a/src/aiidalab_qe/app/submission/model.py +++ b/src/aiidalab_qe/app/submission/model.py @@ -40,6 +40,8 @@ class SubmissionStepModel( internal_submission_blockers = tl.List(tl.Unicode()) external_submission_blockers = tl.List(tl.Unicode()) + plugin_overrides = tl.List(tl.Unicode()) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -72,18 +74,9 @@ def confirm(self): # Once submitted, nothing should unconfirm the model! self.unobserve_all("confirmed") - def refresh_codes(self): + def update(self): for _, model in self.get_models(): - model.refresh_codes() - - def update_active_models(self): - for identifier, model in self.get_models(): - if identifier in ["global"]: - continue - if identifier not in self._get_properties(): - model.include = False - else: - model.include = True + model.update() def update_process_label(self): if not self.input_structure: @@ -112,29 +105,41 @@ def update_process_label(self): label = f"{structure_label} [{relax_info}, {protocol_and_magnetic_info}] {properties_info}".strip() self.process_label = label + def update_plugin_inclusion(self): + properties = self._get_properties() + for identifier, model in self.get_models(): + if identifier in self._default_models: + continue + model.include = identifier in properties + + def update_plugin_overrides(self): + self.plugin_overrides = [ + identifier + for identifier, model in self.get_models() + if identifier != "global" and model.include and model.override + ] + def update_submission_blockers(self): submission_blockers = list(self._check_submission_blockers()) for _, model in self.get_models(): - if hasattr(model, "submission_blockers"): - submission_blockers += model.submission_blockers + submission_blockers += model.submission_blockers self.internal_submission_blockers = submission_blockers def update_submission_warnings(self): submission_warning_messages = self._check_submission_warnings() for _, model in self.get_models(): - if hasattr(model, "submission_warning_messages"): - submission_warning_messages += model.submission_warning_messages + submission_warning_messages += model.submission_warning_messages self.submission_warning_messages = submission_warning_messages def update_submission_blocker_message(self): blockers = self.internal_submission_blockers + self.external_submission_blockers if any(blockers): - fmt_list = "\n".join(f"
  • {item}
  • " for item in sorted(blockers)) + formatted = "\n".join(f"
  • {item}
  • " for item in blockers) self.submission_blocker_messages = f""" -
    +
    The submission is blocked due to the following reason(s):
      - {fmt_list} + {formatted}
    """ @@ -178,8 +183,8 @@ def set_model_state(self, parameters): def get_selected_codes(self) -> dict[str, dict]: return { - name: code_model.get_model_state() - for name, code_model in self.get_model("global").codes.items() + identifier: code_model.get_model_state() + for identifier, code_model in self.get_model("global").get_models() if code_model.is_ready } @@ -252,11 +257,9 @@ def _create_builder(self, parameters) -> ProcessBuilderNamespace: return builder def _check_submission_blockers(self): - # Do not submit while any of the background setup processes are running. if self.installing_qe or self.installing_sssp: yield "Background setup processes must finish." - # SSSP library not installed if not self.sssp_installed: yield "The SSSP library is not installed." diff --git a/src/aiidalab_qe/common/code/model.py b/src/aiidalab_qe/common/code/model.py index 4e40cc871..90b761575 100644 --- a/src/aiidalab_qe/common/code/model.py +++ b/src/aiidalab_qe/common/code/model.py @@ -24,6 +24,7 @@ class CodeModel(Model): max_wallclock_seconds = tl.Int(3600 * 12) allow_hidden_codes = tl.Bool(False) allow_disabled_computers = tl.Bool(False) + override = tl.Bool(False) def __init__( self, @@ -48,19 +49,24 @@ def __init__( def is_ready(self): return self.is_active and bool(self.selected) + @property + def first_option(self): + return self.options[0][1] if self.options else None # type: ignore + def activate(self): self.is_active = True def deactivate(self): self.is_active = False - def update(self, user_email: str): - if not self.options: + def update(self, user_email="", refresh=False): + if not self.options or refresh: self.options = self._get_codes(user_email) - self.selected = self.options[0][1] if self.options else None + self.selected = self.first_option def get_model_state(self) -> dict: return { + "options": self.options, "code": self.selected, "nodes": self.num_nodes, "cpus": self.num_cpus, @@ -69,8 +75,12 @@ def get_model_state(self) -> dict: "max_wallclock_seconds": self.max_wallclock_seconds, } - def set_model_state(self, parameters): - self.selected = self._get_uuid(parameters["code"]) + def set_model_state(self, parameters: dict): + self.selected = ( + self._get_uuid(identifier) + if (identifier := parameters.get("code")) + else self.first_option + ) self.num_nodes = parameters.get("nodes", 1) self.num_cpus = parameters.get("cpus", 1) self.ntasks_per_node = parameters.get("ntasks_per_node", 1) @@ -78,19 +88,15 @@ def set_model_state(self, parameters): self.max_wallclock_seconds = parameters.get("max_wallclock_seconds", 3600 * 12) def _get_uuid(self, identifier): - if not self.selected: - try: - uuid = orm.load_code(identifier).uuid - except NotExistent: - uuid = None - # If the code was imported from another user, it is not usable - # in the app and thus will not be considered as an option! - self.selected = uuid if uuid in [opt[1] for opt in self.options] else None - return self.selected - - def _get_codes(self, user_email: str): - # set default user_email if not provided - user_email = user_email or orm.User.collection.get_default().email + try: + uuid = orm.load_code(identifier).uuid + except NotExistent: + uuid = None + # If the code was imported from another user, it is not usable + # in the app and thus will not be considered as an option! + return uuid if uuid in [opt[1] for opt in self.options] else None + + def _get_codes(self, user_email: str = ""): user = orm.User.collection.get(email=user_email) filters = ( @@ -122,7 +128,7 @@ def _full_code_label(code): class PwCodeModel(CodeModel): - override = tl.Bool(False) + parallelization_override = tl.Bool(False) npool = tl.Int(1) def __init__( @@ -142,14 +148,22 @@ def __init__( def get_model_state(self) -> dict: parameters = super().get_model_state() - parameters["parallelization"] = {"npool": self.npool} if self.override else {} + parameters["parallelization"] = ( + { + "npool": self.npool, + } + if self.parallelization_override + else {} + ) return parameters def set_model_state(self, parameters): super().set_model_state(parameters) if "parallelization" in parameters and "npool" in parameters["parallelization"]: - self.override = True + self.parallelization_override = True self.npool = parameters["parallelization"].get("npool", 1) + else: + self.parallelization_override = False CodesDict = dict[str, CodeModel] diff --git a/src/aiidalab_qe/common/mixins.py b/src/aiidalab_qe/common/mixins.py index 21421c233..bd710a8df 100644 --- a/src/aiidalab_qe/common/mixins.py +++ b/src/aiidalab_qe/common/mixins.py @@ -31,12 +31,19 @@ class HasModels(t.Generic[T]): def __init__(self): self._models: dict[str, T] = {} + def has_model(self, identifier): + return identifier in self._models + def add_model(self, identifier, model): self._models[identifier] = model self._link_model(model) + def add_models(self, models: dict[str, T]): + for identifier, model in models.items(): + self.add_model(identifier, model) + def get_model(self, identifier) -> T: - if identifier in self._models: + if self.has_model(identifier): return self._models[identifier] raise ValueError(f"Model with identifier '{identifier}' not found.") diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py index 3cca22ab0..0a5a913a4 100644 --- a/src/aiidalab_qe/common/panel.py +++ b/src/aiidalab_qe/common/panel.py @@ -15,8 +15,9 @@ from aiida import orm from aiida.common.extendeddicts import AttributeDict +from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS from aiidalab_qe.common.code.model import CodeModel -from aiidalab_qe.common.mixins import Confirmable, HasProcess +from aiidalab_qe.common.mixins import Confirmable, HasModels, HasProcess from aiidalab_qe.common.mvc import Model from aiidalab_qe.common.widgets import ( LoadingWidget, @@ -24,7 +25,7 @@ QEAppComputationalResourcesWidget, ) -DEFAULT_PARAMETERS = {} +DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore class Panel(ipw.VBox): @@ -90,18 +91,12 @@ class SettingsModel(Model): _defaults = {} - def update(self, specific=""): - """Updates the model. - - Parameters - ---------- - `specific` : `str`, optional - If provided, specifies the level of update. - """ + def update(self): + """Updates the model.""" pass def get_model_state(self) -> dict: - """Retrieves the model current state as a dictionary.""" + """Retrieves the current state of the model as a dictionary.""" raise NotImplementedError() def set_model_state(self, parameters: dict): @@ -118,7 +113,7 @@ def reset(self): class SettingsPanel(Panel, t.Generic[SM]): title = "Settings" - description = "" + identifier = "" def __init__(self, model: SM, **kwargs): from aiidalab_qe.common.widgets import LoadingWidget @@ -209,11 +204,12 @@ def _reset(self): self._model.reset() -class ResourceSettingsModel(SettingsModel): +class ResourceSettingsModel(SettingsModel, HasModels[CodeModel]): """Base model for plugin code setting models.""" - dependencies = ["global.global_codes"] - codes = {} # To be defined by subclasses + dependencies = [ + "global.global_codes", + ] global_codes = tl.Dict( key_trait=tl.Unicode(), @@ -228,38 +224,66 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Used by the code-setup thread to fetch code options - self._default_user_email = orm.User.collection.get_default().email + self.DEFAULT_USER_EMAIL = orm.User.collection.get_default().email - def refresh_codes(self): - for _, code_model in self.codes.items(): - code_model.update(self._default_user_email) + def update(self): + """Updates the code models from the global resources. - def update_code_from_global(self): - # Skip the sync if the user has overridden the settings + Skips synchronization with global resources if the user has chosen to override + the resources for the plugin codes. + """ if self.override: return - for _, code_model in self.codes.items(): + for _, code_model in self.get_models(): + code_model.update(self.DEFAULT_USER_EMAIL) default_calc_job_plugin = code_model.default_calc_job_plugin if default_calc_job_plugin in self.global_codes: - code_data = self.global_codes[default_calc_job_plugin] - code_model.set_model_state(code_data) + code_resources: dict = self.global_codes[default_calc_job_plugin] # type: ignore + options = code_resources.get("options", []) + if options != code_model.options: + code_model.update(self.DEFAULT_USER_EMAIL, refresh=True) + code_model.set_model_state(code_resources) + + def update_submission_blockers(self): + self.submission_blockers = list(self._check_submission_blockers()) def get_model_state(self): - codes = {name: model.get_model_state() for name, model in self.codes.items()} return { - "codes": codes, - "override": self.override, + "codes": { + identifier: code_model.get_model_state() + for identifier, code_model in self.get_models() + }, } - def set_model_state(self, code_data: dict): - for name, code_model in self.codes.items(): - if name in code_data: - code_model.set_model_state(code_data[name]) + def set_model_state(self, parameters: dict): + for name, code_model in self.get_models(): + if name in parameters and code_model.is_active: + code_model.set_model_state(parameters[name]) + + def get_selected_codes(self) -> dict[str, dict]: + return { + identifier: code_model.get_model_state() + for identifier, code_model in self.get_models() + if code_model.is_ready + } + + def set_selected_codes(self, code_data=DEFAULT["codes"]): + for identifier, code_model in self.get_models(): + if identifier in code_data and code_model.is_active: + code_model.set_model_state(code_data[identifier]) def reset(self): - """Reset the model to its default state.""" - for code_model in self.codes.values(): - code_model.reset() + """If not overridden, updates the model w.r.t the global resources.""" + self.update() + + def _check_submission_blockers(self): + return [] + + def _link_model(self, model: CodeModel): + tl.link( + (self, "override"), + (model, "override"), + ) RSM = t.TypeVar("RSM", bound=ResourceSettingsModel) @@ -270,8 +294,7 @@ class ResourceSettingsPanel(SettingsPanel[RSM], t.Generic[RSM]): def __init__(self, model, **kwargs): super().__init__(model, **kwargs) - self.code_widgets = {} - self.rendered = False + self._model.observe( self._on_global_codes_change, "global_codes", @@ -281,9 +304,12 @@ def __init__(self, model, **kwargs): "override", ) + self.code_widgets = {} + def render(self): if self.rendered: return + self.override_help = ipw.HTML( "Click to override the resource settings for this plugin." ) @@ -297,35 +323,33 @@ def render(self): (self.override, "value"), ) self.code_widgets_container = ipw.VBox() - self.code_widgets = {} + self.children = [ - ipw.HBox([self.override, self.override_help]), + ipw.HBox( + children=[ + self.override, + self.override_help, + ] + ), self.code_widgets_container, ] self.rendered = True - for code_model in self._model.codes.values(): + # Render any active codes + for _, code_model in self._model.get_models(): self._toggle_code(code_model) + return self.code_widgets_container def _on_global_codes_change(self, _): - self._model.update_code_from_global() + self._model.update() def _on_code_resource_change(self, _): - """Update the submission blockers and warning messages.""" - - def _on_override_change(self, change): - if change["new"]: - for code_widget in self.code_widgets.values(): - code_widget.num_nodes.disabled = False - code_widget.num_cpus.disabled = False - code_widget.code_selection.code_select_dropdown.disabled = False - else: - for code_widget in self.code_widgets.values(): - code_widget.num_nodes.disabled = True - code_widget.num_cpus.disabled = True - code_widget.code_selection.code_select_dropdown.disabled = True + pass + + def _on_override_change(self, _): + self._model.reset() def _toggle_code(self, code_model: CodeModel): if not self.rendered: @@ -349,7 +373,6 @@ def _render_code_widget( code_model: CodeModel, code_widget: QEAppComputationalResourcesWidget, ): - code_model.update(None) ipw.dlink( (code_model, "options"), (code_widget.code_selection.code_select_dropdown, "options"), @@ -359,18 +382,28 @@ def _render_code_widget( (code_widget.code_selection.code_select_dropdown, "value"), ) ipw.dlink( - (code_model, "selected"), + (code_model, "override"), (code_widget.code_selection.code_select_dropdown, "disabled"), - lambda selected: not selected, + lambda override: not override, ) ipw.link( (code_model, "num_cpus"), (code_widget.num_cpus, "value"), ) + ipw.dlink( + (code_model, "override"), + (code_widget.num_cpus, "disabled"), + lambda override: not override, + ) ipw.link( (code_model, "num_nodes"), (code_widget.num_nodes, "value"), ) + ipw.dlink( + (code_model, "override"), + (code_widget.num_nodes, "disabled"), + lambda override: not override, + ) ipw.link( (code_model, "ntasks_per_node"), (code_widget.resource_detail.ntasks_per_node, "value"), @@ -383,18 +416,47 @@ def _render_code_widget( (code_model, "max_wallclock_seconds"), (code_widget.resource_detail.max_wallclock_seconds, "value"), ) + ipw.dlink( + (code_model, "override"), + (code_widget.code_selection.btn_setup_new_code, "disabled"), + lambda override: not override, + ) + ipw.dlink( + (code_model, "override"), + (code_widget.btn_setup_resource_detail, "disabled"), + lambda override: not override, + ) if isinstance(code_widget, PwCodeResourceSetupWidget): ipw.link( - (code_model, "override"), + (code_model, "parallelization_override"), (code_widget.parallelization.override, "value"), ) + ipw.dlink( + (code_model, "override"), + (code_widget.parallelization.override, "disabled"), + lambda override: not override, + ) ipw.link( (code_model, "npool"), (code_widget.parallelization.npool, "value"), ) + ipw.dlink( + (code_model, "override"), + (code_widget.parallelization.npool, "disabled"), + lambda override: not override, + ) + code_model.observe( + self._on_code_resource_change, + [ + "parallelization_override", + "npool", + ], + ) code_model.observe( self._on_code_resource_change, [ + "options", + "selected", "num_cpus", "num_nodes", "ntasks_per_node", @@ -402,15 +464,7 @@ def _render_code_widget( "max_wallclock_seconds", ], ) - # disable the code widget if the override is not set - code_widget.num_nodes.disabled = not self.override.value - code_widget.num_cpus.disabled = not self.override.value - code_widget.code_selection.code_select_dropdown.disabled = ( - not self.override.value - ) - code_widgets = self.code_widgets_container.children[:-1] # type: ignore - self.code_widgets_container.children = [*code_widgets, code_widget] code_model.is_rendered = True diff --git a/src/aiidalab_qe/plugins/bands/__init__.py b/src/aiidalab_qe/plugins/bands/__init__.py index 0c8b86ad6..7b607d1aa 100644 --- a/src/aiidalab_qe/plugins/bands/__init__.py +++ b/src/aiidalab_qe/plugins/bands/__init__.py @@ -18,7 +18,7 @@ class BandsPluginOutline(PluginOutline): "panel": BandsConfigurationSettingsPanel, "model": BandsConfigurationSettingsModel, }, - "code": { + "resources": { "panel": BandsResourceSettingsPanel, "model": BandsResourceSettingsModel, }, diff --git a/src/aiidalab_qe/plugins/bands/code.py b/src/aiidalab_qe/plugins/bands/code.py index 79571a00f..fa11f4318 100644 --- a/src/aiidalab_qe/plugins/bands/code.py +++ b/src/aiidalab_qe/plugins/bands/code.py @@ -7,18 +7,22 @@ class BandsResourceSettingsModel(ResourceSettingsModel): """Model for the band structure plugin.""" - codes = { - "pw": PwCodeModel( - name="pw.x", - description="pw.x", - default_calc_job_plugin="quantumespresso.pw", - ), - "projwfc_bands": CodeModel( - name="projwfc.x", - description="projwfc.x", - default_calc_job_plugin="quantumespresso.projwfc", - ), - } + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_models( + { + "pw": PwCodeModel( + name="pw.x", + description="pw.x", + default_calc_job_plugin="quantumespresso.pw", + ), + "projwfc_bands": CodeModel( + name="projwfc.x", + description="projwfc.x", + default_calc_job_plugin="quantumespresso.projwfc", + ), + } + ) class BandsResourceSettingsPanel(ResourceSettingsPanel[BandsResourceSettingsModel]): diff --git a/src/aiidalab_qe/plugins/pdos/__init__.py b/src/aiidalab_qe/plugins/pdos/__init__.py index 280a90c39..9d9d461fb 100644 --- a/src/aiidalab_qe/plugins/pdos/__init__.py +++ b/src/aiidalab_qe/plugins/pdos/__init__.py @@ -17,7 +17,7 @@ class PdosPluginOutline(PluginOutline): "panel": PdosConfigurationSettingPanel, "model": PdosConfigurationSettingsModel, }, - "code": { + "resources": { "panel": PdosResourceSettingsPanel, "model": PdosResourceSettingsModel, }, diff --git a/src/aiidalab_qe/plugins/pdos/code.py b/src/aiidalab_qe/plugins/pdos/code.py index 1e2095a25..dbe4a68b7 100644 --- a/src/aiidalab_qe/plugins/pdos/code.py +++ b/src/aiidalab_qe/plugins/pdos/code.py @@ -7,23 +7,27 @@ class PdosResourceSettingsModel(ResourceSettingsModel): """Model for the pdos code setting plugin.""" - codes = { - "pw": PwCodeModel( - name="pw.x", - description="pw.x", - default_calc_job_plugin="quantumespresso.pw", - ), - "dos": CodeModel( - name="dos.x", - description="dos.x", - default_calc_job_plugin="quantumespresso.dos", - ), - "projwfc": CodeModel( - name="projwfc.x", - description="projwfc.x", - default_calc_job_plugin="quantumespresso.projwfc", - ), - } + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_models( + { + "pw": PwCodeModel( + name="pw.x", + description="pw.x", + default_calc_job_plugin="quantumespresso.pw", + ), + "dos": CodeModel( + name="dos.x", + description="dos.x", + default_calc_job_plugin="quantumespresso.dos", + ), + "projwfc": CodeModel( + name="projwfc.x", + description="projwfc.x", + default_calc_job_plugin="quantumespresso.projwfc", + ), + } + ) class PdosResourceSettingsPanel(ResourceSettingsPanel[PdosResourceSettingsModel]): diff --git a/src/aiidalab_qe/plugins/xas/__init__.py b/src/aiidalab_qe/plugins/xas/__init__.py index 76a4af001..0237636a7 100644 --- a/src/aiidalab_qe/plugins/xas/__init__.py +++ b/src/aiidalab_qe/plugins/xas/__init__.py @@ -24,7 +24,7 @@ class XasPluginOutline(PluginOutline): "panel": XasConfigurationSettingsPanel, "model": XasConfigurationSettingsModel, }, - "code": { + "resources": { "panel": XasResourceSettingsPanel, "model": XasResourceSettingsModel, }, diff --git a/src/aiidalab_qe/plugins/xas/code.py b/src/aiidalab_qe/plugins/xas/code.py index ff07fb8f9..e980b9152 100644 --- a/src/aiidalab_qe/plugins/xas/code.py +++ b/src/aiidalab_qe/plugins/xas/code.py @@ -7,18 +7,22 @@ class XasResourceSettingsModel(ResourceSettingsModel): """Model for the XAS plugin.""" - codes = { - "pw": PwCodeModel( - name="pw.x", - description="pw.x", - default_calc_job_plugin="quantumespresso.pw", - ), - "xspectra": CodeModel( - name="xspectra.x", - description="xspectra.x", - default_calc_job_plugin="quantumespresso.xspectra", - ), - } + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_models( + { + "pw": PwCodeModel( + name="pw.x", + description="pw.x", + default_calc_job_plugin="quantumespresso.pw", + ), + "xspectra": CodeModel( + name="xspectra.x", + description="xspectra.x", + default_calc_job_plugin="quantumespresso.xspectra", + ), + } + ) class XasResourceSettingsPanel(ResourceSettingsPanel[XasResourceSettingsModel]): diff --git a/tests/conftest.py b/tests/conftest.py index bacfeb26c..0494ff941 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -429,19 +429,12 @@ def app(pw_code, dos_code, projwfc_code, projwfc_bands_code): app.submit_model.qe_installed = True # set up codes - pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw") - dos_code_model = app.submit_model.get_model("global").get_code( - "quantumespresso.dos" - ) - projwfc_code_model = app.submit_model.get_model("global").get_code( - "quantumespresso.projwfc" - ) - - pw_code_model.activate() - dos_code_model.activate() - projwfc_code_model.activate() + global_model = app.submit_model.get_model("global") + global_model.get_model("quantumespresso.pw").activate() + global_model.get_model("quantumespresso.dos").activate() + global_model.get_model("quantumespresso.projwfc").activate() - app.submit_model.get_model("global").set_selected_codes( + global_model.set_selected_codes( { "pw": {"code": pw_code.label}, "dos": {"code": dos_code.label}, @@ -509,7 +502,9 @@ def _submit_app_generator( app.configure_model.confirm() app.submit_model.input_structure = generate_structure_data() - app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 2 + app.submit_model.get_model("global").get_model( + "quantumespresso.pw" + ).num_cpus = 2 return app @@ -818,7 +813,9 @@ def _generate_qeapp_workchain( app.configure_model.confirm() # step 3 setup code and resources - app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 4 + app.submit_model.get_model("global").get_model( + "quantumespresso.pw" + ).num_cpus = 4 parameters = app.submit_model.get_model_state() builder = app.submit_model._create_builder(parameters) diff --git a/tests/test_codes.py b/tests/test_codes.py index 984215606..8df05c6ec 100644 --- a/tests/test_codes.py +++ b/tests/test_codes.py @@ -7,7 +7,7 @@ def test_code_not_selected(submit_app_generator): """Test if there is an error when the code is not selected.""" app: App = submit_app_generator(properties=["dos"]) model = app.submit_model - model.get_model("global").get_code("quantumespresso.dos").selected = None + model.get_model("global").get_model("quantumespresso.dos").selected = None # Check builder construction passes without an error parameters = model.get_model_state() model._create_builder(parameters) @@ -19,8 +19,8 @@ def test_set_selected_codes(submit_app_generator): parameters = app.submit_model.get_model_state() model = SubmissionStepModel() _ = SubmitQeAppWorkChainStep(model=model, qe_auto_setup=False) - for name, code_model in app.submit_model.get_model("global").codes.items(): - model.get_model("global").get_code(name).is_active = code_model.is_active + for identifier, code_model in app.submit_model.get_model("global").get_models(): + model.get_model("global").get_model(identifier).is_active = code_model.is_active model.qe_installed = True model.get_model("global").set_selected_codes(parameters["codes"]["global"]["codes"]) assert model.get_selected_codes() == app.submit_model.get_selected_codes() @@ -32,23 +32,14 @@ def test_update_codes_display(app: App): """ app.submit_step.render() model = app.submit_model - model.get_model("global").update_active_codes() - assert ( - app.submit_step.global_code_settings.code_widgets["dos"].layout.display - == "none" - ) + global_model = model.get_model("global") + global_model.update_active_codes() + global_resources = app.submit_step.global_resources + assert global_resources.code_widgets["dos"].layout.display == "none" model.input_parameters = {"workchain": {"properties": ["pdos"]}} - model.get_model("global").update_active_codes() - assert ( - app.submit_step._model.get_model("global") - .codes["quantumespresso.dos"] - .is_active - is True - ) - assert ( - app.submit_step.global_code_settings.code_widgets["dos"].layout.display - == "block" - ) + global_model.update_active_codes() + assert global_model.get_model("quantumespresso.dos").is_active is True + assert global_resources.code_widgets["dos"].layout.display == "block" def test_check_submission_blockers(app: App): @@ -63,7 +54,7 @@ def test_check_submission_blockers(app: App): assert len(model.internal_submission_blockers) == 0 # set dos code to None, will introduce another blocker - dos_code = model.get_model("global").get_code("quantumespresso.dos") + dos_code = model.get_model("global").get_model("quantumespresso.dos") dos_value = dos_code.selected dos_code.selected = None model.update_submission_blockers() @@ -78,16 +69,16 @@ def test_check_submission_blockers(app: App): def test_qeapp_computational_resources_widget(app: App): """Test QEAppComputationalResourcesWidget.""" app.submit_step.render() - pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw") - pw_code_widget = app.submit_step.global_code_settings.code_widgets["pw"] + global_model = app.submit_model.get_model("global") + global_resources = app.submit_step.global_resources + pw_code_model = global_model.get_model("quantumespresso.pw") + pw_code_widget = global_resources.code_widgets["pw"] assert pw_code_widget.parallelization.npool.layout.display == "none" - pw_code_model.override = True + pw_code_model.parallelization_override = True pw_code_model.npool = 2 assert pw_code_widget.parallelization.npool.layout.display == "block" assert pw_code_widget.parameters == { - "code": app.submit_step.global_code_settings.code_widgets[ - "pw" - ].value, # TODO why None? + "code": global_resources.code_widgets["pw"].value, "cpus": 1, "cpus_per_task": 1, "max_wallclock_seconds": 43200, diff --git a/tests/test_submit_qe_workchain.py b/tests/test_submit_qe_workchain.py index e0c367a69..90d06d58f 100644 --- a/tests/test_submit_qe_workchain.py +++ b/tests/test_submit_qe_workchain.py @@ -16,6 +16,7 @@ def test_create_builder_default( app.submit_model._create_builder(parameters) # since uuid is specific to each run, we remove it from the output ui_parameters = remove_uuid_fields(parameters) + remove_code_options(ui_parameters) # regression test for the parameters generated by the app # this parameters are passed to the workchain data_regression.check(ui_parameters) @@ -144,16 +145,17 @@ def test_warning_messages( app: App = submit_app_generator(properties=["bands", "pdos"]) submit_model = app.submit_model + global_model = submit_model.get_model("global") - pw_code = submit_model.get_model("global").get_code("quantumespresso.pw") + pw_code = global_model.get_model("quantumespresso.pw") pw_code.num_cpus = 1 - submit_model.get_model("global").check_resources() + global_model.check_resources() # no warning: assert submit_model.submission_warning_messages == "" # now we increase the resources, so we should have the Warning-3 pw_code.num_cpus = len(os.sched_getaffinity(0)) - submit_model.get_model("global").check_resources() + global_model.check_resources() for suggestion in ["avoid_overloading", "go_remote"]: assert suggestions[suggestion] in submit_model.submission_warning_messages @@ -161,12 +163,10 @@ def test_warning_messages( structure = generate_structure_data("H2O-larger") submit_model.input_structure = structure pw_code.num_cpus = 1 - submit_model.get_model("global").check_resources() + global_model.check_resources() num_sites = len(structure.sites) volume = structure.get_cell_volume() - estimated_CPUs = submit_model.get_model("global")._estimate_min_cpus( - num_sites, volume - ) + estimated_CPUs = global_model._estimate_min_cpus(num_sites, volume) assert estimated_CPUs == 2 for suggestion in ["more_resources", "change_configuration"]: assert suggestions[suggestion] in submit_model.submission_warning_messages @@ -232,3 +232,10 @@ def remove_uuid_fields(data): else: # Return the value unchanged if it's not a dictionary or list return data + + +def remove_code_options(parameters): + """Remove the code options from the parameters.""" + for panel in parameters["codes"].values(): # type: ignore + for code in panel["codes"].values(): + del code["options"] diff --git a/tests/test_submit_qe_workchain/test_create_builder_default.yml b/tests/test_submit_qe_workchain/test_create_builder_default.yml index 2d6f756e3..dd3f9dac3 100644 --- a/tests/test_submit_qe_workchain/test_create_builder_default.yml +++ b/tests/test_submit_qe_workchain/test_create_builder_default.yml @@ -37,7 +37,6 @@ codes: nodes: 1 ntasks_per_node: 1 parallelization: {} - override: false global: codes: quantumespresso.dos: @@ -87,7 +86,6 @@ codes: nodes: 1 ntasks_per_node: 1 parallelization: {} - override: false pdos: nscf_kpoints_distance: 0.1 pdos_degauss: 0.005