diff --git a/goalie/function_data.py b/goalie/function_data.py index 7a564bd..6d38207 100644 --- a/goalie/function_data.py +++ b/goalie/function_data.py @@ -6,6 +6,7 @@ import firedrake.function as ffunc import firedrake.functionspace as ffs +from firedrake import TransferManager from firedrake.checkpointing import CheckpointFile from firedrake.output.vtk_output import VTKFile @@ -280,6 +281,56 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None): f = self._data[field][field_type][i][j] outfile.save_function(f, name=name, idx=j) + def transfer(self, target, method="interpolate"): + """ + Transfer all functions from this :class:`~.FunctionData` object to the target + :class:`~.FunctionData` object by interpolation, projection or prolongation. + + :arg target: the target :class:`~.FunctionData` object to which to transfer the + data + :type target: :class:`~.FunctionData` + :arg method: the transfer method to use. Either 'interpolate', 'project' or + 'prolong' + :type method: :class:`str` + """ + stp = self.time_partition + ttp = target.time_partition + + if method not in ["interpolate", "project", "prolong"]: + raise ValueError( + f"Transfer method '{method}' not supported." + " Supported methods are 'interpolate', 'project', and 'prolong'." + ) + if stp.num_subintervals != ttp.num_subintervals: + raise ValueError( + "Source and target have different numbers of subintervals." + ) + if stp.num_exports_per_subinterval != ttp.num_exports_per_subinterval: + raise ValueError( + "Source and target have different numbers of exports per subinterval." + ) + + common_fields = set(stp.field_names) & set(ttp.field_names) + if not common_fields: + raise ValueError("No common fields between source and target.") + + common_labels = set(self.labels) & set(target.labels) + if not common_labels: + raise ValueError("No common labels between source and target.") + + for field in common_fields: + for label in common_labels: + for i in range(stp.num_subintervals): + for j in range(stp.num_exports_per_subinterval[i] - 1): + source_function = self._data[field][label][i][j] + target_function = target._data[field][label][i][j] + if method == "interpolate": + target_function.interpolate(source_function) + elif method == "project": + target_function.project(source_function) + elif method == "prolong": + TransferManager().prolong(source_function, target_function) + class ForwardSolutionData(FunctionData): """ diff --git a/test/test_function_data.py b/test/test_function_data.py index 36a2789..b88375f 100644 --- a/test/test_function_data.py +++ b/test/test_function_data.py @@ -267,5 +267,193 @@ def test_export_h5(self): self.assertTrue(os.path.exists(export_filepath)) +class TestTransferFunctionData(BaseTestCases.TestFunctionData): + """ + Unit tests for transferring data from one :class:`~.FunctionData` to another. + """ + + def setUp(self): + super().setUpUnsteady() + self.labels = ("forward", "forward_old") + + def _create_function_data(self): + self.solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + self.solution_data._create_data() + + # Assign 1 to all functions + tp = self.solution_data.time_partition + for field in tp.field_names: + for label in self.solution_data.labels: + for i in range(tp.num_subintervals): + for j in range(tp.num_exports_per_subinterval[i] - 1): + self.solution_data._data[field][label][i][j].assign(1) + + def test_transfer_method_error(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="invalid_method") + self.assertEqual( + str(cm.exception), + "Transfer method 'invalid_method' not supported." + " Supported methods are 'interpolate', 'project', and 'prolong'.", + ) + + def test_transfer_subintervals_error(self): + target_time_partition = TimePartition( + 1.5 * self.time_partition.end_time, + self.time_partition.num_subintervals + 1, + self.time_partition.timesteps + [0.25], + self.time_partition.field_names, + ) + target_function_spaces = { + self.field: [ + FunctionSpace(self.mesh, "DG", 0) + for _ in range(target_time_partition.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + target_time_partition, target_function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), + "Source and target have different numbers of subintervals.", + ) + + def test_transfer_exports_error(self): + target_time_partition = TimePartition( + self.time_partition.end_time, + self.time_partition.num_subintervals, + self.time_partition.timesteps, + self.time_partition.field_names, + num_timesteps_per_export=[1, 2], + ) + target_function_spaces = { + self.field: [ + FunctionSpace(self.mesh, "DG", 0) + for _ in range(target_time_partition.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + target_time_partition, target_function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), + "Source and target have different numbers of exports per subinterval.", + ) + + def test_transfer_common_fields_error(self): + target_time_partition = TimePartition( + self.time_partition.end_time, + self.time_partition.num_subintervals, + self.time_partition.timesteps, + ["different_field"], + ) + target_function_spaces = { + "different_field": [ + FunctionSpace(self.mesh, "DG", 0) + for _ in range(target_time_partition.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + target_time_partition, target_function_spaces + ) + target_solution_data._create_data() + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), "No common fields between source and target." + ) + + def test_transfer_common_labels_error(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + target_solution_data.labels = ("different_label",) + with self.assertRaises(ValueError) as cm: + self.solution_data.transfer(target_solution_data, method="interpolate") + self.assertEqual( + str(cm.exception), "No common labels between source and target." + ) + + def test_transfer_interpolate(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + self.solution_data.transfer(target_solution_data, method="interpolate") + for field in self.solution_data.time_partition.field_names: + for label in self.solution_data.labels: + for i in range(self.solution_data.time_partition.num_subintervals): + for j in range( + self.solution_data.time_partition.num_exports_per_subinterval[i] + - 1 + ): + source_function = self.solution_data._data[field][label][i][j] + target_function = target_solution_data._data[field][label][i][j] + self.assertTrue( + source_function.dat.data.all() + == target_function.dat.data.all() + ) + + def test_transfer_project(self): + target_solution_data = ForwardSolutionData( + self.time_partition, self.function_spaces + ) + target_solution_data._create_data() + self.solution_data.transfer(target_solution_data, method="project") + for field in self.solution_data.time_partition.field_names: + for label in self.solution_data.labels: + for i in range(self.solution_data.time_partition.num_subintervals): + for j in range( + self.solution_data.time_partition.num_exports_per_subinterval[i] + - 1 + ): + source_function = self.solution_data._data[field][label][i][j] + target_function = target_solution_data._data[field][label][i][j] + self.assertTrue( + source_function.dat.data.all() + == target_function.dat.data.all() + ) + + def test_transfer_prolong(self): + enriched_mesh = MeshHierarchy(self.mesh, 1)[-1] + target_function_spaces = { + self.field: [ + FunctionSpace(enriched_mesh, "DG", 0) + for _ in range(self.num_subintervals) + ] + } + target_solution_data = ForwardSolutionData( + self.time_partition, target_function_spaces + ) + target_solution_data._create_data() + self.solution_data.transfer(target_solution_data, method="prolong") + for field in self.solution_data.time_partition.field_names: + for label in self.solution_data.labels: + for i in range(self.solution_data.time_partition.num_subintervals): + for j in range( + self.solution_data.time_partition.num_exports_per_subinterval[i] + - 1 + ): + source_function = self.solution_data._data[field][label][i][j] + target_function = target_solution_data._data[field][label][i][j] + self.assertTrue( + source_function.dat.data.all() + == target_function.dat.data.all() + ) + + if __name__ == "__main__": unittest.main()