diff --git a/src/py/flwr/client/message_handler/message_handler_test.py b/src/py/flwr/client/message_handler/message_handler_test.py index 2a510b291c49..5244951c8a48 100644 --- a/src/py/flwr/client/message_handler/message_handler_test.py +++ b/src/py/flwr/client/message_handler/message_handler_test.py @@ -163,7 +163,7 @@ def test_client_without_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl=DEFAULT_TTL, + ttl=actual_msg.metadata.ttl, # computed based on [message].create_reply() message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, @@ -227,7 +227,7 @@ def test_client_with_get_properties() -> None: src_node_id=1123, dst_node_id=0, reply_to_message=message.metadata.message_id, - ttl=DEFAULT_TTL, + ttl=actual_msg.metadata.ttl, # computed based on [message].create_reply() message_type=MessageTypeLegacy.GET_PROPERTIES, ), content=expected_rs, diff --git a/src/py/flwr/common/message.py b/src/py/flwr/common/message.py index 6e0ab9149828..7707f3c72de1 100644 --- a/src/py/flwr/common/message.py +++ b/src/py/flwr/common/message.py @@ -297,22 +297,33 @@ def _create_reply_metadata(self, ttl: float) -> Metadata: partition_id=self.metadata.partition_id, ) - def create_error_reply( - self, - error: Error, - ttl: float, - ) -> Message: + def create_error_reply(self, error: Error, ttl: float | None = None) -> Message: """Construct a reply message indicating an error happened. Parameters ---------- error : Error The error that was encountered. - ttl : float - Time-to-live for this message in seconds. + ttl : Optional[float] (default: None) + Time-to-live for this message in seconds. If unset, it will be set based + on the remaining time for the received message before it expires. This + follows the equation: + + ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at) """ + # If no TTL passed, use default for message creation (will update after + # message creation) + ttl_ = DEFAULT_TTL if ttl is None else ttl # Create reply with error - message = Message(metadata=self._create_reply_metadata(ttl), error=error) + message = Message(metadata=self._create_reply_metadata(ttl_), error=error) + + if ttl is None: + # Set TTL equal to the remaining time for the received message to expire + ttl = self.metadata.ttl - ( + message.metadata.created_at - self.metadata.created_at + ) + message.metadata.ttl = ttl + return message def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message: @@ -327,18 +338,31 @@ def create_reply(self, content: RecordSet, ttl: float | None = None) -> Message: content : RecordSet The content for the reply message. ttl : Optional[float] (default: None) - Time-to-live for this message in seconds. If unset, it will use - the `common.DEFAULT_TTL` value. + Time-to-live for this message in seconds. If unset, it will be set based + on the remaining time for the received message before it expires. This + follows the equation: + + ttl = msg.meta.ttl - (reply.meta.created_at - msg.meta.created_at) Returns ------- Message A new `Message` instance representing the reply. """ - if ttl is None: - ttl = DEFAULT_TTL + # If no TTL passed, use default for message creation (will update after + # message creation) + ttl_ = DEFAULT_TTL if ttl is None else ttl - return Message( - metadata=self._create_reply_metadata(ttl), + message = Message( + metadata=self._create_reply_metadata(ttl_), content=content, ) + + if ttl is None: + # Set TTL equal to the remaining time for the received message to expire + ttl = self.metadata.ttl - ( + message.metadata.created_at - self.metadata.created_at + ) + message.metadata.ttl = ttl + + return message diff --git a/src/py/flwr/common/message_test.py b/src/py/flwr/common/message_test.py index cd5a7d72272f..1a5da0517352 100644 --- a/src/py/flwr/common/message_test.py +++ b/src/py/flwr/common/message_test.py @@ -16,7 +16,7 @@ import time from contextlib import ExitStack -from typing import Any, Callable +from typing import Any, Callable, Optional import pytest @@ -73,17 +73,21 @@ def test_message_creation( assert message.metadata.created_at < time.time() -def create_message_with_content() -> Message: +def create_message_with_content(ttl: Optional[float] = None) -> Message: """Create a Message with content.""" maker = RecordMaker(state=2) metadata = maker.metadata() + if ttl: + metadata.ttl = ttl return Message(metadata=metadata, content=RecordSet()) -def create_message_with_error() -> Message: +def create_message_with_error(ttl: Optional[float] = None) -> Message: """Create a Message with error.""" maker = RecordMaker(state=2) metadata = maker.metadata() + if ttl: + metadata.ttl = ttl return Message(metadata=metadata, error=Error(code=1)) @@ -111,3 +115,45 @@ def test_altering_message( message.error = Error(code=123) if message.has_error(): message.content = RecordSet() + + +@pytest.mark.parametrize( + "message_creation_fn,ttl,reply_ttl", + [ + (create_message_with_content, 1e6, None), + (create_message_with_error, 1e6, None), + (create_message_with_content, 1e6, 3600), + (create_message_with_error, 1e6, 3600), + ], +) +def test_create_reply( + message_creation_fn: Callable[ + [float], + Message, + ], + ttl: float, + reply_ttl: Optional[float], +) -> None: + """Test reply creation from message.""" + message: Message = message_creation_fn(ttl) + + time.sleep(0.1) + + if message.has_error(): + dummy_error = Error(code=0, reason="it crashed") + reply_message = message.create_error_reply(dummy_error, ttl=reply_ttl) + else: + reply_message = message.create_reply(content=RecordSet(), ttl=reply_ttl) + + # Ensure reply has a higher timestamp + assert message.metadata.created_at < reply_message.metadata.created_at + if reply_ttl: + # Ensure the TTL is the one specify upon reply creation + assert reply_message.metadata.ttl == reply_ttl + else: + # Ensure reply ttl is lower (since it uses remaining time left) + assert message.metadata.ttl > reply_message.metadata.ttl + + assert message.metadata.src_node_id == reply_message.metadata.dst_node_id + assert message.metadata.dst_node_id == reply_message.metadata.src_node_id + assert reply_message.metadata.reply_to_message == message.metadata.message_id