Skip to content

Commit

Permalink
Implement non-breaking solution and update tests accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasVin committed Jan 11, 2025
1 parent debb274 commit a6edac4
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
21 changes: 6 additions & 15 deletions fenicsprecice/fenicsprecice.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,26 +459,17 @@ def retrieve_checkpoint(self):
-------
u : FEniCS Function
Current state of the physical variable of interest for this participant.
t : double (optional)
Current simulation time.
n : int (optional)
Current time window (iteration) number.
t : double
Current simulation time or None if not specified in store_checkpoint
n : int
Current time window (iteration) number or None if not specified in store_checkpoint
"""
assert (not self.is_time_window_complete())
logger.debug("Restore solver state")

# since t and n are optional, they should not be returned, if not specified
payload, t, n = self._checkpoint.get_state()
match (t, n):
case (None, None):
return payload
case (_, None):
return payload, t
case (None, _):
return payload, n
case _:
return payload, t, n

return self._checkpoint.get_state()

def advance(self, dt):
"""
Advances coupling in preCICE.
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/test_write_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,17 @@ def test_opt_parameters(adapter):
adapter.store_checkpoint(u, t) # without n
res = adapter.retrieve_checkpoint()
self.assertEqual(len(res), 2) # correct number of return values
res_u, res_t = res
res_u, res_t, res_n = res
self.assertEqual(res_t, t)
self.assertEqual(res_n, None)
np.testing.assert_array_equal(res_u.vector(), u.vector())

adapter.store_checkpoint(u, n) # without t
res = adapter.retrieve_checkpoint()
self.assertEqual(len(res), 2) # correct number of return values
res_u, res_n = res
res_u, res_t, res_n = res
self.assertEqual(res_n, n)
self.assertEqual(res_t, None)
np.testing.assert_array_equal(res_u.vector(), u.vector())

def test_payload_only(adapter):
Expand All @@ -241,8 +243,10 @@ def test_payload_only(adapter):
u = interpolate(E, V)
# test adapter
adapter.store_checkpoint(u) # no optional parameters
res = adapter.retrieve_checkpoint()
np.testing.assert_array_equal(res.vector(), u.vector())
res_u, res_t, res_n = adapter.retrieve_checkpoint()
self.assertEqual(res_t, None)
self.assertEqual(res_n, None)
np.testing.assert_array_equal(res_u.vector(), u.vector())



Expand Down

0 comments on commit a6edac4

Please sign in to comment.