From 76cc60e4d6dc81382b1af1de6fc156c51238fc3d Mon Sep 17 00:00:00 2001 From: Edan Bainglass Date: Fri, 1 Nov 2024 12:03:18 +0000 Subject: [PATCH] Reduce code duplication with mixins --- .../app/configuration/advanced/model.py | 34 ++------ src/aiidalab_qe/app/configuration/model.py | 42 +++------- src/aiidalab_qe/app/structure/model.py | 17 +--- src/aiidalab_qe/app/submission/model.py | 11 +-- src/aiidalab_qe/common/mixins.py | 84 +++++++++++++++++++ src/aiidalab_qe/common/panel.py | 19 +---- src/aiidalab_qe/plugins/pdos/model.py | 4 +- 7 files changed, 114 insertions(+), 97 deletions(-) create mode 100644 src/aiidalab_qe/common/mixins.py diff --git a/src/aiidalab_qe/app/configuration/advanced/model.py b/src/aiidalab_qe/app/configuration/advanced/model.py index 246a74f2e..88829ea75 100644 --- a/src/aiidalab_qe/app/configuration/advanced/model.py +++ b/src/aiidalab_qe/app/configuration/advanced/model.py @@ -10,23 +10,28 @@ from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import ( create_kpoints_from_distance, ) -from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS +from aiidalab_qe.common.mixins import HasInputStructure, HasModels from aiidalab_qe.common.panel import SettingsModel from aiidalab_qe.setup.pseudos import PseudoFamily +from .subsettings import AdvancedSubModel + if t.TYPE_CHECKING: from .hubbard.hubbard import HubbardModel from .magnetization import MagnetizationModel from .pseudos.pseudos import PseudosModel from .smearing import SmearingModel - from .subsettings import AdvancedSubModel DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore -class AdvancedModel(SettingsModel): +class AdvancedModel( + SettingsModel, + HasModels[AdvancedSubModel], + HasInputStructure, +): dependencies = [ "input_structure", "workchain.protocol", @@ -34,13 +39,6 @@ class AdvancedModel(SettingsModel): "workchain.electronic_type", ] - input_structure = tl.Union( - [ - tl.Instance(orm.StructureData), - tl.Instance(HubbardStructureData), - ], - allow_none=True, - ) protocol = tl.Unicode() spin_type = tl.Unicode() electronic_type = tl.Unicode() @@ -72,8 +70,6 @@ def __init__(self, *args, **kwargs): "dft-d3mbj": 6, } - self._models: dict[str, AdvancedSubModel] = {} - self._defaults = { "forc_conv_thr": self.traits()["forc_conv_thr"].default_value, "forc_conv_thr_step": self.traits()["forc_conv_thr_step"].default_value, @@ -92,18 +88,6 @@ def update(self, specific=""): self._update_thresholds(parameters) self._update_kpoints_mesh() - def add_model(self, identifier, model): - self._models[identifier] = model - self._link_model(model) - - def get_model(self, identifier) -> AdvancedSubModel: - if identifier in self._models: - return self._models[identifier] - raise ValueError(f"Model with identifier '{identifier}' not found.") - - def get_models(self): - return self._models.items() - def get_model_state(self): parameters = { "initial_magnetic_moments": None, @@ -245,7 +229,7 @@ def reset(self): self.kpoints_distance = self._defaults["kpoints_distance"] self.override = self.traits()["override"].default_value - def _link_model(self, model): + def _link_model(self, model: AdvancedSubModel): ipw.dlink( (self, "override"), (model, "override"), diff --git a/src/aiidalab_qe/app/configuration/model.py b/src/aiidalab_qe/app/configuration/model.py index ab5e15faa..2b8b0b272 100644 --- a/src/aiidalab_qe/app/configuration/model.py +++ b/src/aiidalab_qe/app/configuration/model.py @@ -3,35 +3,32 @@ import ipywidgets as ipw import traitlets as tl -from aiida import orm from aiida_quantumespresso.common.types import RelaxType -from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS +from aiidalab_qe.common.mixins import ( + Confirmable, + HasInputStructure, + HasModels, + HasTraitsAndMixins, +) from aiidalab_qe.common.panel import SettingsModel DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore -class ConfigurationModel(SettingsModel): - input_structure = tl.Union( - [ - tl.Instance(orm.StructureData), - tl.Instance(HubbardStructureData), - ], - allow_none=True, - ) - +class ConfigurationModel( + HasTraitsAndMixins, + HasModels[SettingsModel], + Confirmable, + HasInputStructure, +): relax_type_help = tl.Unicode() relax_type_options = tl.List([DEFAULT["workchain"]["relax_type"]]) relax_type = tl.Unicode(DEFAULT["workchain"]["relax_type"]) - confirmed = tl.Bool(False) - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._models: dict[str, SettingsModel] = {} - self._default_models = { "workchain", "advanced", @@ -86,18 +83,6 @@ def update(self): self.relax_type_options = self._get_default_relax_type_options() self.relax_type = self._get_default_relax_type() - def add_model(self, identifier, model): - self._models[identifier] = model - self._link_model(model) - - def get_model(self, identifier) -> SettingsModel: - if identifier in self._models: - return self._models[identifier] - raise ValueError(f"Model with identifier '{identifier}' not found.") - - def get_models(self): - return self._models.items() - def get_model_state(self): parameters = { identifier: model.get_model_state() @@ -120,9 +105,6 @@ def set_model_state(self, parameters): if parameters.get(identifier): model.set_model_state(parameters[identifier]) - def confirm(self): - self.confirmed = True - def reset(self): self.confirmed = False self.relax_type_help = self._get_default_relax_type_help() diff --git a/src/aiidalab_qe/app/structure/model.py b/src/aiidalab_qe/app/structure/model.py index f85ec507f..86b549bb1 100644 --- a/src/aiidalab_qe/app/structure/model.py +++ b/src/aiidalab_qe/app/structure/model.py @@ -1,9 +1,10 @@ import traitlets as tl from aiida import orm +from aiidalab_qe.common.mixins import Confirmable, HasTraitsAndMixins -class StructureModel(tl.HasTraits): +class StructureModel(HasTraitsAndMixins, Confirmable): structure = tl.Instance( orm.StructureData, allow_none=True, @@ -11,14 +12,6 @@ class StructureModel(tl.HasTraits): structure_name = tl.Unicode("") manager_output = tl.Unicode("") message_area = tl.Unicode("") - confirmed = tl.Bool(False) - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.observe( - self._unconfirm, - "structure", - ) def update_widget_text(self): if self.structure is None: @@ -28,14 +21,8 @@ def update_widget_text(self): self.manager_output = "" self.structure_name = str(self.structure.get_formula()) - def confirm(self): - self.confirmed = True - def reset(self): self.structure = None self.structure_name = "" self.manager_output = "" self.message_area = "" - - def _unconfirm(self, _): - self.confirmed = False diff --git a/src/aiidalab_qe/app/submission/model.py b/src/aiidalab_qe/app/submission/model.py index decb379d0..3c253c431 100644 --- a/src/aiidalab_qe/app/submission/model.py +++ b/src/aiidalab_qe/app/submission/model.py @@ -10,8 +10,8 @@ from aiida.engine import ProcessBuilderNamespace from aiida.engine import submit as aiida_submit from aiida.orm.utils.serialize import serialize -from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS +from aiidalab_qe.common.mixins import HasInputStructure, HasTraitsAndMixins from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget from aiidalab_qe.workflows import QeAppWorkChain @@ -20,14 +20,7 @@ DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore -class SubmissionModel(tl.HasTraits): - input_structure = tl.Union( - [ - tl.Instance(orm.StructureData), - tl.Instance(HubbardStructureData), - ], - allow_none=True, - ) +class SubmissionModel(HasTraitsAndMixins, HasInputStructure): input_parameters = tl.Dict() process = tl.Instance(orm.WorkChainNode, allow_none=True) diff --git a/src/aiidalab_qe/common/mixins.py b/src/aiidalab_qe/common/mixins.py new file mode 100644 index 000000000..9efcf297a --- /dev/null +++ b/src/aiidalab_qe/common/mixins.py @@ -0,0 +1,84 @@ +import typing as t + +import traitlets as tl + +from aiida import orm +from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData + +T = t.TypeVar("T") + + +class MissingMixinError(Exception): + """Raised when no mixin is found in a class definition.""" + + def __init__(self, *args: object) -> None: + super().__init__("expected at least one mixin in class definion", *args) + + +class MetaMixinTraitsRegister(tl.MetaHasTraits): + """A metaclass to register traits from `HasTraits`-subclassed mixins. + + The metaclass removes the `HasTraits` base class from the mixin to + avoid MRO conflicts. + """ + + def __new__(cls, name, bases, classdict): + if len(bases) == 1 and not issubclass(bases[0], tl.HasTraits): + raise MissingMixinError() + for base in bases: + if issubclass(base, tl.HasTraits): + for name, trait in base.class_traits().items(): + if name not in classdict: + classdict[name] = trait + bases = tuple(filter(lambda base: base is not tl.HasTraits, bases)) + return super().__new__(cls, name, bases, classdict) + + +class HasTraitsAndMixins(tl.HasTraits, metaclass=MetaMixinTraitsRegister): + """An extension of `traitlet`'s `HasTraits` to support trait-ful mixins.""" + + +class Confirmable(tl.HasTraits): + confirmed = tl.Bool(False) + + def confirm(self): + self.confirmed = True + + @tl.observe(tl.All) + def unconfirm(self, change=None): + if change and change["name"] != "confirmed": + self.confirmed = False + + +class HasInputStructure(tl.HasTraits): + input_structure = tl.Union( + [ + tl.Instance(orm.StructureData), + tl.Instance(HubbardStructureData), + ], + allow_none=True, + ) + + @property + def has_pbc(self): + return self.input_structure is None or any(self.input_structure.pbc) + + +class HasModels(t.Generic[T]): + def __init__(self): + self._models: dict[str, T] = {} + + def add_model(self, identifier, model): + self._models[identifier] = model + self._link_model(model) + + def get_model(self, identifier) -> T: + if identifier in self._models: + return self._models[identifier] + raise ValueError(f"Model with identifier '{identifier}' not found.") + + def get_models(self) -> t.Iterable[tuple[str, T]]: + return self._models.items() + + def _link_model(self, model: T): + raise NotImplementedError() diff --git a/src/aiidalab_qe/common/panel.py b/src/aiidalab_qe/common/panel.py index 878a51db6..ba011a148 100644 --- a/src/aiidalab_qe/common/panel.py +++ b/src/aiidalab_qe/common/panel.py @@ -12,6 +12,8 @@ import ipywidgets as ipw import traitlets as tl +from aiidalab_qe.common.mixins import Confirmable, HasTraitsAndMixins + DEFAULT_PARAMETERS = {} @@ -66,25 +68,14 @@ def __init__(self, **kwargs): ) -class SettingsModel(tl.HasTraits): +class SettingsModel(HasTraitsAndMixins, Confirmable): title = "Model" dependencies: list[str] = [] include = tl.Bool() - confirmed = tl.Bool(False) _defaults = {} - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.observe(self.unconfirm, tl.All) - - @property - def has_pbc(self): - if hasattr(self, "input_structure"): - return self.input_structure is None or any(self.input_structure.pbc) - return False - def update(self, specific=""): """Updates the model. @@ -107,10 +98,6 @@ def reset(self): """Resets the model to present defaults.""" pass - def unconfirm(self, change): - if change["name"] != "confirmed": - self.confirmed = False - class SettingsPanel(Panel): title = "Settings" diff --git a/src/aiidalab_qe/plugins/pdos/model.py b/src/aiidalab_qe/plugins/pdos/model.py index 8c664acae..b64dea435 100644 --- a/src/aiidalab_qe/plugins/pdos/model.py +++ b/src/aiidalab_qe/plugins/pdos/model.py @@ -5,16 +5,16 @@ create_kpoints_from_distance, ) from aiida_quantumespresso.workflows.pdos import PdosWorkChain +from aiidalab_qe.common.mixins import HasInputStructure from aiidalab_qe.common.panel import SettingsModel -class PdosModel(SettingsModel): +class PdosModel(SettingsModel, HasInputStructure): dependencies = [ "input_structure", "workchain.protocol", ] - input_structure = tl.Instance(orm.StructureData, allow_none=True) protocol = tl.Unicode(allow_none=True) kpoints_distance = tl.Float(0.1)