Skip to content

Commit

Permalink
use pytest-mpl to test test_energy_band_alignment_diagram
Browse files Browse the repository at this point in the history
  • Loading branch information
ireaml committed Sep 7, 2023
1 parent aa4de7e commit 9786159
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ jobs:
pip show -V ase
- name: Run tests
run: pytest tests/unit_tests.py
run: pytest --mpl tests/unit_tests.py

11 changes: 8 additions & 3 deletions macrodensity/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def energy_band_alignment_diagram(
outfile: str = "BandAlignment",
references: dict = {},
edge=None,
fig_format: str = "pdf",
) -> plt.figure:
"""Plot an energy band alignment diagram for a list of materials.
Expand Down Expand Up @@ -81,8 +82,12 @@ def energy_band_alignment_diagram(
edge (None or str, optional): The edge color for the bars.
If None, there will be no edge color. Default is None.
fig_format (str, optional): The format used to save the image.
Default is "pdf".
Returns:
Figure: A matplotlib figure object containing the energy band alignment diagram.
Figure: A matplotlib figure object containing the energy band alignment
diagram.
Example:
>>> energies = [(5.2, 2.8), (4.9, 3.1), (5.5, 2.6)]
Expand Down Expand Up @@ -196,8 +201,8 @@ def energy_band_alignment_diagram(
color="r",
)

fig.savefig(f"{outfile}.pdf", bbox_inches="tight")
print(f"Figure saved as {outfile}.pdf")
fig.savefig(f"{outfile}.{fig_format}", bbox_inches="tight")
print(f"Figure saved as {outfile}.{fig_format}")
plt.close(fig)
return fig

Expand Down
94 changes: 94 additions & 0 deletions tests/01_Generate_test_figs.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generate reference figures for testing plotting functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"path_for_ref_images = \"testIm\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Generate test fig for test_energy_band_alignment_diagram"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Figure saved as BandAlignment.pdf\n"
]
}
],
"source": [
"import macrodensity as md\n",
"\n",
"fig = md.energy_band_alignment_diagram(\n",
" {\n",
" \"ZnO\": [4.4, 7.7],\n",
" \"MOF-5\": [2.7, 7.3],\n",
" \"HKUST-1\": [5.1, 6.0],\n",
" \"ZIF-8\": [0.9, 6.4],\n",
" \"COF-1M\": [1.3, 4.7],\n",
" \"CPO-27-Mg\": [2.9, 5.9],\n",
" \"MIL-125\": [3.8, 7.6],\n",
" \"TiO2\": [4.8, 7.8],\n",
" },\n",
" ylims=(-10, 0.0),\n",
" arrowhead=0.15,\n",
")\n",
"fig.savefig(f\"{path_for_ref_images}/BandAlignment.png\", bbox_inches=\"tight\", transparent=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "macrodensity2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Binary file removed tests/testIm/BandAlignment.pdf
Binary file not shown.
Binary file added tests/testIm/BandAlignment.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 30 additions & 9 deletions tests/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from os.path import join as path_join
import matplotlib as mpl
import matplotlib.pyplot as plt
import pytest
import pytest


import numpy as np
Expand All @@ -23,7 +23,7 @@
has_pandas = False


test_dir = os.path.abspath(os.path.dirname(__file__))
_file_path = os.path.dirname(__file__)


class TestDensityReadingFunctions(unittest.TestCase):
Expand Down Expand Up @@ -134,7 +134,6 @@ def tearDown(self):


class TestOtherReadingFunctions(unittest.TestCase):

def test_read_vasp_classic(self):
"""Test the function for reading CHGCAR/LOCPOT"""
chgcar = pkg_resources.resource_filename(
Expand Down Expand Up @@ -238,12 +237,11 @@ def test_get_band_extrema(self):
Outcar = pkg_resources.resource_filename(
__name__, path_join("../tests", "OUTCAR.test")
)
out = md.get_band_extrema(
input_file=Outcar
)
out = md.get_band_extrema(input_file=Outcar)
self.assertEqual(out[0], 2.8952)
self.assertEqual(out[1], 4.411)


class TestGeometryFunctions(unittest.TestCase):
"""Test the functions that do geometry and trigonometry"""

Expand Down Expand Up @@ -303,7 +301,6 @@ def test_GCD_List(self):


class TestConvenienceFunctions(unittest.TestCase):

def test_bulk_interstitial_alignment(self):
"""Tests the bulk_interstitial_alignment function"""
Locpot = pkg_resources.resource_filename(
Expand Down Expand Up @@ -431,8 +428,32 @@ def test_plot_planar_cube(self):
self.assertEqual(dfpot["Planar"].tolist()[-1], -0.581089179258661)
self.addCleanup(os.remove, "planar_average.csv")
self.addCleanup(os.remove, "planar_average.png")



@pytest.mark.mpl_image_compare(
baseline_dir=f"{_file_path}/testIm",
filename="BandAlignment.png",
style=f"{_file_path}/../macroDensity/macrodensity.mplstyle",
savefig_kwargs={"transparent": True, "bbox_inches": "tight"},
)
def test_energy_band_alignment_diagram(self):
"""Tests the energy_band_alignment_diagram function"""

fig = md.energy_band_alignment_diagram(
{
"ZnO": [4.4, 7.7],
"MOF-5": [2.7, 7.3],
"HKUST-1": [5.1, 6.0],
"ZIF-8": [0.9, 6.4],
"COF-1M": [1.3, 4.7],
"CPO-27-Mg": [2.9, 5.9],
"MIL-125": [3.8, 7.6],
"TiO2": [4.8, 7.8],
},
ylims=(-10, 0.0),
arrowhead=0.15,
)
self.addCleanup(os.remove, "BandAlignment.pdf")
return fig

if __name__ == "__main__":
unittest.main()

0 comments on commit 9786159

Please sign in to comment.