From 0b6964a62ddee462bd6fcd621ee1d04e1decf031 Mon Sep 17 00:00:00 2001 From: ddundo Date: Thu, 31 Oct 2024 20:38:40 +0000 Subject: [PATCH] #117: Add unit tests --- test/test_function_data.py | 76 +++++++++++++++++++++++++++++--------- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/test/test_function_data.py b/test/test_function_data.py index 2f31058d..122cc0a2 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -4,6 +4,7 @@ import abc import unittest +from tempfile import TemporaryDirectory from firedrake import * @@ -108,24 +109,6 @@ def test_extract_by_subinterval(self): for f in sub_data[self.field][label]: self.assertTrue(isinstance(f, Function)) - def test_export_extension_error(self): - with self.assertRaises(ValueError) as cm: - self.solution_data.export("test.ext") - msg = ( - "Output file format not recognised: 'test.ext'." - + " Supported formats are '.pvd' and '.h5'." - ) - self.assertEqual(str(cm.exception), msg) - - def test_export_field_error(self): - with self.assertRaises(ValueError) as cm: - self.solution_data.export("test.pvd", export_field_types="test") - msg = ( - "Field types ['test'] not recognised." - + f" Available types are {self.solution_data.labels}." - ) - self.assertEqual(str(cm.exception), msg) - class TestSteadyForwardSolutionData(BaseTestCases.TestFunctionData): """ @@ -231,5 +214,62 @@ def test_extract_by_subinterval(self): self.assertTrue(isinstance(f, Function)) +class TestExportFunctionData(BaseTestCases.TestFunctionData): + """ + Unit tests for exporting and checkpointing :class:`~.FunctionData`. + """ + + @classmethod + def setUpClass(cls): + cls.tmpdir = TemporaryDirectory() + + def setUp(self): + super().setUpUnsteady() + self.labels = ("forward", "forward_old") + + def _create_function_data(self): + self.solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + self.solution_data._create_data() + + def test_export_extension_error(self): + with self.assertRaises(ValueError) as cm: + self.solution_data.export("test.ext") + msg = ( + "Output file format not recognised: 'test.ext'." + + " Supported formats are '.pvd' and '.h5'." + ) + self.assertEqual(str(cm.exception), msg) + + def test_export_field_error(self): + with self.assertRaises(ValueError) as cm: + self.solution_data.export("test.pvd", export_field_types="test") + msg = ( + "Field types ['test'] not recognised." + + f" Available types are {self.solution_data.labels}." + ) + self.assertEqual(str(cm.exception), msg) + + def test_export_pvd(self): + with TemporaryDirectory() as tmpdir: + export_filepath = os.path.join(tmpdir, "test.pvd") + self.solution_data.export(export_filepath) + self.assertTrue(os.path.exists(export_filepath)) + + def test_export_pvd_ic(self): + ic = {field: Function(fs[0]) for field, fs in self.function_spaces.items()} + with TemporaryDirectory() as tmpdir: + export_filepath = os.path.join(tmpdir, "test.pvd") + self.solution_data.export(export_filepath, initial_condition=ic) + self.assertTrue(os.path.exists(export_filepath)) + + def test_export_h5(self): + with TemporaryDirectory() as tmpdir: + export_filepath = os.path.join(tmpdir, "test.h5") + self.solution_data.export(export_filepath) + self.assertTrue(os.path.exists(export_filepath)) + + if __name__ == "__main__": unittest.main()