Skip to content

Commit

Permalink
udpate test
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Nov 21, 2024
1 parent 3f935e1 commit cfe0ae7
Show file tree
Hide file tree
Showing 15 changed files with 258 additions and 149 deletions.
21 changes: 8 additions & 13 deletions src/aiidalab_qe/app/submission/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,12 @@

from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.app.utils import get_entry_items
from aiidalab_qe.common.code import CodeModel, PluginCodes, PwCodeModel
from aiidalab_qe.common.panel import SettingsModel, SettingsPanel
from aiidalab_qe.common.setup_codes import QESetupWidget
from aiidalab_qe.common.setup_pseudos import PseudosInstallWidget
from aiidalab_qe.common.widgets import (
PwCodeResourceSetupWidget,
QEAppComputationalResourcesWidget,
)
from aiidalab_widgets_base import WizardAppWidgetStep

from .basic import BasicCodeModel, BasicCodeSettings
from .global_settings import GlobalCodeModel, GlobalCodeSettings
from .model import SubmissionStepModel

DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore
Expand Down Expand Up @@ -82,20 +77,20 @@ def __init__(self, model: SubmissionStepModel, qe_auto_setup=True, **kwargs):

self.rendered = False

basic_code_model = BasicCodeModel()
self.basic_code_settings = BasicCodeSettings(model=basic_code_model)
self._model.add_model("basic", basic_code_model)
basic_code_model.observe(
global_code_model = GlobalCodeModel()
self.global_code_settings = GlobalCodeSettings(model=global_code_model)
self._model.add_model("global", global_code_model)
global_code_model.observe(
self._on_plugin_submission_blockers_change,
["submission_blockers"],
)
basic_code_model.observe(
global_code_model.observe(
self._on_plugin_submission_warning_messages_change,
["submission_warning_messages"],
)

self.settings = {
"basic": self.basic_code_settings,
"global": self.global_code_settings,
}
self._fetch_plugin_settings()

Expand Down Expand Up @@ -350,7 +345,7 @@ def _fetch_plugin_settings(self):
)
self._model.add_model(identifier, model)

def toggle_plugin(change, identifier=identifier, model=model):
def toggle_plugin(_, model=model):
model.update()
self._update_tabs()

Expand Down
7 changes: 0 additions & 7 deletions src/aiidalab_qe/app/submission/basic/__init__.py

This file was deleted.

7 changes: 7 additions & 0 deletions src/aiidalab_qe/app/submission/global_settings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .model import GlobalCodeModel
from .setting import GlobalCodeSettings

__all__ = [
"GlobalCodeModel",
"GlobalCodeSettings",
]
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore


class BasicCodeModel(
class GlobalCodeModel(
SettingsModel,
HasInputStructure,
):
"""Model for the basic code setting."""
"""Model for the global code setting."""

dependencies = [
"input_parameters",
Expand All @@ -34,9 +34,9 @@ class BasicCodeModel(
value_trait=tl.Instance(CodeModel), # code metadata
)
# this is a copy of the codes trait, which is used to trigger the update of the plugin
basic_codes = tl.Dict(
global_codes = tl.Dict(
key_trait=tl.Unicode(), # code name
value_trait=tl.Instance(CodeModel), # code metadata
value_trait=tl.Dict(), # code metadata
)

plugin_mapping = tl.Dict(
Expand Down Expand Up @@ -83,15 +83,18 @@ def update_active_codes(self):
self.codes[code_name].activate()

def get_model_state(self):
return {"codes": self.codes}
codes = {name: model.get_model_state() for name, model in self.codes.items()}

return {"codes": codes}

def set_model_state(self, code_data: dict):
for name, code_model in self.codes.items():
if name in code_data:
if name in code_data and code_model.is_active:
code_model.set_model_state(code_data[name])

def add_code(self, identifier: str, code: CodeModel) -> CodeModel | None:
"""Add a code to the codes trait."""
code_model = None
default_calc_job_plugin = code.default_calc_job_plugin
name = default_calc_job_plugin.split(".")[-1]
if default_calc_job_plugin not in self.codes:
Expand All @@ -108,13 +111,14 @@ def add_code(self, identifier: str, code: CodeModel) -> CodeModel | None:
default_calc_job_plugin=default_calc_job_plugin,
)
self.codes[default_calc_job_plugin] = code_model
return code_model
# update the plugin mapping to keep track of which codes are associated with which plugins
if identifier not in self.plugin_mapping:
self.plugin_mapping[identifier] = [default_calc_job_plugin]
else:
self.plugin_mapping[identifier].append(default_calc_job_plugin)

return code_model

def get_code(self, name) -> CodeModel | None:
if name in self.codes: # type: ignore
return self.codes[name] # type: ignore
Expand All @@ -127,10 +131,9 @@ def get_selected_codes(self) -> dict[str, dict]:
}

def set_selected_codes(self, code_data=DEFAULT["codes"]):
with self.hold_trait_notifications():
for name, code_model in self.codes.items():
if name in code_data:
code_model.set_model_state(code_data[name])
for name, code_model in self.codes.items():
if name in code_data and code_model.is_active:
code_model.set_model_state(code_data[name])

def reset(self):
"""Reset the model to its default state."""
Expand All @@ -145,21 +148,22 @@ def update_submission_blockers(self):

def _check_submission_blockers(self):
# No pw code selected (this is ignored while the setup process is running).
pw_code = self._model.get_code("quantumespresso.pw")
pw_code = self.get_code("quantumespresso.pw")
if pw_code and not pw_code.selected and not self.installing_qe:
yield ("No pw code selected")

# code related to the selected property is not installed
properties = self._get_properties()
message = "Calculating the {property} property requires code {code} to be set."
for identifier, codes in self.get_model("basic").codes.items():
for identifier, code_names in self.plugin_mapping.items():
if identifier in properties:
for code in codes.values():
for name in code_names:
code = self.get_code(name)
if not code.is_ready:
yield message.format(property=identifier, code=code.description)

# check if the QEAppComputationalResourcesWidget is used
for name, code in self.get_model("basic").codes.items():
for name, code in self.codes.items():
# skip if the code is not displayed, convenient for the plugin developer
if not code.is_ready:
continue
Expand Down Expand Up @@ -241,7 +245,6 @@ def check_resources(self):
+ suggestions["go_remote"]
+ "</ul>"
)
print("alert_message: ", alert_message)

self.submission_warning_messages = (
""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
QEAppComputationalResourcesWidget,
)

from .model import BasicCodeModel
from .model import GlobalCodeModel

DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore


class BasicCodeSettings(SettingsPanel[BasicCodeModel]):
title = "Basic"
identifier = "basic"
class GlobalCodeSettings(SettingsPanel[GlobalCodeModel]):
title = "Global settings"
identifier = "global"

def __init__(self, model: BasicCodeModel, **kwargs):
def __init__(self, model: GlobalCodeModel, **kwargs):
super().__init__(model, **kwargs)
self._set_up_codes()
self._model.observe(
Expand Down Expand Up @@ -77,17 +77,16 @@ def _on_code_activation_change(self, change):

def _on_code_selection_change(self, _):
""""""
# TODO: update the selected code in the input parameters
# self._model.update_submission_blockers()
self._model.update_submission_blockers()

def _on_pw_code_resource_change(self, _):
self._model.check_resources()

def _on_code_resource_change(self, _):
"""Update the plugin resources."""
# trigger the update of basic codes
self._model.basic_codes = {}
self._model.basic_codes = self._model.get_model_state()["codes"]
# trigger the update of global codes
self._model.global_codes = {}
self._model.global_codes = self._model.get_model_state()["codes"]

def _set_up_codes(self):
codes: PluginCodes = {
Expand All @@ -105,14 +104,14 @@ def _set_up_codes(self):
codes[identifier] = data["model"].codes
for identifier, code_models in codes.items():
for _, code_model in code_models.items():
# use the new code model created using the basic code model
code_model = self._model.add_code(identifier, code_model)
if code_model is not None:
code_model.observe(
# use the new code model created using the global code model
base_code_model = self._model.add_code(identifier, code_model)
if base_code_model is not None:
base_code_model.observe(
self._on_code_activation_change,
"is_active",
)
code_model.observe(
base_code_model.observe(
self._on_code_selection_change,
"selected",
)
Expand Down Expand Up @@ -140,7 +139,7 @@ def _render_code_widget(
code_model: CodeModel,
code_widget: QEAppComputationalResourcesWidget,
):
code_model.update()
code_model.update(None)
ipw.dlink(
(code_model, "options"),
(code_widget.code_selection.code_select_dropdown, "options"),
Expand Down
25 changes: 14 additions & 11 deletions src/aiidalab_qe/app/submission/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._default_models = {
"basic",
"global",
}

self._ALERT_MESSAGE = """
Expand Down Expand Up @@ -78,7 +78,7 @@ def refresh_codes(self):

def update_active_models(self):
for identifier, model in self.get_models():
if identifier in ["basic"]:
if identifier in ["global"]:
continue
if identifier not in self._get_properties():
model.include = False
Expand Down Expand Up @@ -164,11 +164,12 @@ def set_model_state(self, parameters):
}
workchain_parameters: dict = parameters.get("workchain", {})
properties = set(workchain_parameters.get("properties", []))
print("codes: ", parameters["codes"])
with self.hold_trait_notifications():
for identifier, model in self._models.items():
model.include = identifier in self._default_models | properties
if parameters.get(identifier):
model.set_model_state(parameters[identifier])
if parameters["codes"].get(identifier):
model.set_model_state(parameters["codes"][identifier]["codes"])
model.loaded_from_process = True

if self.process_node:
Expand All @@ -179,7 +180,7 @@ def set_model_state(self, parameters):
def get_selected_codes(self) -> dict[str, dict]:
return {
name: code_model.get_model_state()
for name, code_model in self.get_model("basic").codes.items()
for name, code_model in self.get_model("global").codes.items()
if code_model.is_ready
}

Expand Down Expand Up @@ -240,16 +241,18 @@ def _create_builder(self, parameters) -> ProcessBuilderNamespace:
parameters=deepcopy(parameters), # TODO why deepcopy again?
)

codes = parameters["codes"]
codes = parameters["codes"]["global"]["codes"]

builder.relax.base.pw.metadata.options.resources = {
"num_machines": codes.get("pw")["nodes"],
"num_mpiprocs_per_machine": codes.get("pw")["ntasks_per_node"],
"num_cores_per_mpiproc": codes.get("pw")["cpus_per_task"],
"num_machines": codes.get("quantumespresso.pw")["nodes"],
"num_mpiprocs_per_machine": codes.get("quantumespresso.pw")[
"ntasks_per_node"
],
"num_cores_per_mpiproc": codes.get("quantumespresso.pw")["cpus_per_task"],
}
mws = codes.get("pw")["max_wallclock_seconds"]
mws = codes.get("quantumespresso.pw")["max_wallclock_seconds"]
builder.relax.base.pw.metadata.options["max_wallclock_seconds"] = mws
parallelization = codes["pw"]["parallelization"]
parallelization = codes["quantumespresso.pw"]["parallelization"]
builder.relax.base.pw.parallelization = orm.Dict(dict=parallelization)

return builder
Expand Down
4 changes: 2 additions & 2 deletions src/aiidalab_qe/common/code/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def activate(self):
def deactivate(self):
self.is_active = False

def update(self, user_email: str = None):
def update(self, user_email: str):
if not self.options:
self.options = self._get_codes(user_email)
self.selected = self.options[0][1] if self.options else None
Expand Down Expand Up @@ -88,7 +88,7 @@ def _get_uuid(self, identifier):
self.selected = uuid if uuid in [opt[1] for opt in self.options] else None
return self.selected

def _get_codes(self, user_email: str = None):
def _get_codes(self, user_email: str):
# set default user_email if not provided
user_email = user_email or orm.User.collection.get_default().email
user = orm.User.collection.get(email=user_email)
Expand Down
Loading

0 comments on commit cfe0ae7

Please sign in to comment.