Skip to content

Commit

Permalink
#117: Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ddundo committed Oct 31, 2024
1 parent 85ef5fd commit 0b6964a
Showing 1 changed file with 58 additions and 18 deletions.
76 changes: 58 additions & 18 deletions test/test_function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import abc
import unittest
from tempfile import TemporaryDirectory

This comment has been minimized.

Copy link
@jwallwork23

jwallwork23 Nov 1, 2024

Member

Nice! I hadn't heard of this before but it sounds like a really useful utility. We should use this more often.

This comment has been minimized.

Copy link
@ddundo

ddundo Nov 1, 2024

Author Member

I hadn't either until yesterday! I was initially going to copy the animate checkpointing temporary directory workflow, but this seemed much neater. Might want to neaten that up at some point :)


from firedrake import *

Expand Down Expand Up @@ -108,24 +109,6 @@ def test_extract_by_subinterval(self):
for f in sub_data[self.field][label]:
self.assertTrue(isinstance(f, Function))

def test_export_extension_error(self):
with self.assertRaises(ValueError) as cm:
self.solution_data.export("test.ext")
msg = (
"Output file format not recognised: 'test.ext'."
+ " Supported formats are '.pvd' and '.h5'."
)
self.assertEqual(str(cm.exception), msg)

def test_export_field_error(self):
with self.assertRaises(ValueError) as cm:
self.solution_data.export("test.pvd", export_field_types="test")
msg = (
"Field types ['test'] not recognised."
+ f" Available types are {self.solution_data.labels}."
)
self.assertEqual(str(cm.exception), msg)


class TestSteadyForwardSolutionData(BaseTestCases.TestFunctionData):
"""
Expand Down Expand Up @@ -231,5 +214,62 @@ def test_extract_by_subinterval(self):
self.assertTrue(isinstance(f, Function))


class TestExportFunctionData(BaseTestCases.TestFunctionData):
"""
Unit tests for exporting and checkpointing :class:`~.FunctionData`.
"""

@classmethod
def setUpClass(cls):
cls.tmpdir = TemporaryDirectory()

def setUp(self):
super().setUpUnsteady()
self.labels = ("forward", "forward_old")

def _create_function_data(self):
self.solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
self.solution_data._create_data()

def test_export_extension_error(self):
with self.assertRaises(ValueError) as cm:
self.solution_data.export("test.ext")
msg = (
"Output file format not recognised: 'test.ext'."
+ " Supported formats are '.pvd' and '.h5'."
)
self.assertEqual(str(cm.exception), msg)

def test_export_field_error(self):
with self.assertRaises(ValueError) as cm:
self.solution_data.export("test.pvd", export_field_types="test")
msg = (
"Field types ['test'] not recognised."
+ f" Available types are {self.solution_data.labels}."
)
self.assertEqual(str(cm.exception), msg)

def test_export_pvd(self):
with TemporaryDirectory() as tmpdir:
export_filepath = os.path.join(tmpdir, "test.pvd")
self.solution_data.export(export_filepath)
self.assertTrue(os.path.exists(export_filepath))

def test_export_pvd_ic(self):
ic = {field: Function(fs[0]) for field, fs in self.function_spaces.items()}
with TemporaryDirectory() as tmpdir:
export_filepath = os.path.join(tmpdir, "test.pvd")
self.solution_data.export(export_filepath, initial_condition=ic)
self.assertTrue(os.path.exists(export_filepath))

def test_export_h5(self):
with TemporaryDirectory() as tmpdir:
export_filepath = os.path.join(tmpdir, "test.h5")
self.solution_data.export(export_filepath)
self.assertTrue(os.path.exists(export_filepath))


if __name__ == "__main__":
unittest.main()

0 comments on commit 0b6964a

Please sign in to comment.