Skip to content

Commit

Permalink
More updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Nov 22, 2024
1 parent f688492 commit 2d19b18
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 43 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ lint.ignore = [
"PT013", # pytest-incorrect-pytest-import
"RUF012", # Disable checks for mutable class args. This is a non-problem.
"SIM105", # Use contextlib.suppress(OSError) instead of try-except-pass
"ISC001"
]
lint.pydocstyle.convention = "google"
lint.isort.required-imports = ["from __future__ import annotations"]
Expand Down
2 changes: 2 additions & 0 deletions src/matpes/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def get_df(self, functional: str) -> pd.DataFrame:
collection.find(
{},
projection=[
"matpesid",
"formula",
"elements",
"energy",
"chemsys",
Expand Down
138 changes: 99 additions & 39 deletions src/matpes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,35 @@ def update_graph(

# Callback to download data
@callback(
Output("download-data", "data"),
Input("btn-download", "n_clicks"),
Output("download-json", "data"),
Input("json-download", "n_clicks"),
State("functional", "value"),
State("el_filter", "value"),
State("chemsys_filter", "value"),
State("min_coh_e_filter", "value"),
State("max_coh_e_filter", "value"),
State("min_form_e_filter", "value"),
State("max_form_e_filter", "value"),
prevent_initial_call=True,
)
def download_data(n_clicks, functional, el_filter, chemsys_filter):
"""Handle data download requests."""
def download_json(
n_clicks,
functional,
el_filter,
chemsys_filter,
min_coh_e_filter,
max_coh_e_filter,
min_form_e_filter,
max_form_e_filter,
):
"""Handle json download requests."""
criteria = {}
if el_filter:
criteria["elements"] = el_filter
if chemsys_filter:
criteria["chemsys"] = "-".join(sorted(chemsys_filter.split("-")))
criteria["cohesive_energy_per_atom"] = {"$gte": min_coh_e_filter, "$lte": max_coh_e_filter}
criteria["formation_energy_per_atom"] = {"$gte": min_form_e_filter, "$lte": max_form_e_filter}
data = DB.get_json(functional, criteria)
for entry in data:
entry.pop("_id", None) # Remove MongoDB's internal ID
Expand All @@ -157,6 +172,35 @@ def download_data(n_clicks, functional, el_filter, chemsys_filter):
)


@callback(
Output("download-csv", "data"),
Input("csv-download", "n_clicks"),
State("functional", "value"),
State("el_filter", "value"),
State("chemsys_filter", "value"),
State("min_coh_e_filter", "value"),
State("max_coh_e_filter", "value"),
State("min_form_e_filter", "value"),
State("max_form_e_filter", "value"),
prevent_initial_call=True,
)
def download_csv(
n_clicks,
functional,
el_filter,
chemsys_filter,
min_coh_e_filter,
max_coh_e_filter,
min_form_e_filter,
max_form_e_filter,
):
"""Handle csv download requests."""
df = get_data(
functional, el_filter, chemsys_filter, min_coh_e_filter, max_coh_e_filter, min_form_e_filter, max_form_e_filter
)
return dict(content=df.to_csv(), filename=f"matpes_{functional}_{el_filter or 'all'}_{chemsys_filter or 'all'}.csv")


@callback(Output("el_filter", "value"), Input("ptheatmap", "clickData"), State("el_filter", "value"))
def display_click_data(clickdata, el_filter):
"""
Expand All @@ -166,8 +210,9 @@ def display_click_data(clickdata, el_filter):
clickdata (dict): Click data.
el_filter (dict): Element filter.
"""
el_filter = el_filter or []
new_el_filter = {*el_filter, Element.from_Z(clickdata["points"][0]["pointNumber"] + 1).symbol}
new_el_filter = el_filter or []
if clickdata:
new_el_filter = {*new_el_filter, Element.from_Z(clickdata["points"][0]["pointNumber"] + 1).symbol}
return list(new_el_filter)


Expand All @@ -180,25 +225,22 @@ def main():
[
dbc.Row(
[
# html.Div("MatPES Explorer", className="text-primary text-center fs-3"),
dbc.Col(
[
html.Img(
src="https://github.com/materialsvirtuallab/matpes/blob"
"/2b7f8de716289de8089504a63c6431c456268172/assets/logo.png?raw=true",
width="50%",
style={
"padding": "12px",
},
html.Div(
html.Img(
src="https://github.com/materialsvirtuallab/matpes/blob"
"/2b7f8de716289de8089504a63c6431c456268172/assets/logo.png?raw=true",
width="50%",
style={
"padding": "12px",
},
),
className="text-primary text-center",
)
],
width={"size": 8, "order": 1, "offset": 4},
)
]
),
dbc.Row(
[
html.H2("Explorer", className="text-primary text-center fs-3"),
width={"size": 6, "offset": 3},
),
]
),
dbc.Row(
Expand All @@ -213,18 +255,17 @@ def main():
clearable=False,
),
],
width=4,
width=2,
)
]
),
dbc.Row(
[
dbc.Col(
[
html.Div("Filters: "),
],
width=1,
),
html.Div("Filters: "),
],
),
dbc.Row(
[
dbc.Col(
[
html.Label("Element(s)"),
Expand All @@ -238,7 +279,7 @@ def main():
multi=True,
),
],
width=2,
width=1,
),
dbc.Col(
[
Expand All @@ -252,28 +293,44 @@ def main():
),
dbc.Col(
[
html.Div("Coh. Energy", className="text-center"),
html.Div("Coh. Energy (Min, Max)"),
dcc.Input(0, type="number", id="min_coh_e_filter"),
dcc.Input(10, type="number", id="max_coh_e_filter"),
],
width=2,
width=4,
),
dbc.Col(
[
html.Div("Form. Energy", className="text-center"),
html.Div("Form. Energy (Min, Max)"),
dcc.Input(0, type="number", id="min_form_e_filter"),
dcc.Input(10, type="number", id="max_form_e_filter"),
],
width=2,
width=4,
),
]
),
dbc.Col(
dbc.Row(
[
html.Button("Download", id="btn-download"),
dcc.Download(id="download-data"),
html.Div("Download"),
],
width=1,
),
dbc.Row(
[
dbc.Col(
[
html.Button("JSON", id="json-download"),
dcc.Download(id="download-json"),
],
width=2,
),
dbc.Col(
[
html.Button("CSV", id="csv-download"),
dcc.Download(id="download-csv"),
],
width=2,
),
]
),
html.Div(
[
Expand All @@ -296,8 +353,11 @@ def main():
),
dbc.Row(
dbc.Col(
dcc.Graph(id="ptheatmap", style={"marginLeft": "auto", "marginRight": "auto"}),
width={"size": 16},
html.Div(
[dcc.Graph(id="ptheatmap")],
style={"marginLeft": "auto", "marginRight": "auto", "text-align": "center"},
),
width=12,
)
),
dbc.Row(
Expand Down
14 changes: 10 additions & 4 deletions src/matpes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ def pt_heatmap(values: dict[str, float], label: str = "value", log: bool = False
)

fig.update_layout(
xaxis=dict(title="Group", range=[0.5, 18.5], dtick=1),
yaxis=dict(title="Period", range=[0.5, 10.5], dtick=1, autorange="reversed"),
xaxis=dict(title=None, range=[0.5, 18.5], dtick=1),
yaxis=dict(title=None, range=[0.5, 10.5], dtick=1, autorange="reversed"),
showlegend=False,
plot_bgcolor="white",
width=1100,
height=650,
width=1080,
height=640,
font=dict(
family="Arial",
size=14,
Expand All @@ -161,6 +161,12 @@ def pt_heatmap(values: dict[str, float], label: str = "value", log: bool = False
),
)

# Hide x-axis
fig.update_xaxes(showticklabels=False, showgrid=False)

# Hide y-axis
fig.update_yaxes(showticklabels=False, showgrid=False)

if log:
max_log = int(df[f"log10_{label}"].max())
fig.update_layout(
Expand Down

0 comments on commit 2d19b18

Please sign in to comment.