Skip to content

Commit

Permalink
Fix models
Browse files Browse the repository at this point in the history
  • Loading branch information
edan-bainglass committed Oct 29, 2024
1 parent dd00dd5 commit 608467b
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 232 deletions.
4 changes: 2 additions & 2 deletions src/aiidalab_qe/app/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def __init__(self, model: ConfigurationModel, **kwargs):

self.settings: dict[str, SettingsPanel] = {}

workchain_model = WorkChainModel(include=True)
workchain_model = WorkChainModel()
self._model.add_model("workchain", workchain_model)

advanced_model = AdvancedModel(include=True)
advanced_model = AdvancedModel()
self._model.add_model("advanced", advanced_model)

self._fetch_plugin_settings()
Expand Down
2 changes: 1 addition & 1 deletion src/aiidalab_qe/app/configuration/advanced/advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def _on_input_structure_change(self, _):
def _on_protocol_change(self, _):
self.refresh(specific="protocol")

def _on_kpoints_distance_change(self, _=None):
def _on_kpoints_distance_change(self, _):
self.refresh(specific="mesh")

def _on_override_change(self, change):
Expand Down
27 changes: 12 additions & 15 deletions src/aiidalab_qe/app/configuration/advanced/hubbard/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,19 @@ def __init__(self, *args, **kwargs):
}

def update(self, specific=""):
if self.input_structure is None:
self.applicable_kinds = []
self.orbital_labels = []
self._defaults |= {
"parameters": {},
"eigenvalues": [],
}
else:
self.orbital_labels = self._define_orbital_labels()
self._defaults["parameters"] = self._define_default_parameters()
self.applicable_kinds = self._define_applicable_kinds()
self._defaults["eigenvalues"] = self._define_default_eigenvalues()
with self.hold_trait_notifications():
self._update_defaults(specific)
self.parameters = self._get_default_parameters()
self.eigenvalues = self._get_default_eigenvalues()
self.needs_eigenvalues_widget = len(self.applicable_kinds) > 0
Expand Down Expand Up @@ -82,20 +93,6 @@ def reset(self):
self.parameters = self._get_default_parameters()
self.eigenvalues = self._get_default_eigenvalues()

def _update_defaults(self, specific=""):
if self.input_structure is None:
self.applicable_kinds = []
self.orbital_labels = []
self._defaults |= {
"parameters": {},
"eigenvalues": [],
}
else:
self.orbital_labels = self._define_orbital_labels()
self._defaults["parameters"] = self._define_default_parameters()
self.applicable_kinds = self._define_applicable_kinds()
self._defaults["eigenvalues"] = self._define_default_eigenvalues()

def _define_orbital_labels(self):
hubbard_manifold_list = [
self._get_manifold(Element(kind.symbol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ def __init__(self, *args, **kwargs):
}

def update(self, specific=""):
if self.spin_type == "none" or self.input_structure is None:
self._defaults["moments"] = {}
else:
self._defaults["moments"] = {
symbol: 0.0 for symbol in self.input_structure.get_kind_names()
}
with self.hold_trait_notifications():
self._update_defaults(specific)
self.moments = self._get_default_moments()

def reset(self):
Expand All @@ -45,13 +50,5 @@ def reset(self):
self.total = self.traits()["total"].default_value
self.moments = self._get_default_moments()

def _update_defaults(self, specific=""):
if self.spin_type == "none" or self.input_structure is None:
self._defaults["moments"] = {}
else:
self._defaults["moments"] = {
symbol: 0.0 for symbol in self.input_structure.get_kind_names()
}

def _get_default_moments(self):
return deepcopy(self._defaults["moments"])
64 changes: 30 additions & 34 deletions src/aiidalab_qe/app/configuration/advanced/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ class AdvancedModel(SettingsModel):
kpoints_distance = tl.Float(0.0)
mesh_grid = tl.Unicode("")

def __init__(self, include=False, *args, **kwargs):
super().__init__(include, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.include = True

self.dftd3_version = {
"dft-d3": 3,
Expand All @@ -84,19 +86,21 @@ def __init__(self, include=False, *args, **kwargs):

def update(self, specific=""):
with self.hold_trait_notifications():
self._update_defaults(specific)
self.forc_conv_thr = self._defaults["forc_conv_thr"]
self.forc_conv_thr_step = self._defaults["forc_conv_thr_step"]
self.etot_conv_thr = self._defaults["etot_conv_thr"]
self.etot_conv_thr_step = self._defaults["etot_conv_thr_step"]
self.scf_conv_thr = self._defaults["scf_conv_thr"]
self.scf_conv_thr_step = self._defaults["scf_conv_thr_step"]
self.kpoints_distance = self._defaults["kpoints_distance"]
if not specific or specific != "mesh":
parameters = PwBaseWorkChain.get_protocol_inputs(self.protocol)
self._update_kpoints_distance(parameters)
self._update_thresholds(parameters)
self._update_kpoints_mesh()

def add_model(self, identifier, model):
self._models[identifier] = model
self._link_model(model)

def get_model(self, identifier) -> AdvancedSubModel:
if identifier in self._models:
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")

def get_models(self):
return self._models.items()

Expand All @@ -122,15 +126,15 @@ def get_model_state(self):
"kpoints_distance": self.kpoints_distance,
}

hubbard: HubbardModel = self._get_model("hubbard") # type: ignore
hubbard: HubbardModel = self.get_model("hubbard") # type: ignore
if hubbard.is_active:
parameters["hubbard_parameters"] = {"hubbard_u": hubbard.parameters}
if hubbard.has_eigenvalues:
parameters["pw"]["parameters"]["SYSTEM"] |= {
"starting_ns_eigenvalue": hubbard.get_active_eigenvalues()
}

pseudos: PseudosModel = self._get_model("pseudos") # type: ignore
pseudos: PseudosModel = self.get_model("pseudos") # type: ignore
parameters["pseudo_family"] = pseudos.family
if pseudos.dictionary:
parameters["pw"]["pseudos"] = pseudos.dictionary
Expand All @@ -145,8 +149,8 @@ def get_model_state(self):
self.dftd3_version[self.van_der_waals]
)

smearing: SmearingModel = self._get_model("smearing") # type: ignore
magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore
smearing: SmearingModel = self.get_model("smearing") # type: ignore
magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore
if self.spin_type == "collinear":
parameters["initial_magnetic_moments"] = magnetization.moments
if self.electronic_type == "metal":
Expand Down Expand Up @@ -181,7 +185,7 @@ def get_model_state(self):
return parameters

def set_model_state(self, parameters):
pseudos: PseudosModel = self._get_model("pseudos") # type: ignore
pseudos: PseudosModel = self.get_model("pseudos") # type: ignore
if "pseudo_family" in parameters:
pseudo_family = PseudoFamily.from_string(parameters["pseudo_family"])
library = pseudo_family.library
Expand All @@ -199,7 +203,7 @@ def set_model_state(self, parameters):
if (pw_parameters := parameters.get("pw", {}).get("parameters")) is not None:
self._set_pw_parameters(pw_parameters)

magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore
magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore
if magnetic_moments := parameters.get("initial_magnetic_moments"):
if isinstance(magnetic_moments, (int, float)):
magnetic_moments = [magnetic_moments]
Expand All @@ -212,7 +216,7 @@ def set_model_state(self, parameters):
)
magnetization.moments = magnetic_moments

hubbard: HubbardModel = self._get_model("hubbard") # type: ignore
hubbard: HubbardModel = self.get_model("hubbard") # type: ignore
if parameters.get("hubbard_parameters"):
hubbard.is_active = True
hubbard.parameters = parameters["hubbard_parameters"]["hubbard_u"]
Expand Down Expand Up @@ -256,16 +260,6 @@ def _link_model(self, model):
(model, trait),
)

def _update_defaults(self, specific=""):
if not specific or specific != "mesh":
parameters = PwBaseWorkChain.get_protocol_inputs(self.protocol)
self._update_kpoints_distance(parameters)

self._update_kpoints_mesh()

if not specific or specific == "protocol":
self._update_thresholds(parameters)

def _update_kpoints_mesh(self, _=None):
if self.input_structure is None:
mesh_grid = ""
Expand All @@ -284,18 +278,25 @@ def _update_kpoints_mesh(self, _=None):
def _update_kpoints_distance(self, parameters):
kpoints_distance = parameters["kpoints_distance"] if self.has_pbc else 100.0
self._defaults["kpoints_distance"] = kpoints_distance
self.kpoints_distance = self._defaults["kpoints_distance"]

def _update_thresholds(self, parameters):
num_atoms = len(self.input_structure.sites) if self.input_structure else 1

etot_value = num_atoms * parameters["meta_parameters"]["etot_conv_thr_per_atom"]
self._set_value_and_step("etot_conv_thr", etot_value)
self.etot_conv_thr = self._defaults["etot_conv_thr"]
self.etot_conv_thr_step = self._defaults["etot_conv_thr_step"]

scf_value = num_atoms * parameters["meta_parameters"]["conv_thr_per_atom"]
self._set_value_and_step("scf_conv_thr", scf_value)
self.scf_conv_thr = self._defaults["scf_conv_thr"]
self.scf_conv_thr_step = self._defaults["scf_conv_thr_step"]

forc_value = parameters["pw"]["parameters"]["CONTROL"]["forc_conv_thr"]
self._set_value_and_step("forc_conv_thr", forc_value)
self.forc_conv_thr = self._defaults["forc_conv_thr"]
self.forc_conv_thr_step = self._defaults["forc_conv_thr_step"]

def _set_value_and_step(self, attribute, value):
self._defaults[attribute] = value
Expand Down Expand Up @@ -324,18 +325,13 @@ def _set_pw_parameters(self, pw_parameters):
system_params.get("vdw_corr", "none"),
)

smearing: SmearingModel = self._get_model("smearing") # type: ignore
smearing: SmearingModel = self.get_model("smearing") # type: ignore
if "degauss" in system_params:
smearing.degauss = system_params["degauss"]

if "smearing" in system_params:
smearing.type = system_params["smearing"]

magnetization: MagnetizationModel = self._get_model("magnetization") # type: ignore
magnetization: MagnetizationModel = self.get_model("magnetization") # type: ignore
if "tot_magnetization" in system_params:
magnetization.type = "tot_magnetization"

def _get_model(self, identifier) -> AdvancedSubModel:
if identifier in self._models:
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")
23 changes: 10 additions & 13 deletions src/aiidalab_qe/app/configuration/advanced/pseudos/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,16 @@ def __init__(self, *args, **kwargs):

def update(self, specific=""):
with self.hold_trait_notifications():
self._update_defaults(specific)
if self.input_structure is None:
self._defaults |= {
"dictionary": {},
"cutoffs": [[0.0], [0.0]],
}
else:
self.update_default_pseudos()
self.update_default_cutoffs()
self.update_family_parameters()
self.update_family()

def update_default_pseudos(self):
try:
Expand Down Expand Up @@ -262,18 +271,6 @@ def reset(self):
self.family_help_message = self.PSEUDO_HELP_WO_SOC
self.status_message = ""

def _update_defaults(self, specific=""):
if self.input_structure is None:
self._defaults |= {
"dictionary": {},
"cutoffs": [[0.0], [0.0]],
}
else:
self.update_default_pseudos()
self.update_default_cutoffs()
self.update_family_parameters()
self.update_family()

def _get_pseudo_family_from_database(self):
"""Get the pseudo family from the database."""
return (
Expand Down
19 changes: 8 additions & 11 deletions src/aiidalab_qe/app/configuration/advanced/smearing/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,6 @@ def __init__(self, *args, **kwargs):
}

def update(self, specific=""):
with self.hold_trait_notifications():
self._update_defaults(specific)
self.type = self._defaults["type"]
self.degauss = self._defaults["degauss"]

def reset(self):
with self.hold_trait_notifications():
self.type = self._defaults["type"]
self.degauss = self._defaults["degauss"]

def _update_defaults(self, specific=""):
parameters = (
PwBaseWorkChain.get_protocol_inputs(self.protocol)
.get("pw", {})
Expand All @@ -48,3 +37,11 @@ def _update_defaults(self, specific=""):
"type": parameters["smearing"],
"degauss": parameters["degauss"],
}
with self.hold_trait_notifications():
self.type = self._defaults["type"]
self.degauss = self._defaults["degauss"]

def reset(self):
with self.hold_trait_notifications():
self.type = self._defaults["type"]
self.degauss = self._defaults["degauss"]
15 changes: 5 additions & 10 deletions src/aiidalab_qe/app/configuration/advanced/subsettings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import ipywidgets as ipw
import traitlets as tl

Expand All @@ -21,16 +23,6 @@ def reset(self):
"""Resets the model to present defaults."""
raise NotImplementedError

def _update_defaults(self, specific=""):
"""Updates the model's default values.
Parameters
----------
`specific` : `str`, optional
If provided, specifies the level of update.
"""
raise NotImplementedError


class AdvancedSubSettings(ipw.VBox):
identifier = "sub"
Expand Down Expand Up @@ -74,6 +66,9 @@ def refresh(self, specific=""):
self.updated = False
self._unsubscribe()
self._update(specific)
if "PYTEST_CURRENT_TEST" in os.environ:
# Skip resetting to avoid having to inject a structure when testing
return
if hasattr(self._model, "input_structure") and not self._model.input_structure:
self._reset()

Expand Down
6 changes: 4 additions & 2 deletions src/aiidalab_qe/app/configuration/basic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ class WorkChainModel(SettingsModel):
spin_type = tl.Unicode(DEFAULT["workchain"]["spin_type"])
electronic_type = tl.Unicode(DEFAULT["workchain"]["electronic_type"])

def __init__(self, include=False, *args, **kwargs):
super().__init__(include, *args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.include = True

self._defaults = {
"protocol": self.traits()["protocol"].default_value,
Expand Down
Loading

0 comments on commit 608467b

Please sign in to comment.