Skip to content

Commit

Permalink
More cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep committed Nov 21, 2024
1 parent c86675f commit 376c72a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 19 deletions.
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
[![GitHub license](https://img.shields.io/github/license/materialsvirtuallab/python_template)](https://github.com/materialsvirtuallab/python_template/blob/main/LICENSE)
[![Linting](https://github.com/materialsvirtuallab/python_template/workflows/Linting/badge.svg)](https://github.com/materialsvirtuallab/python_template/workflows/Linting/badge.svg)
[![Testing](https://github.com/materialsvirtuallab/python_template/workflows/Testing/badge.svg)](https://github.com/materialsvirtuallab/python_template/workflows/Testing/badge.svg)
<!--
[![Downloads](https://pepy.tech/badge/python_template)](https://pepy.tech/project/python_template)
[![codecov](https://codecov.io/gh/materialsvirtuallab/python_template/branch/main/graph/badge.svg?token=3V3O79GODQ)]
(https://codecov.io/gh/materialsvirtuallab/python_template)
-->

# Introduction

This is a template for setting up Python packages in the Materials Virtual Lab. It comes with the standard Github
workflows, pyproject and linting.
Tools for working with MatPES.
9 changes: 9 additions & 0 deletions src/matpes/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ def __init__(self, dbname="matpes"):
client = MongoClient()
self.db = client[dbname]

def get_json(self, functional: str, criteria: dict) -> list:
"""
Args:
functional (str): The name of the functional to query.
criteria (dict): The criteria to query.
"""
return list(self.db[functional].find(criteria))

def get_df(self, functional: str) -> pd.DataFrame:
"""
Retrieve data for the given functional from the MongoDB database.
Expand All @@ -41,6 +49,7 @@ def get_df(self, functional: str) -> pd.DataFrame:
"formation_energy_per_atom",
"natoms",
"nelements",
"bandgap",
],
)
)
51 changes: 39 additions & 12 deletions src/matpes/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import collections
import functools
import itertools
import json
from typing import TYPE_CHECKING
Expand All @@ -24,24 +25,34 @@
DB = MatPESDB()


def get_data(functional: str, element_filter: list, chemsys: str) -> pd.DataFrame:
@functools.lru_cache
def get_full_data(functional: str) -> pd.DataFrame:
"""Cache data for each functional for more responsive UI."""
return DB.get_df(functional)


def get_data(functional: str, element_filter: list, chemsys_filter: str, bandgap_filter) -> pd.DataFrame:
"""
Filter data based on the selected functional, element, and chemical system.
Args:
functional (str): Functional to filter data for.
element_filter (list | None): Elements to filter (if any).
chemsys (str | None): Chemical system to filter (if any).
chemsys_filter (str | None): Chemical system to filter (if any).
bandgap_filter (str | None): Bandgap to filter (not used at the moment).
Returns:
pd.DataFrame: Filtered data.
"""
df = DB.get_df(functional)
df = get_full_data(functional)
if element_filter:
df = df[df["elements"].apply(lambda x: set(x).issuperset(element_filter))]
if chemsys:
sorted_chemsys = "-".join(sorted(chemsys.split("-")))
if chemsys_filter:
sorted_chemsys = "-".join(sorted(chemsys_filter.split("-")))
df = df[df["chemsys"] == sorted_chemsys]

# df = df[bandgap_filter[0] <= df["bandgap"]]
# df = df[df["bandgap"] <= bandgap_filter[1]]
return df


Expand All @@ -53,11 +64,16 @@ def get_data(functional: str, element_filter: list, chemsys: str) -> pd.DataFram
Output("natoms_hist", "figure"),
Output("nelements_hist", "figure"),
],
[Input("functional", "value"), Input("el_filter", "value"), Input("chemsys_filter", "value")],
[
Input("functional", "value"),
Input("el_filter", "value"),
Input("chemsys_filter", "value"),
# Input("bandgap_filter", "value")
],
)
def update_graph(functional, el_filter, chemsys_filter):
def update_graph(functional, el_filter, chemsys_filter, bandgap_filter):
"""Update graphs based on user inputs."""
df = get_data(functional, el_filter, chemsys_filter)
df = get_data(functional, el_filter, chemsys_filter, bandgap_filter)
element_counts = collections.Counter(itertools.chain(*df["elements"]))
heatmap_figure = pt_heatmap(element_counts, label="Count", log=True)
return (
Expand Down Expand Up @@ -90,13 +106,12 @@ def update_graph(functional, el_filter, chemsys_filter):
)
def download_data(n_clicks, functional, el_filter, chemsys_filter):
"""Handle data download requests."""
collection = DB[functional]
criteria = {}
if el_filter:
criteria["elements"] = el_filter
if chemsys_filter:
criteria["chemsys"] = "-".join(sorted(chemsys_filter.split("-")))
data = list(collection.find(criteria))
data = DB.get_json(functional, criteria)
for entry in data:
entry.pop("_id", None) # Remove MongoDB's internal ID
return dict(
Expand All @@ -120,6 +135,9 @@ def main():
src="https://github.com/materialsvirtuallab/matpes/blob"
"/2b7f8de716289de8089504a63c6431c456268172/assets/logo.png?raw=true",
width="50%",
style={
"padding": "12px",
},
)
],
width={"size": 8, "order": 1, "offset": 4},
Expand Down Expand Up @@ -158,14 +176,23 @@ def main():
),
dbc.Col(
[
html.Label("Filter by Chemsys"),
html.Div("Filter by Chemsys"),
dcc.Input(
id="chemsys_filter",
placeholder="Li-Fe-O",
),
],
width=2,
),
# dbc.Col(
# [
# html.Div("Bandgap", className="text-center"),
# dcc.RangeSlider(0, 10, 0.1,
# marks={i: str(i) for i in range(0, 10)},
# value=[0, 10], id='bandgap_filter'),
# ],
# width=2,
# ),
dbc.Col(
[
html.Label("Data Tools"),
Expand All @@ -179,7 +206,7 @@ def main():
dbc.Row(
dbc.Col(
dcc.Graph(id="ptheatmap", style={"marginLeft": "auto", "marginRight": "auto"}),
width={"size": 8, "offset": 2},
width={"size": 16},
)
),
dbc.Row(
Expand Down

0 comments on commit 376c72a

Please sign in to comment.