diff --git a/goalie/adjoint.py b/goalie/adjoint.py index b3c5a756..b94fee85 100644 --- a/goalie/adjoint.py +++ b/goalie/adjoint.py @@ -393,14 +393,47 @@ def wrapped_solver(subinterval, initial_condition_map, **kwargs): if block.adj_sol is not None: solutions.adjoint[i][j].assign(block.adj_sol) - # Lagged forward solution comes from dependencies - dep = self._dependency(field, i, block) - if not self.steady and dep is not None: - solutions.forward_old[i][j].assign(dep.saved_output) - - # Adjoint action also comes from dependencies - if get_adj_values and dep is not None: - solutions.adj_value[i][j].assign(dep.adj_value) + if not self.steady: + # Lagged solution comes from dependencies for unsteady fields + if self.field_types[field] == "unsteady": + dep = self._dependency(field, i, block) + solutions.forward_old[i][j].assign(dep.saved_output) + # Adjoint action also comes from dependencies + if get_adj_values: + solutions.adj_value[i][j].assign(dep.adj_value) + # Lagged solution comes from previous block for steady fields + if self.field_types[field] == "steady": + if stride == 1: + if j == 0: + if i == 0: + forward_old = self.initial_condition[field] + else: + forward_old = self._transfer( + solutions.forward[i - 1][-1], fs[i] + ) + else: + forward_old = solutions.forward[i][j - 1] + if get_adj_values: + # TODO this does not consider the j==0 case + if j == num_solve_blocks - 1: + if i != num_subintervals - 1: + adj_value = self._transfer( + out.adj_value, fs[i + 1] + ) + solutions.adj_value[i + 1][0].assign( + adj_value + ) + else: + solutions.adj_value[i][j + 1].assign( + out.adj_value + ) + else: + old_block = solve_blocks[solve_blocks.index(block) - 1] + old_out = self._output(field, i, old_block) + forward_old = old_out.saved_output + if get_adj_values: + solutions.adj_value[i][j].assign(old_out.adj_value) + solutions.forward_old[i][j].assign(forward_old) # The adjoint solution at the 'next' timestep is determined from the # adj_sol attribute of the next solve block diff --git a/goalie/mesh_seq.py b/goalie/mesh_seq.py index 8ad5087d..4703d8a5 100644 --- a/goalie/mesh_seq.py +++ b/goalie/mesh_seq.py @@ -760,10 +760,29 @@ def solve_forward(self, solver_kwargs={}): if out is not None: solutions.forward[i][j].assign(out.saved_output) - # Lagged solution comes from dependencies - dep = self._dependency(field, i, block) - if not self.steady and dep is not None: - solutions.forward_old[i][j].assign(dep.saved_output) + if not self.steady: + # Lagged solution comes from dependencies for unsteady fields + if self.field_types[field] == "unsteady": + dep = self._dependency(field, i, block) + solutions.forward_old[i][j].assign(dep.saved_output) + # Lagged solution comes from previous block for steady fields + elif self.field_types[field] == "steady": + if stride == 1: + if j == 0: + if i == 0: + forward_old = self.initial_condition[field] + else: + forward_old = self._transfer( + solutions.forward[i - 1][-1], fs[i] + ) + else: + forward_old = solutions.forward[i][j - 1] + else: + old_block = solve_blocks[solve_blocks.index(block) - 1] + old_out = self._output(field, i, old_block) + if out is not None: + forward_old = old_out.saved_output + solutions.forward_old[i][j].assign(forward_old) # Transfer the checkpoint between subintervals if i < num_subintervals - 1: