Skip to content

Commit

Permalink
Reduce code duplication with mixins
Browse files Browse the repository at this point in the history
  • Loading branch information
edan-bainglass committed Nov 1, 2024
1 parent 4b32281 commit 76cc60e
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 97 deletions.
34 changes: 9 additions & 25 deletions src/aiidalab_qe/app/configuration/advanced/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,35 @@
from aiida_quantumespresso.calculations.functions.create_kpoints_from_distance import (
create_kpoints_from_distance,
)
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain
from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.common.mixins import HasInputStructure, HasModels
from aiidalab_qe.common.panel import SettingsModel
from aiidalab_qe.setup.pseudos import PseudoFamily

from .subsettings import AdvancedSubModel

if t.TYPE_CHECKING:
from .hubbard.hubbard import HubbardModel
from .magnetization import MagnetizationModel
from .pseudos.pseudos import PseudosModel
from .smearing import SmearingModel
from .subsettings import AdvancedSubModel

DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore


class AdvancedModel(SettingsModel):
class AdvancedModel(
SettingsModel,
HasModels[AdvancedSubModel],
HasInputStructure,
):
dependencies = [
"input_structure",
"workchain.protocol",
"workchain.spin_type",
"workchain.electronic_type",
]

input_structure = tl.Union(
[
tl.Instance(orm.StructureData),
tl.Instance(HubbardStructureData),
],
allow_none=True,
)
protocol = tl.Unicode()
spin_type = tl.Unicode()
electronic_type = tl.Unicode()
Expand Down Expand Up @@ -72,8 +70,6 @@ def __init__(self, *args, **kwargs):
"dft-d3mbj": 6,
}

self._models: dict[str, AdvancedSubModel] = {}

self._defaults = {
"forc_conv_thr": self.traits()["forc_conv_thr"].default_value,
"forc_conv_thr_step": self.traits()["forc_conv_thr_step"].default_value,
Expand All @@ -92,18 +88,6 @@ def update(self, specific=""):
self._update_thresholds(parameters)
self._update_kpoints_mesh()

def add_model(self, identifier, model):
self._models[identifier] = model
self._link_model(model)

def get_model(self, identifier) -> AdvancedSubModel:
if identifier in self._models:
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")

def get_models(self):
return self._models.items()

def get_model_state(self):
parameters = {
"initial_magnetic_moments": None,
Expand Down Expand Up @@ -245,7 +229,7 @@ def reset(self):
self.kpoints_distance = self._defaults["kpoints_distance"]
self.override = self.traits()["override"].default_value

def _link_model(self, model):
def _link_model(self, model: AdvancedSubModel):
ipw.dlink(
(self, "override"),
(model, "override"),
Expand Down
42 changes: 12 additions & 30 deletions src/aiidalab_qe/app/configuration/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,32 @@
import ipywidgets as ipw
import traitlets as tl

from aiida import orm
from aiida_quantumespresso.common.types import RelaxType
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.common.mixins import (
Confirmable,
HasInputStructure,
HasModels,
HasTraitsAndMixins,
)
from aiidalab_qe.common.panel import SettingsModel

DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore


class ConfigurationModel(SettingsModel):
input_structure = tl.Union(
[
tl.Instance(orm.StructureData),
tl.Instance(HubbardStructureData),
],
allow_none=True,
)

class ConfigurationModel(
HasTraitsAndMixins,
HasModels[SettingsModel],
Confirmable,
HasInputStructure,
):
relax_type_help = tl.Unicode()
relax_type_options = tl.List([DEFAULT["workchain"]["relax_type"]])
relax_type = tl.Unicode(DEFAULT["workchain"]["relax_type"])

confirmed = tl.Bool(False)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._models: dict[str, SettingsModel] = {}

self._default_models = {
"workchain",
"advanced",
Expand Down Expand Up @@ -86,18 +83,6 @@ def update(self):
self.relax_type_options = self._get_default_relax_type_options()
self.relax_type = self._get_default_relax_type()

def add_model(self, identifier, model):
self._models[identifier] = model
self._link_model(model)

def get_model(self, identifier) -> SettingsModel:
if identifier in self._models:
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")

def get_models(self):
return self._models.items()

def get_model_state(self):
parameters = {
identifier: model.get_model_state()
Expand All @@ -120,9 +105,6 @@ def set_model_state(self, parameters):
if parameters.get(identifier):
model.set_model_state(parameters[identifier])

def confirm(self):
self.confirmed = True

def reset(self):
self.confirmed = False
self.relax_type_help = self._get_default_relax_type_help()
Expand Down
17 changes: 2 additions & 15 deletions src/aiidalab_qe/app/structure/model.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import traitlets as tl

from aiida import orm
from aiidalab_qe.common.mixins import Confirmable, HasTraitsAndMixins


class StructureModel(tl.HasTraits):
class StructureModel(HasTraitsAndMixins, Confirmable):
structure = tl.Instance(
orm.StructureData,
allow_none=True,
)
structure_name = tl.Unicode("")
manager_output = tl.Unicode("")
message_area = tl.Unicode("")
confirmed = tl.Bool(False)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.observe(
self._unconfirm,
"structure",
)

def update_widget_text(self):
if self.structure is None:
Expand All @@ -28,14 +21,8 @@ def update_widget_text(self):
self.manager_output = ""
self.structure_name = str(self.structure.get_formula())

def confirm(self):
self.confirmed = True

def reset(self):
self.structure = None
self.structure_name = ""
self.manager_output = ""
self.message_area = ""

def _unconfirm(self, _):
self.confirmed = False
11 changes: 2 additions & 9 deletions src/aiidalab_qe/app/submission/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from aiida.engine import ProcessBuilderNamespace
from aiida.engine import submit as aiida_submit
from aiida.orm.utils.serialize import serialize
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData
from aiidalab_qe.app.parameters import DEFAULT_PARAMETERS
from aiidalab_qe.common.mixins import HasInputStructure, HasTraitsAndMixins
from aiidalab_qe.common.widgets import QEAppComputationalResourcesWidget
from aiidalab_qe.workflows import QeAppWorkChain

Expand All @@ -20,14 +20,7 @@
DEFAULT: dict = DEFAULT_PARAMETERS # type: ignore


class SubmissionModel(tl.HasTraits):
input_structure = tl.Union(
[
tl.Instance(orm.StructureData),
tl.Instance(HubbardStructureData),
],
allow_none=True,
)
class SubmissionModel(HasTraitsAndMixins, HasInputStructure):
input_parameters = tl.Dict()

process = tl.Instance(orm.WorkChainNode, allow_none=True)
Expand Down
84 changes: 84 additions & 0 deletions src/aiidalab_qe/common/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import typing as t

import traitlets as tl

from aiida import orm
from aiida_quantumespresso.data.hubbard_structure import HubbardStructureData

T = t.TypeVar("T")


class MissingMixinError(Exception):
"""Raised when no mixin is found in a class definition."""

def __init__(self, *args: object) -> None:
super().__init__("expected at least one mixin in class definion", *args)


class MetaMixinTraitsRegister(tl.MetaHasTraits):
"""A metaclass to register traits from `HasTraits`-subclassed mixins.
The metaclass removes the `HasTraits` base class from the mixin to
avoid MRO conflicts.
"""

def __new__(cls, name, bases, classdict):
if len(bases) == 1 and not issubclass(bases[0], tl.HasTraits):
raise MissingMixinError()
for base in bases:
if issubclass(base, tl.HasTraits):
for name, trait in base.class_traits().items():
if name not in classdict:
classdict[name] = trait
bases = tuple(filter(lambda base: base is not tl.HasTraits, bases))
return super().__new__(cls, name, bases, classdict)


class HasTraitsAndMixins(tl.HasTraits, metaclass=MetaMixinTraitsRegister):
"""An extension of `traitlet`'s `HasTraits` to support trait-ful mixins."""


class Confirmable(tl.HasTraits):
confirmed = tl.Bool(False)

def confirm(self):
self.confirmed = True

@tl.observe(tl.All)
def unconfirm(self, change=None):
if change and change["name"] != "confirmed":
self.confirmed = False


class HasInputStructure(tl.HasTraits):
input_structure = tl.Union(
[
tl.Instance(orm.StructureData),
tl.Instance(HubbardStructureData),
],
allow_none=True,
)

@property
def has_pbc(self):
return self.input_structure is None or any(self.input_structure.pbc)


class HasModels(t.Generic[T]):
def __init__(self):
self._models: dict[str, T] = {}

def add_model(self, identifier, model):
self._models[identifier] = model
self._link_model(model)

def get_model(self, identifier) -> T:
if identifier in self._models:
return self._models[identifier]
raise ValueError(f"Model with identifier '{identifier}' not found.")

def get_models(self) -> t.Iterable[tuple[str, T]]:
return self._models.items()

def _link_model(self, model: T):
raise NotImplementedError()
19 changes: 3 additions & 16 deletions src/aiidalab_qe/common/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import ipywidgets as ipw
import traitlets as tl

from aiidalab_qe.common.mixins import Confirmable, HasTraitsAndMixins

DEFAULT_PARAMETERS = {}


Expand Down Expand Up @@ -66,25 +68,14 @@ def __init__(self, **kwargs):
)


class SettingsModel(tl.HasTraits):
class SettingsModel(HasTraitsAndMixins, Confirmable):
title = "Model"
dependencies: list[str] = []

include = tl.Bool()
confirmed = tl.Bool(False)

_defaults = {}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.observe(self.unconfirm, tl.All)

@property
def has_pbc(self):
if hasattr(self, "input_structure"):
return self.input_structure is None or any(self.input_structure.pbc)
return False

def update(self, specific=""):
"""Updates the model.
Expand All @@ -107,10 +98,6 @@ def reset(self):
"""Resets the model to present defaults."""
pass

def unconfirm(self, change):
if change["name"] != "confirmed":
self.confirmed = False


class SettingsPanel(Panel):
title = "Settings"
Expand Down
4 changes: 2 additions & 2 deletions src/aiidalab_qe/plugins/pdos/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
create_kpoints_from_distance,
)
from aiida_quantumespresso.workflows.pdos import PdosWorkChain
from aiidalab_qe.common.mixins import HasInputStructure
from aiidalab_qe.common.panel import SettingsModel


class PdosModel(SettingsModel):
class PdosModel(SettingsModel, HasInputStructure):
dependencies = [
"input_structure",
"workchain.protocol",
]

input_structure = tl.Instance(orm.StructureData, allow_none=True)
protocol = tl.Unicode(allow_none=True)

kpoints_distance = tl.Float(0.1)
Expand Down

0 comments on commit 76cc60e

Please sign in to comment.