Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to interpolate data from one FunctionData object to another #243

Merged
merged 8 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions goalie/function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import firedrake.function as ffunc
import firedrake.functionspace as ffs
from firedrake import TransferManager
from firedrake.checkpointing import CheckpointFile
from firedrake.output.vtk_output import VTKFile

Expand Down Expand Up @@ -280,6 +281,56 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None):
f = self._data[field][field_type][i][j]
outfile.save_function(f, name=name, idx=j)

def transfer(self, target, method="interpolate"):
"""
Transfer all functions from this :class:`~.FunctionData` object to the target
:class:`~.FunctionData` object by interpolation, projection or prolongation.

:arg target: the target :class:`~.FunctionData` object to which to transfer the
data
:type target: :class:`~.FunctionData`
:arg method: the transfer method to use. Either 'interpolate', 'project' or
'prolong'
:type method: :class:`str`
"""
stp = self.time_partition
ttp = target.time_partition

if method not in ["interpolate", "project", "prolong"]:
raise ValueError(
f"Transfer method '{method}' not supported."
" Supported methods are 'interpolate', 'project', and 'prolong'."
)
if stp.num_subintervals != ttp.num_subintervals:
raise ValueError(
"Source and target have different numbers of subintervals."
)
if stp.num_exports_per_subinterval != ttp.num_exports_per_subinterval:
raise ValueError(
"Source and target have different numbers of exports per subinterval."
)

common_fields = set(stp.field_names) & set(ttp.field_names)
if not common_fields:
raise ValueError("No common fields between source and target.")

common_labels = set(self.labels) & set(target.labels)
if not common_labels:
raise ValueError("No common labels between source and target.")

for field in common_fields:
for label in common_labels:
for i in range(stp.num_subintervals):
for j in range(stp.num_exports_per_subinterval[i] - 1):
source_function = self._data[field][label][i][j]
target_function = target._data[field][label][i][j]
if method == "interpolate":
target_function.interpolate(source_function)
elif method == "project":
target_function.project(source_function)
elif method == "prolong":
TransferManager().prolong(source_function, target_function)


class ForwardSolutionData(FunctionData):
"""
Expand Down
188 changes: 188 additions & 0 deletions test/test_function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,193 @@ def test_export_h5(self):
self.assertTrue(os.path.exists(export_filepath))


class TestTransferFunctionData(BaseTestCases.TestFunctionData):
"""
Unit tests for transferring data from one :class:`~.FunctionData` to another.
"""

def setUp(self):
super().setUpUnsteady()
self.labels = ("forward", "forward_old")

def _create_function_data(self):
self.solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
self.solution_data._create_data()

# Assign 1 to all functions
tp = self.solution_data.time_partition
for field in tp.field_names:
for label in self.solution_data.labels:
for i in range(tp.num_subintervals):
for j in range(tp.num_exports_per_subinterval[i] - 1):
self.solution_data._data[field][label][i][j].assign(1)

def test_transfer_method_error(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="invalid_method")
self.assertEqual(
str(cm.exception),
"Transfer method 'invalid_method' not supported."
" Supported methods are 'interpolate', 'project', and 'prolong'.",
)

def test_transfer_subintervals_error(self):
target_time_partition = TimePartition(
1.5 * self.time_partition.end_time,
self.time_partition.num_subintervals + 1,
self.time_partition.timesteps + [0.25],
self.time_partition.field_names,
)
target_function_spaces = {
self.field: [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception),
"Source and target have different numbers of subintervals.",
)

def test_transfer_exports_error(self):
target_time_partition = TimePartition(
self.time_partition.end_time,
self.time_partition.num_subintervals,
self.time_partition.timesteps,
self.time_partition.field_names,
num_timesteps_per_export=[1, 2],
)
target_function_spaces = {
self.field: [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception),
"Source and target have different numbers of exports per subinterval.",
)

def test_transfer_common_fields_error(self):
target_time_partition = TimePartition(
self.time_partition.end_time,
self.time_partition.num_subintervals,
self.time_partition.timesteps,
["different_field"],
)
target_function_spaces = {
"different_field": [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception), "No common fields between source and target."
)

def test_transfer_common_labels_error(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
target_solution_data.labels = ("different_label",)
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception), "No common labels between source and target."
)

def test_transfer_interpolate(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="interpolate")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)
jwallwork23 marked this conversation as resolved.
Show resolved Hide resolved

def test_transfer_project(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="project")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)

def test_transfer_prolong(self):
enriched_mesh = MeshHierarchy(self.mesh, 1)[-1]
target_function_spaces = {
self.field: [
FunctionSpace(enriched_mesh, "DG", 0)
for _ in range(self.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
self.time_partition, target_function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="prolong")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)


if __name__ == "__main__":
unittest.main()
Loading