Skip to content

Commit

Permalink
#153: Extract lagged fields in mixed cases
Browse files Browse the repository at this point in the history
  • Loading branch information
ddundo committed Apr 15, 2024
1 parent fc2cf23 commit 8e45ed4
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 12 deletions.
49 changes: 41 additions & 8 deletions goalie/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,47 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs):
if block.adj_sol is not None:
solutions.adjoint[i][j].assign(block.adj_sol)

# Lagged forward solution comes from dependencies
dep = self._dependency(field, i, block)
if not self.steady and dep is not None:
solutions.forward_old[i][j].assign(dep.saved_output)

# Adjoint action also comes from dependencies
if get_adj_values and dep is not None:
solutions.adj_value[i][j].assign(dep.adj_value)
if not self.steady:
# Lagged solution comes from dependencies for unsteady fields
if self.field_types[field] == "unsteady":
dep = self._dependency(field, i, block)
solutions.forward_old[i][j].assign(dep.saved_output)
# Adjoint action also comes from dependencies
if get_adj_values:
solutions.adj_value[i][j].assign(dep.adj_value)
# Lagged solution comes from previous block for steady fields
if self.field_types[field] == "steady":
if stride == 1:
if j == 0:
if i == 0:
forward_old = self.initial_condition[field]
else:
forward_old = self._transfer(
solutions.forward[i - 1][-1], fs[i]
)
else:
forward_old = solutions.forward[i][j - 1]
if get_adj_values:
# TODO this does not consider the j==0 case
if j == num_solve_blocks - 1:
if i != num_subintervals - 1:
adj_value = self._transfer(
out.adj_value, fs[i + 1]
)
solutions.adj_value[i + 1][0].assign(
adj_value
)
else:
solutions.adj_value[i][j + 1].assign(
out.adj_value
)
else:
old_block = solve_blocks[solve_blocks.index(block) - 1]
old_out = self._output(field, i, old_block)
forward_old = old_out.saved_output
if get_adj_values:
solutions.adj_value[i][j].assign(old_out.adj_value)
solutions.forward_old[i][j].assign(forward_old)

# The adjoint solution at the 'next' timestep is determined from the
# adj_sol attribute of the next solve block
Expand Down
27 changes: 23 additions & 4 deletions goalie/mesh_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,10 +760,29 @@ def solve_forward(self, solver_kwargs={}):
if out is not None:
solutions.forward[i][j].assign(out.saved_output)

# Lagged solution comes from dependencies
dep = self._dependency(field, i, block)
if not self.steady and dep is not None:
solutions.forward_old[i][j].assign(dep.saved_output)
if not self.steady:
# Lagged solution comes from dependencies for unsteady fields
if self.field_types[field] == "unsteady":
dep = self._dependency(field, i, block)
solutions.forward_old[i][j].assign(dep.saved_output)
# Lagged solution comes from previous block for steady fields
elif self.field_types[field] == "steady":
if stride == 1:
if j == 0:
if i == 0:
forward_old = self.initial_condition[field]
else:
forward_old = self._transfer(
solutions.forward[i - 1][-1], fs[i]
)
else:
forward_old = solutions.forward[i][j - 1]
else:
old_block = solve_blocks[solve_blocks.index(block) - 1]
old_out = self._output(field, i, old_block)
if out is not None:
forward_old = old_out.saved_output
solutions.forward_old[i][j].assign(forward_old)

# Transfer the checkpoint between subintervals
if i < num_subintervals - 1:
Expand Down

0 comments on commit 8e45ed4

Please sign in to comment.