Skip to content

Commit

Permalink
Add method to interpolate data from one FunctionData object to anot…
Browse files Browse the repository at this point in the history
…her (#243)
  • Loading branch information
ddundo authored Dec 8, 2024
1 parent 736b2f5 commit 9f4f9bb
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 0 deletions.
51 changes: 51 additions & 0 deletions goalie/function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import firedrake.function as ffunc
import firedrake.functionspace as ffs
from firedrake import TransferManager
from firedrake.checkpointing import CheckpointFile
from firedrake.output.vtk_output import VTKFile

Expand Down Expand Up @@ -280,6 +281,56 @@ def _export_h5(self, output_fpath, export_field_types, initial_condition=None):
f = self._data[field][field_type][i][j]
outfile.save_function(f, name=name, idx=j)

def transfer(self, target, method="interpolate"):
"""
Transfer all functions from this :class:`~.FunctionData` object to the target
:class:`~.FunctionData` object by interpolation, projection or prolongation.
:arg target: the target :class:`~.FunctionData` object to which to transfer the
data
:type target: :class:`~.FunctionData`
:arg method: the transfer method to use. Either 'interpolate', 'project' or
'prolong'
:type method: :class:`str`
"""
stp = self.time_partition
ttp = target.time_partition

if method not in ["interpolate", "project", "prolong"]:
raise ValueError(
f"Transfer method '{method}' not supported."
" Supported methods are 'interpolate', 'project', and 'prolong'."
)
if stp.num_subintervals != ttp.num_subintervals:
raise ValueError(
"Source and target have different numbers of subintervals."
)
if stp.num_exports_per_subinterval != ttp.num_exports_per_subinterval:
raise ValueError(
"Source and target have different numbers of exports per subinterval."
)

common_fields = set(stp.field_names) & set(ttp.field_names)
if not common_fields:
raise ValueError("No common fields between source and target.")

common_labels = set(self.labels) & set(target.labels)
if not common_labels:
raise ValueError("No common labels between source and target.")

for field in common_fields:
for label in common_labels:
for i in range(stp.num_subintervals):
for j in range(stp.num_exports_per_subinterval[i] - 1):
source_function = self._data[field][label][i][j]
target_function = target._data[field][label][i][j]
if method == "interpolate":
target_function.interpolate(source_function)
elif method == "project":
target_function.project(source_function)
elif method == "prolong":
TransferManager().prolong(source_function, target_function)


class ForwardSolutionData(FunctionData):
"""
Expand Down
188 changes: 188 additions & 0 deletions test/test_function_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,5 +267,193 @@ def test_export_h5(self):
self.assertTrue(os.path.exists(export_filepath))


class TestTransferFunctionData(BaseTestCases.TestFunctionData):
"""
Unit tests for transferring data from one :class:`~.FunctionData` to another.
"""

def setUp(self):
super().setUpUnsteady()
self.labels = ("forward", "forward_old")

def _create_function_data(self):
self.solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
self.solution_data._create_data()

# Assign 1 to all functions
tp = self.solution_data.time_partition
for field in tp.field_names:
for label in self.solution_data.labels:
for i in range(tp.num_subintervals):
for j in range(tp.num_exports_per_subinterval[i] - 1):
self.solution_data._data[field][label][i][j].assign(1)

def test_transfer_method_error(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="invalid_method")
self.assertEqual(
str(cm.exception),
"Transfer method 'invalid_method' not supported."
" Supported methods are 'interpolate', 'project', and 'prolong'.",
)

def test_transfer_subintervals_error(self):
target_time_partition = TimePartition(
1.5 * self.time_partition.end_time,
self.time_partition.num_subintervals + 1,
self.time_partition.timesteps + [0.25],
self.time_partition.field_names,
)
target_function_spaces = {
self.field: [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception),
"Source and target have different numbers of subintervals.",
)

def test_transfer_exports_error(self):
target_time_partition = TimePartition(
self.time_partition.end_time,
self.time_partition.num_subintervals,
self.time_partition.timesteps,
self.time_partition.field_names,
num_timesteps_per_export=[1, 2],
)
target_function_spaces = {
self.field: [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception),
"Source and target have different numbers of exports per subinterval.",
)

def test_transfer_common_fields_error(self):
target_time_partition = TimePartition(
self.time_partition.end_time,
self.time_partition.num_subintervals,
self.time_partition.timesteps,
["different_field"],
)
target_function_spaces = {
"different_field": [
FunctionSpace(self.mesh, "DG", 0)
for _ in range(target_time_partition.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
target_time_partition, target_function_spaces
)
target_solution_data._create_data()
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception), "No common fields between source and target."
)

def test_transfer_common_labels_error(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
target_solution_data.labels = ("different_label",)
with self.assertRaises(ValueError) as cm:
self.solution_data.transfer(target_solution_data, method="interpolate")
self.assertEqual(
str(cm.exception), "No common labels between source and target."
)

def test_transfer_interpolate(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="interpolate")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)

def test_transfer_project(self):
target_solution_data = ForwardSolutionData(
self.time_partition, self.function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="project")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)

def test_transfer_prolong(self):
enriched_mesh = MeshHierarchy(self.mesh, 1)[-1]
target_function_spaces = {
self.field: [
FunctionSpace(enriched_mesh, "DG", 0)
for _ in range(self.num_subintervals)
]
}
target_solution_data = ForwardSolutionData(
self.time_partition, target_function_spaces
)
target_solution_data._create_data()
self.solution_data.transfer(target_solution_data, method="prolong")
for field in self.solution_data.time_partition.field_names:
for label in self.solution_data.labels:
for i in range(self.solution_data.time_partition.num_subintervals):
for j in range(
self.solution_data.time_partition.num_exports_per_subinterval[i]
- 1
):
source_function = self.solution_data._data[field][label][i][j]
target_function = target_solution_data._data[field][label][i][j]
self.assertTrue(
source_function.dat.data.all()
== target_function.dat.data.all()
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 9f4f9bb

Please sign in to comment.