Skip to content

Commit

Permalink
Fix issues relating to kinds
Browse files Browse the repository at this point in the history
  • Loading branch information
edan-bainglass committed Oct 28, 2024
1 parent e7a0911 commit 1dc9406
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 66 deletions.
54 changes: 39 additions & 15 deletions src/aiidalab_qe/app/configuration/advanced/hubbard/hubbard.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ def _build_hubbard_widget(self):
(self._model, "parameters"),
(float_widget, "value"),
[
lambda p, label=label: p.get(label, 0.0),
lambda v, label=label: {
lambda parameters, label=label: parameters.get(label, 0.0),
lambda value, label=label: {
**self._model.parameters,
label: v,
label: value,
},
],
)
Expand All @@ -149,24 +149,22 @@ def _build_hubbard_widget(self):

def _build_eigenvalues_widget(self):
def update(index, spin, state, symbol, value):
"""Update the eigenvalues list."""
eigenvalues = [*self._model.eigenvalues]
eigenvalues[index][spin][state] = [state + 1, spin, symbol, value]
return eigenvalues

children = []

for ei, element in enumerate(self._model.applicable_elements):
es = element.symbol
num_states = 5 if element.is_transition_metal else 7 # d or f states
for kind_index, (kind, num_states) in enumerate(self._model.applicable_kinds):
symbol = kind.symbol

label_layout = ipw.Layout(justify_content="flex-start", width="50px")
spin_up_row = ipw.HBox([ipw.Label("Up:", layout=label_layout)])
spin_down_row = ipw.HBox([ipw.Label("Down:", layout=label_layout)])

for si in range(num_states):
for state_index in range(num_states):
eigenvalues_up = ipw.Dropdown(
description=f"{si+1}",
description=f"{state_index+1}",
options=["-1", "0", "1"],
layout=ipw.Layout(width="65px"),
style={"description_width": "initial"},
Expand All @@ -175,15 +173,28 @@ def update(index, spin, state, symbol, value):
(self._model, "eigenvalues"),
(eigenvalues_up, "value"),
[
lambda evs, ei=ei, si=si: str(evs[ei][0][si][-1]),
lambda v, ei=ei, si=si, es=es: update(ei, 0, si, es, float(v)),
lambda eigenvalues,
kind_index=kind_index,
state_index=state_index: str(
eigenvalues[kind_index][0][state_index][-1]
),
lambda value,
kind_index=kind_index,
state_index=state_index,
symbol=symbol: update(
kind_index,
0,
state_index,
symbol,
float(value),
),
],
)
self.links.append(link)
spin_up_row.children += (eigenvalues_up,)

eigenvalues_down = ipw.Dropdown(
description=f"{si+1}",
description=f"{state_index+1}",
options=["-1", "0", "1"],
layout=ipw.Layout(width="65px"),
style={"description_width": "initial"},
Expand All @@ -192,8 +203,21 @@ def update(index, spin, state, symbol, value):
(self._model, "eigenvalues"),
(eigenvalues_down, "value"),
[
lambda evs, ei=ei, si=si: str(evs[ei][1][si][-1]),
lambda v, ei=ei, si=si, es=es: update(ei, 1, si, es, float(v)),
lambda eigenvalues,
kind_index=kind_index,
state_index=state_index: str(
eigenvalues[kind_index][1][state_index][-1]
),
lambda value,
kind_index=kind_index,
state_index=state_index,
symbol=symbol: update(
kind_index,
1,
state_index,
symbol,
float(value),
),
],
)
self.links.append(link)
Expand All @@ -202,7 +226,7 @@ def update(index, spin, state, symbol, value):
children.append(
ipw.HBox(
[
ipw.Label(element.symbol, layout=label_layout),
ipw.Label(kind.name, layout=label_layout),
ipw.VBox(
children=[
spin_up_row,
Expand Down
77 changes: 42 additions & 35 deletions src/aiidalab_qe/app/configuration/advanced/hubbard/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class HubbardModel(AdvancedSubModel):
default_value=[],
)

applicable_elements = []
applicable_kinds = []
orbital_labels = []

def __init__(self, *args, **kwargs):
Expand All @@ -51,7 +51,7 @@ def update(self, which):
self._update_defaults(which)
self.parameters = self._get_default_parameters()
self.eigenvalues = self._get_default_eigenvalues()
self.needs_eigenvalues_widget = len(self.applicable_elements) > 0
self.needs_eigenvalues_widget = len(self.applicable_kinds) > 0

def get_active_eigenvalues(self):
return [
Expand Down Expand Up @@ -82,7 +82,7 @@ def reset(self):

def _update_defaults(self, which):
if self.input_structure is None:
self.applicable_elements = []
self.applicable_kinds = []
self.orbital_labels = []
self._defaults.update(
{
Expand All @@ -91,39 +91,12 @@ def _update_defaults(self, which):
}
)
else:
self.orbital_labels = self._get_labels()
self._defaults["parameters"] = {label: 0.0 for label in self.orbital_labels}
self.applicable_elements = [
*filter(
lambda element: (
element.is_transition_metal
or element.is_lanthanoid
or element.is_actinoid
),
[
Element(symbol)
for symbol in self.input_structure.get_symbols_set()
],
)
]
self._defaults["eigenvalues"] = [
[
[
[state + 1, spin, element.symbol, -1] # default eigenvalue
for state in range(5 if element.is_transition_metal else 7)
]
for spin in range(2) # spin up and down
]
for element in self.applicable_elements # transition metals and lanthanoids
]
self.orbital_labels = self._define_orbital_labels()
self._defaults["parameters"] = self._define_default_parameters()
self.applicable_kinds = self._define_applicable_kinds()
self._defaults["eigenvalues"] = self._define_default_eigenvalues()

def _get_default_parameters(self):
return deepcopy(self._defaults["parameters"])

def _get_default_eigenvalues(self):
return deepcopy(self._defaults["eigenvalues"])

def _get_labels(self):
def _define_orbital_labels(self):
hubbard_manifold_list = [
self._get_manifold(Element(kind.symbol))
for kind in self.input_structure.kinds
Expand All @@ -136,6 +109,40 @@ def _get_labels(self):
)
]

def _define_default_parameters(self):
return {label: 0.0 for label in self.orbital_labels}

def _define_applicable_kinds(self):
applicable_kinds = []
for kind in self.input_structure.kinds:
element = Element(kind.symbol)
if (
element.is_transition_metal
or element.is_lanthanoid
or element.is_actinoid
):
num_states = 5 if element.is_transition_metal else 7
applicable_kinds.append((kind, num_states))
return applicable_kinds

def _define_default_eigenvalues(self):
return [
[
[
[state + 1, spin, kind.symbol, -1] # default eigenvalue
for state in range(num_states)
]
for spin in range(2) # spin up and down
]
for kind, num_states in self.applicable_kinds # transition metals and lanthanoids
]

def _get_default_parameters(self):
return deepcopy(self._defaults["parameters"])

def _get_default_eigenvalues(self):
return deepcopy(self._defaults["eigenvalues"])

def _get_manifold(self, element):
valence = [
orbital
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def _build_kinds_widget(self):
(self._model, "moments"),
(element_widget, "value"),
[
lambda d, symbol=symbol: d.get(symbol, 0.0),
lambda v, symbol=symbol: {
lambda moments, symbol=symbol: moments.get(symbol, 0.0),
lambda value, symbol=symbol: {
**self._model.moments,
symbol: v,
symbol: value,
},
],
)
Expand Down
20 changes: 12 additions & 8 deletions src/aiidalab_qe/app/configuration/advanced/pseudos/pseudos.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,17 +295,21 @@ def _build_setter_widgets(self):
(self._model, "dictionary"),
(upload_widget, "pseudo"),
[
lambda d, symbol=kind.symbol: orm.load_node(d.get(symbol)),
lambda v, symbol=kind.symbol: {
lambda dictionary, kind_name=kind.name: orm.load_node(
dictionary.get(kind_name)
),
lambda pseudo, kind_name=kind.name: {
**self._model.dictionary,
symbol: v.uuid,
kind_name: pseudo.uuid,
},
],
)
cutoffs_link = ipw.dlink(
(self._model, "cutoffs"),
(upload_widget, "cutoffs"),
lambda c, i=index: [c[0][i], c[1][i]] if len(c[0]) > i else [0.0, 0.0],
lambda cutoffs, index=index: [cutoffs[0][index], cutoffs[1][index]]
if len(cutoffs[0]) > index
else [0.0, 0.0],
)
upload_widget.render()

Expand Down Expand Up @@ -373,15 +377,15 @@ def render(self):
pseudo_link = ipw.dlink(
(self, "pseudo"),
(self.pseudo_text, "value"),
lambda p: p.filename if p else "",
lambda pseudo: pseudo.filename if pseudo else "",
)

cutoff_link = ipw.dlink(
(self, "cutoffs"),
(self.cutoff_message, "value"),
lambda c: cutoffs_message_template.format(
ecutwfc=c[0] if len(c) else "not set",
ecutrho=c[1] if len(c) else "not set",
lambda cutoffs: cutoffs_message_template.format(
ecutwfc=cutoffs[0] if len(cutoffs) else "not set",
ecutrho=cutoffs[1] if len(cutoffs) else "not set",
),
)

Expand Down
2 changes: 1 addition & 1 deletion src/aiidalab_qe/plugins/xas/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _build_core_hole_treatments_widget(self):
(self._model, "core_hole_treatments"),
(treatment_selector, "value"),
[
lambda cht, element=element: cht.get(element, "full"),
lambda treatments, element=element: treatments.get(element, "full"),
lambda value, element=element: {
**self._model.core_hole_treatments,
element: value,
Expand Down
6 changes: 3 additions & 3 deletions src/aiidalab_qe/plugins/xps/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ def _build_core_levels_widget(self):
(self._model, "core_levels"),
(checkbox, "value"),
[
lambda cl, orbital=orbital: cl.get(orbital, False),
lambda v, orbital=orbital: {
lambda levels, orbital=orbital: levels.get(orbital, False),
lambda value, orbital=orbital: {
**self._model.core_levels,
orbital: v,
orbital: value,
},
],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/configuration/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def test_advanced_hubbard_settings(generate_structure_data):

# Check there is only eigenvalues for Co (Transition metal)
model.hubbard.has_eigenvalues = True
assert len(model.hubbard.applicable_elements) == 1
assert len(model.hubbard.applicable_kinds) == 1
assert len(model.hubbard.eigenvalues) == 1

Co_eigenvalues = hubbard.eigenvalues_widget.children[0].children[1] # type: ignore
Expand Down

0 comments on commit 1dc9406

Please sign in to comment.