Skip to content

Commit

Permalink
Restructure tests (#237)
Browse files Browse the repository at this point in the history
Closes #236.

This PR puts all the tests in the same directory, with a subdirectory
for the adjoint ones. Various parts of the conftest are no longer
needed.

I also plan to crank up the test coverage on a few modules in this PR.
  • Loading branch information
jwallwork23 authored Nov 24, 2024
1 parent 5303c0d commit 7f61988
Show file tree
Hide file tree
Showing 19 changed files with 139 additions and 217 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ jobs:
python $(which firedrake-clean)
export GITHUB_ACTIONS_TEST_RUN=1
python -m coverage erase
python -m coverage run -a --source=goalie -m pytest -v --durations=20 test
python -m coverage run -a --source=goalie -m pytest -v --durations=10 test_adjoint
python -m coverage run --source=goalie -m pytest -v --durations=20 test
python -m coverage report
changed-files-patterns: |
**/*.py
**/*.msh
**/*.geo
**/*.geo
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ lint:
test: lint
@echo "Running test suite..."
@cd test && make
@cd test_adjoint && make
@echo "PASS"

coverage:
@echo "Generating coverage report..."
@python3 -m coverage erase
@python3 -m coverage run -a --source=goalie -m pytest -v test
@python3 -m coverage run -a --source=goalie -m pytest -v test_adjoint
@python3 -m coverage run --source=goalie -m pytest -v test
@python3 -m coverage html

demo:
Expand Down
4 changes: 2 additions & 2 deletions goalie/mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def plot(self, fig=None, axes=None, **kwargs):
from matplotlib.pyplot import subplots

if self.dim != 2:
raise ValueError("MeshSeq plotting only supported in 2D")
raise ValueError("MeshSeq plotting only supported in 2D.")

# Process kwargs
interior_kw = {"edgecolor": "k"}
Expand Down Expand Up @@ -336,7 +336,7 @@ def _outputs_consistent(self):
self.debug(
"Current and lagged solutions are equal. Does the"
" solver yield before updating lagged solutions?"
)
) # noqa
break
assert isinstance(method_map, dict), f"get_{method} should return a dict"
mesh_seq_fields = set(self.fields)
Expand Down
5 changes: 3 additions & 2 deletions test/Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
all: run

run:
@echo "Running all tests..."
@python3 -m pytest -v -n auto --durations=20 .
@echo "Running all non-adjoint tests..."
@python3 -m pytest -v -n auto --durations=20 test_*.py
@echo "Done."

clean:
Expand All @@ -13,3 +13,4 @@ clean:
@rm -rf *.jpg *.png
@rm -rf outputs*
@echo "Done."
@cd adjoint && make clean
File renamed without changes.
51 changes: 51 additions & 0 deletions test/adjoint/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Global pytest configuration for adjoint tests.
**Disclaimer: some functions copied from firedrake/src/tests/conftest.py
"""

import pyadjoint
import pytest


@pytest.fixture(scope="module", autouse=True)
def check_empty_tape(request):
"""
Check that the tape is empty at the end of each module
**Disclaimer: copied from firedrake/src/tests/conftest.py
"""

def fin():
tape = pyadjoint.get_working_tape()
if tape is not None:
assert len(tape.get_blocks()) == 0

request.addfinalizer(fin)


@pytest.fixture(autouse=True)
def handle_taping():
"""
**Disclaimer: copied from
firedrake/tests/regression/test_adjoint_operators.py
"""
yield
tape = pyadjoint.get_working_tape()
tape.clear_tape()


@pytest.fixture(autouse=True, scope="module")
def handle_annotation():
"""
Since importing firedrake-adjoint modifies a global variable, we need to
pause annotations at the end of the module.
**Disclaimer: copied from
firedrake/tests/regression/test_adjoint_operators.py
"""
if not pyadjoint.annotate_tape():
pyadjoint.continue_annotation()
yield
if pyadjoint.annotate_tape():
pyadjoint.pause_annotation()
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ def test_enrichment_error(self):
mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0))
with self.assertRaises(ValueError) as cm:
mesh_seq.get_enriched_mesh_seq(enrichment_method="q")
msg = "Enrichment method 'q' not supported."
self.assertEqual(str(cm.exception), msg)
self.assertEqual(str(cm.exception), "Enrichment method 'q' not supported.")

def test_num_enrichments_error(self):
mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0))
Expand All @@ -369,6 +368,16 @@ def test_num_enrichments_error(self):
msg = "A positive number of enrichments is required."
self.assertEqual(str(cm.exception), msg)

def test_form_error(self):
mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0))
with self.assertRaises(AttributeError) as cm:
mesh_seq.forms()
msg = (
"Forms have not been read in. Use read_forms({'field_name': F}) in"
" get_solver to read in the forms."
)
self.assertEqual(str(cm.exception), msg)

def test_h_enrichment_error(self):
end_time = 1.0
num_subintervals = 2
Expand Down
2 changes: 1 addition & 1 deletion test_adjoint/test_demos.py → test/adjoint/test_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from goalie.log import *

cwd = os.path.abspath(os.path.dirname(__file__))
demo_dir = os.path.abspath(os.path.join(cwd, "..", "demos"))
demo_dir = os.path.abspath(os.path.join(cwd, "..", "..", "demos"))
all_demos = glob.glob(os.path.join(demo_dir, "*.py"))

# Modifications dictionary to cut down run time of demos:
Expand Down
File renamed without changes.
File renamed without changes.
102 changes: 68 additions & 34 deletions test/test_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,91 @@
import re
import unittest

from firedrake import Function, FunctionSpace, UnitCubeMesh, UnitSquareMesh
from firedrake import (
Function,
FunctionSpace,
UnitCubeMesh,
UnitIntervalMesh,
UnitSquareMesh,
)
from parameterized import parameterized

from goalie.mesh_seq import MeshSeq
from goalie.time_partition import TimeInterval, TimePartition


class TestGeneric(unittest.TestCase):
class BaseClasses:
"""
Generic unit tests for :class:`MeshSeq`.
Base classes for mesh sequence unit testing.
"""

def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])
class MeshSeqTestCase(unittest.TestCase):
"""
Test case with a simple setUp method and mesh constructor.
"""

def test_setitem(self):
mesh1 = UnitSquareMesh(1, 1, diagonal="left")
mesh2 = UnitSquareMesh(1, 1, diagonal="right")
mesh_seq = MeshSeq(self.time_interval, [mesh1])
self.assertEqual(mesh_seq[0], mesh1)
mesh_seq[0] = mesh2
self.assertEqual(mesh_seq[0], mesh2)
def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])

def trivial_mesh(self, dim):
try:
return {
1: UnitIntervalMesh(1),
2: UnitSquareMesh(1, 1),
3: UnitCubeMesh(1, 1, 1),
}[dim]
except KeyError:
raise ValueError(f"Dimension {dim} not supported.") from None

def test_inconsistent_dim(self):
meshes = [UnitSquareMesh(1, 1), UnitCubeMesh(1, 1, 1)]

class TestExceptions(BaseClasses.MeshSeqTestCase):
"""
Unit tests for exceptions raised by :class:`MeshSeq`.
"""

def test_inconsistent_dim_error(self):
meshes = [self.trivial_mesh(2), self.trivial_mesh(3)]
with self.assertRaises(ValueError) as cm:
MeshSeq(self.time_partition, meshes)
msg = "Meshes must all have the same topological dimension."
self.assertEqual(str(cm.exception), msg)

@parameterized.expand(["get_function_spaces", "get_solver"])
def test_notimplemented_error(self, function_name):
mesh_seq = MeshSeq(self.time_interval, UnitSquareMesh(1, 1))
mesh_seq = MeshSeq(self.time_interval, self.trivial_mesh(2))
with self.assertRaises(NotImplementedError) as cm:
if function_name == "get_function_spaces":
getattr(mesh_seq, function_name)(mesh_seq[0])
else:
getattr(mesh_seq, function_name)()
msg = f"'{function_name}' needs implementing."
self.assertEqual(str(cm.exception), msg)
self.assertEqual(str(cm.exception), f"'{function_name}' needs implementing.")

@parameterized.expand(["get_function_spaces", "get_initial_condition"])
def test_return_dict_error(self, method):
mesh = UnitSquareMesh(1, 1)
kwargs = {method: lambda _: 0}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
msg = f"{method} should return a dict"
self.assertEqual(str(cm.exception), msg)
MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs)
self.assertEqual(str(cm.exception), f"{method} should return a dict")

@parameterized.expand(["get_function_spaces", "get_initial_condition"])
def test_missing_field_error(self, method):
mesh = UnitSquareMesh(1, 1)
kwargs = {method: lambda _: {}}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs)
msg = "missing fields {'field'} in " + f"{method}"
self.assertEqual(str(cm.exception), msg)

@parameterized.expand(["get_function_spaces", "get_initial_condition"])
def test_unexpected_field_error(self, method):
mesh = UnitSquareMesh(1, 1)
kwargs = {method: lambda _: {"field": None, "extra_field": None}}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs)
msg = "unexpected fields {'extra_field'} in " + f"{method}"
self.assertEqual(str(cm.exception), msg)

def test_solver_generator_error(self):
mesh = UnitSquareMesh(1, 1)
mesh = self.trivial_mesh(2)
f_space = FunctionSpace(mesh, "CG", 1)
kwargs = {
"get_function_spaces": lambda _: {"field": f_space},
Expand All @@ -84,8 +98,32 @@ def test_solver_generator_error(self):
}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
msg = "solver should yield"
self.assertEqual(str(cm.exception), msg)
self.assertEqual(str(cm.exception), "solver should yield")

@parameterized.expand([1, 3])
def test_plot_dim_error(self, dim):
mesh_seq = MeshSeq(self.time_interval, self.trivial_mesh(dim))
with self.assertRaises(ValueError) as cm:
mesh_seq.plot()
self.assertEqual(str(cm.exception), "MeshSeq plotting only supported in 2D.")


class TestGeneric(BaseClasses.MeshSeqTestCase):
"""
Generic unit tests for :class:`MeshSeq`.
"""

def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])

def test_setitem(self):
mesh1 = UnitSquareMesh(1, 1, diagonal="left")
mesh2 = UnitSquareMesh(1, 1, diagonal="right")
mesh_seq = MeshSeq(self.time_interval, [mesh1])
self.assertEqual(mesh_seq[0], mesh1)
mesh_seq[0] = mesh2
self.assertEqual(mesh_seq[0], mesh2)

def test_counting_2d(self):
mesh_seq = MeshSeq(self.time_interval, [UnitSquareMesh(3, 3)])
Expand All @@ -98,16 +136,12 @@ def test_counting_3d(self):
self.assertEqual(mesh_seq.count_vertices(), [64])


class TestStringFormatting(unittest.TestCase):
class TestStringFormatting(BaseClasses.MeshSeqTestCase):
"""
Test that the :meth:`__str__` and :meth:`__repr__` methods work as intended for
Goalie's :class:`MeshSeq` object.
"""

def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])

def test_mesh_seq_time_interval_str(self):
mesh_seq = MeshSeq(self.time_interval, [UnitSquareMesh(1, 1)])
got = re.sub("#[0-9]*", "?", str(mesh_seq))
Expand Down
Loading

0 comments on commit 7f61988

Please sign in to comment.