diff --git a/src/aiidalab_qe/app/submission/__init__.py b/src/aiidalab_qe/app/submission/__init__.py index 61d089c2f..773d7e8a4 100644 --- a/src/aiidalab_qe/app/submission/__init__.py +++ b/src/aiidalab_qe/app/submission/__init__.py @@ -10,17 +10,12 @@ from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS from aiidalab_qe.app.utils import get_entry_items -from aiidalab_qe.common.code import CodeModel, PluginCodes, PwCodeModel from aiidalab_qe.common.panel import SettingsModel, SettingsPanel from aiidalab_qe.common.setup_codes import QESetupWidget from aiidalab_qe.common.setup_pseudos import PseudosInstallWidget -from aiidalab_qe.common.widgets import ( - PwCodeResourceSetupWidget, - QEAppComputationalResourcesWidget, -) from aiidalab_widgets_base import WizardAppWidgetStep -from .basic import BasicCodeModel, BasicCodeSettings +from .global_settings import GlobalCodeModel, GlobalCodeSettings from .model import SubmissionStepModel DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore @@ -82,20 +77,20 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs): self.rendered = False - basic_code_model = BasicCodeModel() - self.basic_code_settings = BasicCodeSettings(model=basic_code_model) - self._model.add_model("basic", basic_code_model) - basic_code_model.observe( + global_code_model = GlobalCodeModel() + self.global_code_settings = GlobalCodeSettings(model=global_code_model) + self._model.add_model("global", global_code_model) + global_code_model.observe( self._on_plugin_submission_blockers_change, ["submission_blockers"], ) - basic_code_model.observe( + global_code_model.observe( self._on_plugin_submission_warning_messages_change, ["submission_warning_messages"], ) self.settings = { - "basic": self.basic_code_settings, + "global": self.global_code_settings, } self._fetch_plugin_settings() @@ -350,7 +345,7 @@ def _fetch_plugin_settings(self): ) self._model.add_model(identifier, model) - def toggle_plugin(change, identifier=identifier, model=model): + def toggle_plugin(_, model=model): model.update() self._update_tabs() diff --git a/src/aiidalab_qe/app/submission/basic/__init__.py b/src/aiidalab_qe/app/submission/basic/__init__.py deleted file mode 100644 index 29dea71a1..000000000 --- a/src/aiidalab_qe/app/submission/basic/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .model import BasicCodeModel -from .setting import BasicCodeSettings - -__all__ = [ - "BasicCodeModel", - "BasicCodeSettings", -] diff --git a/src/aiidalab_qe/app/submission/global_settings/__init__.py b/src/aiidalab_qe/app/submission/global_settings/__init__.py new file mode 100644 index 000000000..9a44d03de --- /dev/null +++ b/src/aiidalab_qe/app/submission/global_settings/__init__.py @@ -0,0 +1,7 @@ +from .model import GlobalCodeModel +from .setting import GlobalCodeSettings + +__all__ = [ + "GlobalCodeModel", + "GlobalCodeSettings", +] diff --git a/src/aiidalab_qe/app/submission/basic/model.py b/src/aiidalab_qe/app/submission/global_settings/model.py similarity index 93% rename from src/aiidalab_qe/app/submission/basic/model.py rename to src/aiidalab_qe/app/submission/global_settings/model.py index 243a84f41..bf7838337 100644 --- a/src/aiidalab_qe/app/submission/basic/model.py +++ b/src/aiidalab_qe/app/submission/global_settings/model.py @@ -16,11 +16,11 @@ DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore -class BasicCodeModel( +class GlobalCodeModel( SettingsModel, HasInputStructure, ): - """Model for the basic code setting.""" + """Model for the global code setting.""" dependencies = [ "input_parameters", @@ -34,9 +34,9 @@ class BasicCodeModel( value_trait=tl.Instance(CodeModel), # code metadata ) # this is a copy of the codes trait, which is used to trigger the update of the plugin - basic_codes = tl.Dict( + global_codes = tl.Dict( key_trait=tl.Unicode(), # code name - value_trait=tl.Instance(CodeModel), # code metadata + value_trait=tl.Dict(), # code metadata ) plugin_mapping = tl.Dict( @@ -83,15 +83,18 @@ def update_active_codes(self): self.codes[code_name].activate() def get_model_state(self): - return {"codes": self.codes} + codes = {name: model.get_model_state() for name, model in self.codes.items()} + + return {"codes": codes} def set_model_state(self, code_data: dict): for name, code_model in self.codes.items(): - if name in code_data: + if name in code_data and code_model.is_active: code_model.set_model_state(code_data[name]) def add_code(self, identifier: str, code: CodeModel) -> CodeModel | None: """Add a code to the codes trait.""" + code_model = None default_calc_job_plugin = code.default_calc_job_plugin name = default_calc_job_plugin.split(".")[-1] if default_calc_job_plugin not in self.codes: @@ -108,13 +111,14 @@ def add_code(self, identifier: str, code: CodeModel) -> CodeModel | None: default_calc_job_plugin=default_calc_job_plugin, ) self.codes[default_calc_job_plugin] = code_model - return code_model # update the plugin mapping to keep track of which codes are associated with which plugins if identifier not in self.plugin_mapping: self.plugin_mapping[identifier] = [default_calc_job_plugin] else: self.plugin_mapping[identifier].append(default_calc_job_plugin) + return code_model + def get_code(self, name) -> CodeModel | None: if name in self.codes: # type: ignore return self.codes[name] # type: ignore @@ -127,10 +131,9 @@ def get_selected_codes(self) -> dict[str, dict]: } def set_selected_codes(self, code_data=DEFAULT["codes"]): - with self.hold_trait_notifications(): - for name, code_model in self.codes.items(): - if name in code_data: - code_model.set_model_state(code_data[name]) + for name, code_model in self.codes.items(): + if name in code_data and code_model.is_active: + code_model.set_model_state(code_data[name]) def reset(self): """Reset the model to its default state.""" @@ -145,21 +148,22 @@ def update_submission_blockers(self): def _check_submission_blockers(self): # No pw code selected (this is ignored while the setup process is running). - pw_code = self._model.get_code("quantumespresso.pw") + pw_code = self.get_code("quantumespresso.pw") if pw_code and not pw_code.selected and not self.installing_qe: yield ("No pw code selected") # code related to the selected property is not installed properties = self._get_properties() message = "Calculating the {property} property requires code {code} to be set." - for identifier, codes in self.get_model("basic").codes.items(): + for identifier, code_names in self.plugin_mapping.items(): if identifier in properties: - for code in codes.values(): + for name in code_names: + code = self.get_code(name) if not code.is_ready: yield message.format(property=identifier, code=code.description) # check if the QEAppComputationalResourcesWidget is used - for name, code in self.get_model("basic").codes.items(): + for name, code in self.codes.items(): # skip if the code is not displayed, convenient for the plugin developer if not code.is_ready: continue @@ -241,7 +245,6 @@ def check_resources(self): + suggestions["go_remote"] + "" ) - print("alert_message: ", alert_message) self.submission_warning_messages = ( "" diff --git a/src/aiidalab_qe/app/submission/basic/setting.py b/src/aiidalab_qe/app/submission/global_settings/setting.py similarity index 89% rename from src/aiidalab_qe/app/submission/basic/setting.py rename to src/aiidalab_qe/app/submission/global_settings/setting.py index 564eb4621..27d7b475f 100644 --- a/src/aiidalab_qe/app/submission/basic/setting.py +++ b/src/aiidalab_qe/app/submission/global_settings/setting.py @@ -17,16 +17,16 @@ QEAppComputationalResourcesWidget, ) -from .model import BasicCodeModel +from .model import GlobalCodeModel DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore -class BasicCodeSettings(SettingsPanel[BasicCodeModel]): - title = "Basic" - identifier = "basic" +class GlobalCodeSettings(SettingsPanel[GlobalCodeModel]): + title = "Global settings" + identifier = "global" - def __init__(self, model: BasicCodeModel, **kwargs): + def __init__(self, model: GlobalCodeModel, **kwargs): super().__init__(model, **kwargs) self._set_up_codes() self._model.observe( @@ -77,17 +77,16 @@ def _on_code_activation_change(self, change): def _on_code_selection_change(self, _): """""" - # TODO: update the selected code in the input parameters - # self._model.update_submission_blockers() + self._model.update_submission_blockers() def _on_pw_code_resource_change(self, _): self._model.check_resources() def _on_code_resource_change(self, _): """Update the plugin resources.""" - # trigger the update of basic codes - self._model.basic_codes = {} - self._model.basic_codes = self._model.get_model_state()["codes"] + # trigger the update of global codes + self._model.global_codes = {} + self._model.global_codes = self._model.get_model_state()["codes"] def _set_up_codes(self): codes: PluginCodes = { @@ -105,14 +104,14 @@ def _set_up_codes(self): codes[identifier] = data["model"].codes for identifier, code_models in codes.items(): for _, code_model in code_models.items(): - # use the new code model created using the basic code model - code_model = self._model.add_code(identifier, code_model) - if code_model is not None: - code_model.observe( + # use the new code model created using the global code model + base_code_model = self._model.add_code(identifier, code_model) + if base_code_model is not None: + base_code_model.observe( self._on_code_activation_change, "is_active", ) - code_model.observe( + base_code_model.observe( self._on_code_selection_change, "selected", ) @@ -140,7 +139,7 @@ def _render_code_widget( code_model: CodeModel, code_widget: QEAppComputationalResourcesWidget, ): - code_model.update() + code_model.update(None) ipw.dlink( (code_model, "options"), (code_widget.code_selection.code_select_dropdown, "options"), diff --git a/src/aiidalab_qe/app/submission/model.py b/src/aiidalab_qe/app/submission/model.py index 3970d95b9..10e2a0298 100644 --- a/src/aiidalab_qe/app/submission/model.py +++ b/src/aiidalab_qe/app/submission/model.py @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._default_models = { - "basic", + "global", } self._ALERT_MESSAGE = """ @@ -78,7 +78,7 @@ def refresh_codes(self): def update_active_models(self): for identifier, model in self.get_models(): - if identifier in ["basic"]: + if identifier in ["global"]: continue if identifier not in self._get_properties(): model.include = False @@ -164,11 +164,12 @@ def set_model_state(self, parameters): } workchain_parameters: dict = parameters.get("workchain", {}) properties = set(workchain_parameters.get("properties", [])) + print("codes: ", parameters["codes"]) with self.hold_trait_notifications(): for identifier, model in self._models.items(): model.include = identifier in self._default_models | properties - if parameters.get(identifier): - model.set_model_state(parameters[identifier]) + if parameters["codes"].get(identifier): + model.set_model_state(parameters["codes"][identifier]["codes"]) model.loaded_from_process = True if self.process_node: @@ -179,7 +180,7 @@ def set_model_state(self, parameters): def get_selected_codes(self) -> dict[str, dict]: return { name: code_model.get_model_state() - for name, code_model in self.get_model("basic").codes.items() + for name, code_model in self.get_model("global").codes.items() if code_model.is_ready } @@ -240,16 +241,18 @@ def _create_builder(self, parameters) -> ProcessBuilderNamespace: parameters=deepcopy(parameters), # TODO why deepcopy again? ) - codes = parameters["codes"] + codes = parameters["codes"]["global"]["codes"] builder.relax.base.pw.metadata.options.resources = { - "num_machines": codes.get("pw")["nodes"], - "num_mpiprocs_per_machine": codes.get("pw")["ntasks_per_node"], - "num_cores_per_mpiproc": codes.get("pw")["cpus_per_task"], + "num_machines": codes.get("quantumespresso.pw")["nodes"], + "num_mpiprocs_per_machine": codes.get("quantumespresso.pw")[ + "ntasks_per_node" + ], + "num_cores_per_mpiproc": codes.get("quantumespresso.pw")["cpus_per_task"], } - mws = codes.get("pw")["max_wallclock_seconds"] + mws = codes.get("quantumespresso.pw")["max_wallclock_seconds"] builder.relax.base.pw.metadata.options["max_wallclock_seconds"] = mws - parallelization = codes["pw"]["parallelization"] + parallelization = codes["quantumespresso.pw"]["parallelization"] builder.relax.base.pw.parallelization = orm.Dict(dict=parallelization) return builder diff --git a/src/aiidalab_qe/common/code/model.py b/src/aiidalab_qe/common/code/model.py index ff026e9a9..6bf684eb7 100644 --- a/src/aiidalab_qe/common/code/model.py +++ b/src/aiidalab_qe/common/code/model.py @@ -54,7 +54,7 @@ def activate(self): def deactivate(self): self.is_active = False - def update(self, user_email: str = None): + def update(self, user_email: str): if not self.options: self.options = self._get_codes(user_email) self.selected = self.options[0][1] if self.options else None @@ -88,7 +88,7 @@ def _get_uuid(self, identifier): self.selected = uuid if uuid in [opt[1] for opt in self.options] else None return self.selected - def _get_codes(self, user_email: str = None): + def _get_codes(self, user_email: str): # set default user_email if not provided user_email = user_email or orm.User.collection.get_default().email user = orm.User.collection.get(email=user_email) diff --git a/src/aiidalab_qe/plugins/bands/code.py b/src/aiidalab_qe/plugins/bands/code.py index a7735fe7a..e191bf7fe 100644 --- a/src/aiidalab_qe/plugins/bands/code.py +++ b/src/aiidalab_qe/plugins/bands/code.py @@ -4,7 +4,7 @@ import traitlets as tl from aiida import orm -from aiidalab_qe.common.code.model import CodeModel +from aiidalab_qe.common.code.model import CodeModel, PwCodeModel from aiidalab_qe.common.panel import SettingsModel, SettingsPanel from aiidalab_qe.common.widgets import ( LoadingWidget, @@ -17,11 +17,11 @@ class BandsCodeModel(SettingsModel): """Model for the band structure plugin.""" dependencies = [ - "basic.basic_codes", + "global.global_codes", ] codes = { - "pw": CodeModel( + "pw": PwCodeModel( name="pw.x", description="pw.x", default_calc_job_plugin="quantumespresso.pw", @@ -33,9 +33,9 @@ class BandsCodeModel(SettingsModel): ), } - basic_codes = tl.Dict( + global_codes = tl.Dict( key_trait=tl.Unicode(), # code name - value_trait=tl.Instance(CodeModel), # code metadata + value_trait=tl.Dict(), # code metadata ) submission_blockers = tl.List(tl.Unicode()) submission_warning_messages = tl.Unicode("") @@ -54,20 +54,20 @@ def refresh_codes(self): for _, code_model in self.codes.items(): code_model.update(self._default_user_email) # type: ignore - def update_code_from_basic(self): + def update_code_from_global(self): # skip the sync if the user has overridden the settings - print("override: ", self.override) if self.override: return - for name, code_model in self.codes.items(): + for _, code_model in self.codes.items(): default_calc_job_plugin = code_model.default_calc_job_plugin - if default_calc_job_plugin in self.basic_codes: - code_data = self.basic_codes[default_calc_job_plugin].get_model_state() + if default_calc_job_plugin in self.global_codes: + code_data = self.global_codes[default_calc_job_plugin] code_model.set_model_state(code_data) def get_model_state(self): + codes = {name: model.get_model_state() for name, model in self.codes.items()} return { - "codes": self.codes, + "codes": codes, "override": self.override, } @@ -90,8 +90,8 @@ def __init__(self, model, **kwargs): super().__init__(model, **kwargs) self._model.observe( - self._on_basic_codes_change, - "basic_codes", + self._on_global_codes_change, + "global_codes", ) self._model.observe( self._on_override_change, @@ -127,8 +127,8 @@ def render(self): self._toggle_code(code_model) return self.code_widgets_container - def _on_basic_codes_change(self, _): - self._model.update_code_from_basic() + def _on_global_codes_change(self, _): + self._model.update_code_from_global() def _on_override_change(self, change): if change["new"]: diff --git a/src/aiidalab_qe/plugins/pdos/code.py b/src/aiidalab_qe/plugins/pdos/code.py index 3dc010a6c..fc6af0d99 100644 --- a/src/aiidalab_qe/plugins/pdos/code.py +++ b/src/aiidalab_qe/plugins/pdos/code.py @@ -4,7 +4,7 @@ import traitlets as tl from aiida import orm -from aiidalab_qe.common.code.model import CodeModel +from aiidalab_qe.common.code.model import CodeModel, PwCodeModel from aiidalab_qe.common.panel import SettingsModel, SettingsPanel from aiidalab_qe.common.widgets import ( LoadingWidget, @@ -19,6 +19,11 @@ class PdosCodeModel(SettingsModel): dependencies = [] codes = { + "pw": PwCodeModel( + name="pw.x", + description="pw.x", + default_calc_job_plugin="quantumespresso.pw", + ), "dos": CodeModel( name="dos.x", description="dos.x", @@ -30,6 +35,11 @@ class PdosCodeModel(SettingsModel): default_calc_job_plugin="quantumespresso.projwfc", ), } + global_codes = tl.Dict( + key_trait=tl.Unicode(), # code name + value_trait=tl.Dict(), # code metadata + ) + submission_blockers = tl.List(tl.Unicode()) submission_warning_messages = tl.Unicode("") override = tl.Bool(False) @@ -46,8 +56,22 @@ def refresh_codes(self): for _, code_model in self.codes.items(): code_model.update(self._default_user_email) # type: ignore + def update_code_from_global(self): + # skip the sync if the user has overridden the settings + if self.override: + return + for _, code_model in self.codes.items(): + default_calc_job_plugin = code_model.default_calc_job_plugin + if default_calc_job_plugin in self.global_codes: + code_data = self.global_codes[default_calc_job_plugin] + code_model.set_model_state(code_data) + def get_model_state(self): - return {"codes": self.codes} + codes = {name: model.get_model_state() for name, model in self.codes.items()} + return { + "codes": codes, + "override": self.override, + } def set_model_state(self, code_data: dict): for name, code_model in self.codes.items(): @@ -61,7 +85,7 @@ def reset(self): class PdosCodeSettings(SettingsPanel[PdosCodeModel]): - title = "PDOS Structure" + title = "PDOS" identifier = "pdos" def render(self): diff --git a/src/aiidalab_qe/plugins/xas/code.py b/src/aiidalab_qe/plugins/xas/code.py index f4a130da0..d889fea78 100644 --- a/src/aiidalab_qe/plugins/xas/code.py +++ b/src/aiidalab_qe/plugins/xas/code.py @@ -4,7 +4,7 @@ import traitlets as tl from aiida import orm -from aiidalab_qe.common.code.model import CodeModel +from aiidalab_qe.common.code.model import CodeModel, PwCodeModel from aiidalab_qe.common.panel import SettingsModel, SettingsPanel from aiidalab_qe.common.widgets import ( LoadingWidget, @@ -19,17 +19,27 @@ class XasCodeModel(SettingsModel): dependencies = [] codes = { + "pw": PwCodeModel( + name="pw.x", + description="pw.x", + default_calc_job_plugin="quantumespresso.pw", + ), "xspectra": CodeModel( name="xspectra.x", description="xspectra.x", default_calc_job_plugin="quantumespresso.xspectra", - ) + ), } - basic_codes = tl.Dict( + global_codes = tl.Dict( key_trait=tl.Unicode(), # code name value_trait=tl.Instance(CodeModel), # code metadata ) + global_codes = tl.Dict( + key_trait=tl.Unicode(), # code name + value_trait=tl.Dict(), # code metadata + ) + submission_blockers = tl.List(tl.Unicode()) submission_warning_messages = tl.Unicode("") override = tl.Bool(False) @@ -46,8 +56,22 @@ def refresh_codes(self): for _, code_model in self.codes.items(): code_model.update(self._default_user_email) # type: ignore + def update_code_from_global(self): + # skip the sync if the user has overridden the settings + if self.override: + return + for _, code_model in self.codes.items(): + default_calc_job_plugin = code_model.default_calc_job_plugin + if default_calc_job_plugin in self.global_codes: + code_data = self.global_codes[default_calc_job_plugin] + code_model.set_model_state(code_data) + def get_model_state(self): - return {"codes": self.codes} + codes = {name: model.get_model_state() for name, model in self.codes.items()} + return { + "codes": codes, + "override": self.override, + } def set_model_state(self, code_data: dict): for name, code_model in self.codes.items(): diff --git a/src/aiidalab_qe/workflows/__init__.py b/src/aiidalab_qe/workflows/__init__.py index 0b7187f6b..dd2836913 100644 --- a/src/aiidalab_qe/workflows/__init__.py +++ b/src/aiidalab_qe/workflows/__init__.py @@ -175,8 +175,9 @@ def get_builder_from_protocol( "base_final_scf": parameters["advanced"], } protocol = parameters["workchain"]["protocol"] + print("codes: ", codes["global"]["codes"].get("quantumespresso.pw")) relax_builder = PwRelaxWorkChain.get_builder_from_protocol( - code=codes["basic"]["codes"].get("quantumespresso.pw")["code"], + code=codes["global"]["codes"].get("quantumespresso.pw")["code"], structure=structure, protocol=protocol, relax_type=RelaxType(parameters["workchain"]["relax_type"]), @@ -202,7 +203,10 @@ def get_builder_from_protocol( for name, entry_point in plugin_entries.items(): if name in properties: plugin_builder = entry_point["get_builder"]( - codes, builder.structure, copy.deepcopy(parameters), **kwargs + codes[name]["codes"], + builder.structure, + copy.deepcopy(parameters), + **kwargs, ) plugin_workchain = entry_point["workchain"] if plugin_workchain.spec().has_input("clean_workdir"): diff --git a/tests/conftest.py b/tests/conftest.py index 4d1ecdb0a..094a6ccb4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -424,22 +424,23 @@ def app(pw_code, dos_code, projwfc_code, projwfc_bands_code): app.submit_model.qe_installed = True # set up codes - pw_code_model = app.submit_model.get_code("dft", "pw") - dos_code_model = app.submit_model.get_code("pdos", "dos") - projwfc_code_model = app.submit_model.get_code("pdos", "projwfc") - projwfc_bands_code_model = app.submit_model.get_code("bands", "projwfc_bands") + pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw") + dos_code_model = app.submit_model.get_model("global").get_code( + "quantumespresso.dos" + ) + projwfc_code_model = app.submit_model.get_model("global").get_code( + "quantumespresso.projwfc" + ) pw_code_model.activate() dos_code_model.activate() projwfc_code_model.activate() - projwfc_bands_code_model.activate() - app.submit_model.set_selected_codes( + app.submit_model.get_model("global").set_selected_codes( { "pw": {"code": pw_code.label}, "dos": {"code": dos_code.label}, "projwfc": {"code": projwfc_code.label}, - "projwfc_bands": {"code": projwfc_bands_code.label}, } ) @@ -503,7 +504,7 @@ def _submit_app_generator( app.configure_model.confirm() app.submit_model.input_structure = generate_structure_data() - app.submit_model.get_code("dft", "pw").num_cpus = 2 + app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 2 return app @@ -812,7 +813,7 @@ def _generate_qeapp_workchain( app.configure_model.confirm() # step 3 setup code and resources - app.submit_model.get_code("dft", "pw").num_cpus = 4 + app.submit_model.get_model("global").get_code("quantumespresso.pw").num_cpus = 4 parameters = app.submit_model.get_model_state() builder = app.submit_model._create_builder(parameters) diff --git a/tests/test_codes.py b/tests/test_codes.py index be3e480fb..984215606 100644 --- a/tests/test_codes.py +++ b/tests/test_codes.py @@ -7,7 +7,7 @@ def test_code_not_selected(submit_app_generator): """Test if there is an error when the code is not selected.""" app: App = submit_app_generator(properties=["dos"]) model = app.submit_model - model.get_code("pdos", "dos").selected = None + model.get_model("global").get_code("quantumespresso.dos").selected = None # Check builder construction passes without an error parameters = model.get_model_state() model._create_builder(parameters) @@ -19,11 +19,10 @@ def test_set_selected_codes(submit_app_generator): parameters = app.submit_model.get_model_state() model = SubmissionStepModel() _ = SubmitQeAppWorkChainStep(model=model, qe_auto_setup=False) - for identifier, code_models in app.submit_model.get_code_models(): - for name, code_model in code_models.items(): - model.get_code(identifier, name).is_active = code_model.is_active + for name, code_model in app.submit_model.get_model("global").codes.items(): + model.get_model("global").get_code(name).is_active = code_model.is_active model.qe_installed = True - model.set_selected_codes(parameters["codes"]) + model.get_model("global").set_selected_codes(parameters["codes"]["global"]["codes"]) assert model.get_selected_codes() == app.submit_model.get_selected_codes() @@ -33,48 +32,62 @@ def test_update_codes_display(app: App): """ app.submit_step.render() model = app.submit_model - model.update_active_codes() - assert app.submit_step.code_widgets["dos"].layout.display == "none" + model.get_model("global").update_active_codes() + assert ( + app.submit_step.global_code_settings.code_widgets["dos"].layout.display + == "none" + ) model.input_parameters = {"workchain": {"properties": ["pdos"]}} - model.update_active_codes() - assert app.submit_step.code_widgets["dos"].layout.display == "block" + model.get_model("global").update_active_codes() + assert ( + app.submit_step._model.get_model("global") + .codes["quantumespresso.dos"] + .is_active + is True + ) + assert ( + app.submit_step.global_code_settings.code_widgets["dos"].layout.display + == "block" + ) def test_check_submission_blockers(app: App): """Test check_submission_blockers method.""" model = app.submit_model - blockers = list(model._check_submission_blockers()) - assert len(blockers) == 0 + model.update_submission_blockers() + assert len(model.internal_submission_blockers) == 0 model.input_parameters = {"workchain": {"properties": ["pdos"]}} - blockers = list(model._check_submission_blockers()) - assert len(blockers) == 0 + model.update_submission_blockers() + assert len(model.internal_submission_blockers) == 0 # set dos code to None, will introduce another blocker - dos_code = model.get_code("pdos", "dos") + dos_code = model.get_model("global").get_code("quantumespresso.dos") dos_value = dos_code.selected dos_code.selected = None - blockers = list(model._check_submission_blockers()) - assert len(blockers) == 1 + model.update_submission_blockers() + assert len(model.internal_submission_blockers) == 1 # set dos code back will remove the blocker dos_code.selected = dos_value - blockers = list(model._check_submission_blockers()) - assert len(blockers) == 0 + model.update_submission_blockers() + assert len(model.internal_submission_blockers) == 0 def test_qeapp_computational_resources_widget(app: App): """Test QEAppComputationalResourcesWidget.""" app.submit_step.render() - pw_code_model = app.submit_model.get_code("dft", "pw") - pw_code_widget = app.submit_step.code_widgets["pw"] + pw_code_model = app.submit_model.get_model("global").get_code("quantumespresso.pw") + pw_code_widget = app.submit_step.global_code_settings.code_widgets["pw"] assert pw_code_widget.parallelization.npool.layout.display == "none" pw_code_model.override = True pw_code_model.npool = 2 assert pw_code_widget.parallelization.npool.layout.display == "block" assert pw_code_widget.parameters == { - "code": app.submit_step.code_widgets["pw"].value, # TODO why None? + "code": app.submit_step.global_code_settings.code_widgets[ + "pw" + ].value, # TODO why None? "cpus": 1, "cpus_per_task": 1, "max_wallclock_seconds": 43200, diff --git a/tests/test_submit_qe_workchain.py b/tests/test_submit_qe_workchain.py index 7d183e41b..e0c367a69 100644 --- a/tests/test_submit_qe_workchain.py +++ b/tests/test_submit_qe_workchain.py @@ -145,15 +145,15 @@ def test_warning_messages( app: App = submit_app_generator(properties=["bands", "pdos"]) submit_model = app.submit_model - pw_code = submit_model.get_code("dft", "pw") + pw_code = submit_model.get_model("global").get_code("quantumespresso.pw") pw_code.num_cpus = 1 - submit_model.check_resources() + submit_model.get_model("global").check_resources() # no warning: assert submit_model.submission_warning_messages == "" # now we increase the resources, so we should have the Warning-3 pw_code.num_cpus = len(os.sched_getaffinity(0)) - submit_model.check_resources() + submit_model.get_model("global").check_resources() for suggestion in ["avoid_overloading", "go_remote"]: assert suggestions[suggestion] in submit_model.submission_warning_messages @@ -161,10 +161,12 @@ def test_warning_messages( structure = generate_structure_data("H2O-larger") submit_model.input_structure = structure pw_code.num_cpus = 1 - submit_model.check_resources() + submit_model.get_model("global").check_resources() num_sites = len(structure.sites) volume = structure.get_cell_volume() - estimated_CPUs = submit_model._estimate_min_cpus(num_sites, volume) + estimated_CPUs = submit_model.get_model("global")._estimate_min_cpus( + num_sites, volume + ) assert estimated_CPUs == 2 for suggestion in ["more_resources", "change_configuration"]: assert suggestions[suggestion] in submit_model.submission_warning_messages diff --git a/tests/test_submit_qe_workchain/test_create_builder_default.yml b/tests/test_submit_qe_workchain/test_create_builder_default.yml index 3ad9d35b7..2d6f756e3 100644 --- a/tests/test_submit_qe_workchain/test_create_builder_default.yml +++ b/tests/test_submit_qe_workchain/test_create_builder_default.yml @@ -22,31 +22,72 @@ advanced: bands: projwfc_bands: false codes: - dos: - cpus: 1 - cpus_per_task: 1 - max_wallclock_seconds: 43200 - nodes: 1 - ntasks_per_node: 1 - projwfc: - cpus: 1 - cpus_per_task: 1 - max_wallclock_seconds: 43200 - nodes: 1 - ntasks_per_node: 1 - projwfc_bands: - cpus: 1 - cpus_per_task: 1 - max_wallclock_seconds: 43200 - nodes: 1 - ntasks_per_node: 1 - pw: - cpus: 2 - cpus_per_task: 1 - max_wallclock_seconds: 43200 - nodes: 1 - ntasks_per_node: 2 - parallelization: {} + bands: + codes: + projwfc_bands: + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + pw: + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + parallelization: {} + override: false + global: + codes: + quantumespresso.dos: + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + quantumespresso.projwfc: + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + quantumespresso.pw: + cpus: 2 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 2 + parallelization: {} + quantumespresso.xspectra: + code: null + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + pdos: + codes: + dos: + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + projwfc: + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + pw: + cpus: 1 + cpus_per_task: 1 + max_wallclock_seconds: 43200 + nodes: 1 + ntasks_per_node: 1 + parallelization: {} + override: false pdos: nscf_kpoints_distance: 0.1 pdos_degauss: 0.005