Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure tests #237

Merged
merged 9 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ jobs:
python $(which firedrake-clean)
export GITHUB_ACTIONS_TEST_RUN=1
python -m coverage erase
python -m coverage run -a --source=goalie -m pytest -v --durations=20 test
python -m coverage run -a --source=goalie -m pytest -v --durations=10 test_adjoint
python -m coverage run --source=goalie -m pytest -v --durations=20 test
python -m coverage report
changed-files-patterns: |
**/*.py
**/*.msh
**/*.geo
**/*.geo
4 changes: 1 addition & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ lint:
test: lint
@echo "Running test suite..."
@cd test && make
@cd test_adjoint && make
@echo "PASS"

coverage:
@echo "Generating coverage report..."
@python3 -m coverage erase
@python3 -m coverage run -a --source=goalie -m pytest -v test
@python3 -m coverage run -a --source=goalie -m pytest -v test_adjoint
@python3 -m coverage run --source=goalie -m pytest -v test
@python3 -m coverage html

demo:
Expand Down
4 changes: 2 additions & 2 deletions goalie/mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def plot(self, fig=None, axes=None, **kwargs):
from matplotlib.pyplot import subplots

if self.dim != 2:
raise ValueError("MeshSeq plotting only supported in 2D")
raise ValueError("MeshSeq plotting only supported in 2D.")

# Process kwargs
interior_kw = {"edgecolor": "k"}
Expand Down Expand Up @@ -336,7 +336,7 @@ def _outputs_consistent(self):
self.debug(
"Current and lagged solutions are equal. Does the"
" solver yield before updating lagged solutions?"
)
) # noqa
break
assert isinstance(method_map, dict), f"get_{method} should return a dict"
mesh_seq_fields = set(self.fields)
Expand Down
5 changes: 3 additions & 2 deletions test/Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
all: run

run:
@echo "Running all tests..."
@python3 -m pytest -v -n auto --durations=20 .
@echo "Running all non-adjoint tests..."
@python3 -m pytest -v -n auto --durations=20 test_*.py
@echo "Done."

clean:
Expand All @@ -13,3 +13,4 @@ clean:
@rm -rf *.jpg *.png
@rm -rf outputs*
@echo "Done."
@cd adjoint && make clean
File renamed without changes.
51 changes: 51 additions & 0 deletions test/adjoint/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
Global pytest configuration for adjoint tests.

**Disclaimer: some functions copied from firedrake/src/tests/conftest.py
"""

import pyadjoint
import pytest


@pytest.fixture(scope="module", autouse=True)
def check_empty_tape(request):
"""
Check that the tape is empty at the end of each module

**Disclaimer: copied from firedrake/src/tests/conftest.py
"""

def fin():
tape = pyadjoint.get_working_tape()
if tape is not None:
assert len(tape.get_blocks()) == 0

request.addfinalizer(fin)


@pytest.fixture(autouse=True)
def handle_taping():
"""
**Disclaimer: copied from
firedrake/tests/regression/test_adjoint_operators.py
"""
yield
tape = pyadjoint.get_working_tape()
tape.clear_tape()


@pytest.fixture(autouse=True, scope="module")
def handle_annotation():
"""
Since importing firedrake-adjoint modifies a global variable, we need to
pause annotations at the end of the module.

**Disclaimer: copied from
firedrake/tests/regression/test_adjoint_operators.py
"""
if not pyadjoint.annotate_tape():
pyadjoint.continue_annotation()
yield
if pyadjoint.annotate_tape():
pyadjoint.pause_annotation()
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ def test_enrichment_error(self):
mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0))
with self.assertRaises(ValueError) as cm:
mesh_seq.get_enriched_mesh_seq(enrichment_method="q")
msg = "Enrichment method 'q' not supported."
self.assertEqual(str(cm.exception), msg)
self.assertEqual(str(cm.exception), "Enrichment method 'q' not supported.")

def test_num_enrichments_error(self):
mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0))
Expand All @@ -369,6 +368,16 @@ def test_num_enrichments_error(self):
msg = "A positive number of enrichments is required."
self.assertEqual(str(cm.exception), msg)

def test_form_error(self):
mesh_seq = self.go_mesh_seq(self.get_function_spaces_decorator("R", 0, 0))
with self.assertRaises(AttributeError) as cm:
mesh_seq.forms()
msg = (
"Forms have not been read in. Use read_forms({'field_name': F}) in"
" get_solver to read in the forms."
)
self.assertEqual(str(cm.exception), msg)

def test_h_enrichment_error(self):
end_time = 1.0
num_subintervals = 2
Expand Down
2 changes: 1 addition & 1 deletion test_adjoint/test_demos.py → test/adjoint/test_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from goalie.log import *

cwd = os.path.abspath(os.path.dirname(__file__))
demo_dir = os.path.abspath(os.path.join(cwd, "..", "demos"))
demo_dir = os.path.abspath(os.path.join(cwd, "..", "..", "demos"))
all_demos = glob.glob(os.path.join(demo_dir, "*.py"))

# Modifications dictionary to cut down run time of demos:
Expand Down
File renamed without changes.
File renamed without changes.
102 changes: 68 additions & 34 deletions test/test_mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,91 @@
import re
import unittest

from firedrake import Function, FunctionSpace, UnitCubeMesh, UnitSquareMesh
from firedrake import (
Function,
FunctionSpace,
UnitCubeMesh,
UnitIntervalMesh,
UnitSquareMesh,
)
from parameterized import parameterized

from goalie.mesh_seq import MeshSeq
from goalie.time_partition import TimeInterval, TimePartition


class TestGeneric(unittest.TestCase):
class BaseClasses:
"""
Generic unit tests for :class:`MeshSeq`.
Base classes for mesh sequence unit testing.
"""

def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])
class MeshSeqTestCase(unittest.TestCase):
"""
Test case with a simple setUp method and mesh constructor.
"""

def test_setitem(self):
mesh1 = UnitSquareMesh(1, 1, diagonal="left")
mesh2 = UnitSquareMesh(1, 1, diagonal="right")
mesh_seq = MeshSeq(self.time_interval, [mesh1])
self.assertEqual(mesh_seq[0], mesh1)
mesh_seq[0] = mesh2
self.assertEqual(mesh_seq[0], mesh2)
def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])

def trivial_mesh(self, dim):
try:
return {
1: UnitIntervalMesh(1),
2: UnitSquareMesh(1, 1),
3: UnitCubeMesh(1, 1, 1),
}[dim]
except KeyError:
raise ValueError(f"Dimension {dim} not supported.") from None

def test_inconsistent_dim(self):
meshes = [UnitSquareMesh(1, 1), UnitCubeMesh(1, 1, 1)]

class TestExceptions(BaseClasses.MeshSeqTestCase):
"""
Unit tests for exceptions raised by :class:`MeshSeq`.
"""

def test_inconsistent_dim_error(self):
meshes = [self.trivial_mesh(2), self.trivial_mesh(3)]
with self.assertRaises(ValueError) as cm:
MeshSeq(self.time_partition, meshes)
msg = "Meshes must all have the same topological dimension."
self.assertEqual(str(cm.exception), msg)

@parameterized.expand(["get_function_spaces", "get_solver"])
def test_notimplemented_error(self, function_name):
mesh_seq = MeshSeq(self.time_interval, UnitSquareMesh(1, 1))
mesh_seq = MeshSeq(self.time_interval, self.trivial_mesh(2))
with self.assertRaises(NotImplementedError) as cm:
if function_name == "get_function_spaces":
getattr(mesh_seq, function_name)(mesh_seq[0])
else:
getattr(mesh_seq, function_name)()
msg = f"'{function_name}' needs implementing."
self.assertEqual(str(cm.exception), msg)
self.assertEqual(str(cm.exception), f"'{function_name}' needs implementing.")

@parameterized.expand(["get_function_spaces", "get_initial_condition"])
def test_return_dict_error(self, method):
mesh = UnitSquareMesh(1, 1)
kwargs = {method: lambda _: 0}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
msg = f"{method} should return a dict"
self.assertEqual(str(cm.exception), msg)
MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs)
self.assertEqual(str(cm.exception), f"{method} should return a dict")

@parameterized.expand(["get_function_spaces", "get_initial_condition"])
def test_missing_field_error(self, method):
mesh = UnitSquareMesh(1, 1)
kwargs = {method: lambda _: {}}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs)
msg = "missing fields {'field'} in " + f"{method}"
self.assertEqual(str(cm.exception), msg)

@parameterized.expand(["get_function_spaces", "get_initial_condition"])
def test_unexpected_field_error(self, method):
mesh = UnitSquareMesh(1, 1)
kwargs = {method: lambda _: {"field": None, "extra_field": None}}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
MeshSeq(self.time_interval, self.trivial_mesh(2), **kwargs)
msg = "unexpected fields {'extra_field'} in " + f"{method}"
self.assertEqual(str(cm.exception), msg)

def test_solver_generator_error(self):
mesh = UnitSquareMesh(1, 1)
mesh = self.trivial_mesh(2)
f_space = FunctionSpace(mesh, "CG", 1)
kwargs = {
"get_function_spaces": lambda _: {"field": f_space},
Expand All @@ -84,8 +98,32 @@ def test_solver_generator_error(self):
}
with self.assertRaises(AssertionError) as cm:
MeshSeq(self.time_interval, mesh, **kwargs)
msg = "solver should yield"
self.assertEqual(str(cm.exception), msg)
self.assertEqual(str(cm.exception), "solver should yield")

@parameterized.expand([1, 3])
def test_plot_dim_error(self, dim):
mesh_seq = MeshSeq(self.time_interval, self.trivial_mesh(dim))
with self.assertRaises(ValueError) as cm:
mesh_seq.plot()
self.assertEqual(str(cm.exception), "MeshSeq plotting only supported in 2D.")


class TestGeneric(BaseClasses.MeshSeqTestCase):
"""
Generic unit tests for :class:`MeshSeq`.
"""

def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])

def test_setitem(self):
mesh1 = UnitSquareMesh(1, 1, diagonal="left")
mesh2 = UnitSquareMesh(1, 1, diagonal="right")
mesh_seq = MeshSeq(self.time_interval, [mesh1])
self.assertEqual(mesh_seq[0], mesh1)
mesh_seq[0] = mesh2
self.assertEqual(mesh_seq[0], mesh2)

def test_counting_2d(self):
mesh_seq = MeshSeq(self.time_interval, [UnitSquareMesh(3, 3)])
Expand All @@ -98,16 +136,12 @@ def test_counting_3d(self):
self.assertEqual(mesh_seq.count_vertices(), [64])


class TestStringFormatting(unittest.TestCase):
class TestStringFormatting(BaseClasses.MeshSeqTestCase):
"""
Test that the :meth:`__str__` and :meth:`__repr__` methods work as intended for
Goalie's :class:`MeshSeq` object.
"""

def setUp(self):
self.time_partition = TimePartition(1.0, 2, [0.5, 0.5], ["field"])
self.time_interval = TimeInterval(1.0, [0.5], ["field"])

def test_mesh_seq_time_interval_str(self):
mesh_seq = MeshSeq(self.time_interval, [UnitSquareMesh(1, 1)])
got = re.sub("#[0-9]*", "?", str(mesh_seq))
Expand Down
Loading
Loading