diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py index 6eccb056390a..1397a026a33c 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py +++ b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py @@ -18,7 +18,7 @@ - https://github.com/grpc/grpc/blob/master/doc/statuscodes.md """ - +import uuid from typing import Callable, Iterator import grpc @@ -91,9 +91,12 @@ def Join( # pylint: disable=invalid-name wrapping the actual message - The `Join` method is (pretty much) unaware of the protocol """ - peer: str = context.peer() + # When running Flower behind a proxy, the peer can be the same for + # different clients, so instead of `cid: str = context.peer()` we + # use a `UUID4` that is unique. + cid: str = uuid.uuid4().hex bridge = self.grpc_bridge_factory() - client_proxy = self.client_proxy_factory(peer, bridge) + client_proxy = self.client_proxy_factory(cid, bridge) is_success = register_client_proxy(self.client_manager, client_proxy, context) if is_success: diff --git a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py index b5c3f504af03..37eef167879c 100644 --- a/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py +++ b/src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer_test.py @@ -16,6 +16,7 @@ import unittest +import uuid from unittest.mock import MagicMock, call from flwr.proto.transport_pb2 import ( # pylint: disable=E0611 @@ -30,7 +31,8 @@ CLIENT_MESSAGE = ClientMessage() SERVER_MESSAGE = ServerMessage() -CLIENT_CID = "some_client_cid" + +CID: str = uuid.uuid4().hex class FlowerServiceServicerTestCase(unittest.TestCase): @@ -42,7 +44,6 @@ def setUp(self) -> None: """Create mocks for tests.""" # Mock for the gRPC context argument self.context_mock = MagicMock() - self.context_mock.peer.return_value = CLIENT_CID # Define client_messages to be processed by FlowerServiceServicer instance self.client_messages = [CLIENT_MESSAGE for _ in range(5)] @@ -70,7 +71,7 @@ def setUp(self) -> None: # Create a GrpcClientProxy mock which we will use to test if correct # methods where called and client_messages are getting passed to it self.grpc_client_proxy_mock = MagicMock() - self.grpc_client_proxy_mock.cid = CLIENT_CID + self.grpc_client_proxy_mock.cid = CID self.client_proxy_factory_mock = MagicMock() self.client_proxy_factory_mock.return_value = self.grpc_client_proxy_mock @@ -127,11 +128,7 @@ def test_join(self) -> None: num_server_messages += 1 assert len(self.client_messages) == num_server_messages - assert self.grpc_client_proxy_mock.cid == CLIENT_CID - - self.client_proxy_factory_mock.assert_called_once_with( - CLIENT_CID, self.grpc_bridge_mock - ) + assert self.grpc_client_proxy_mock.cid == CID # Check if the client was registered with the client_manager self.client_manager_mock.register.assert_called_once_with(