Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to interpolate data from one FunctionData object to another #243

Merged
merged 8 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions goalie/function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,53 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None):
f = self._data[field][field_type][i][j]
outfile.save_function(f, name=name, idx=j)

def transfer(self, target, method="interpolate"):
"""
Interpolate or project all functions from this :class:`~FunctionData` object to
the target :class:`~FunctionData` object.

:arg target: the target :class:`~FunctionData` object to which to transfer the
data
:type target: :class:`.FunctionData`
jwallwork23 marked this conversation as resolved.
Show resolved Hide resolved
:arg method: the transfer method to use, either 'interpolate' or 'project'
:type method: :class:`str`
"""
stp = self.time_partition
ttp = target.time_partition

if method not in ["interpolate", "project"]:
raise ValueError(
f"Transfer method '{method}' not supported."
" Supported methods are 'interpolate' or 'project'."
)
jwallwork23 marked this conversation as resolved.
Show resolved Hide resolved
if stp.num_subintervals != ttp.num_subintervals:
raise ValueError(
"Source and target have different numbers of subintervals."
)
if stp.num_exports_per_subinterval != ttp.num_exports_per_subinterval:
raise ValueError(
"Source and target have different numbers of exports per subinterval."
)

common_fields = set(stp.field_names) & set(ttp.field_names)
if not common_fields:
raise ValueError("No common fields between source and target.")

common_labels = set(self.labels) & set(target.labels)
if not common_labels:
raise ValueError("No common labels between source and target.")

for field in common_fields:
for label in common_labels:
for i in range(stp.num_subintervals):
for j in range(stp.num_exports_per_subinterval[i] - 1):
source_function = self._data[field][label][i][j]
target_function = target._data[field][label][i][j]
if method == "interpolate":
target_function.interpolate(source_function)
elif method == "project":
target_function.project(source_function)


class ForwardSolutionData(FunctionData):
"""
Expand Down
153 changes: 153 additions & 0 deletions test/test_function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,158 @@ def test_export_h5(self):
self.assertTrue(os.path.exists(export_filepath))


class TestTransferFunctionData(BaseTestCases.TestFunctionData):
"""
Unit tests for transferring data from one :class:`~.FunctionData` to another.
"""

def setUp(self):
super().setUpUnsteady()
self.labels = ("forward", "forward_old")

def _create_function_data(self):
self.solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
self.solution_data._create_data()

def test_transfer_method_error(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="invalid_method")
self.assertEqual(
str(cm.exception),
"Transfer method 'invalid_method' not supported."
" Supported methods are 'interpolate' or 'project'.",
)

def test_transfer_subintervals_error(self):
target_time_partition = TimePartition(
1.5 * self.time_partition.end_time,
self.time_partition.num_subintervals + 1,
self.time_partition.timesteps + [0.25],
self.time_partition.field_names,
)
target_function_spaces = {
self.field: [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception),
"Source and target have different numbers of subintervals.",
)

def test_transfer_exports_error(self):
target_time_partition = TimePartition(
self.time_partition.end_time,
self.time_partition.num_subintervals,
self.time_partition.timesteps,
self.time_partition.field_names,
num_timesteps_per_export=[1, 2],
)
target_function_spaces = {
self.field: [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception),
"Source and target have different numbers of exports per subinterval.",
)

def test_transfer_common_fields_error(self):
target_time_partition = TimePartition(
self.time_partition.end_time,
self.time_partition.num_subintervals,
self.time_partition.timesteps,
["different_field"],
)
target_function_spaces = {
"different_field": [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception), "No common fields between source and target."
)

def test_transfer_common_labels_error(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
target_solution_data.labels = ("different_label",)
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception), "No common labels between source and target."
)

def test_transfer_interpolate(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="interpolate")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)
jwallwork23 marked this conversation as resolved.
Show resolved Hide resolved

def test_transfer_project(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="project")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)


if __name__ == "__main__":
unittest.main()