Skip to content

Commit

Permalink
Identify gRPC clients using UUID instead of gRPC peer (#2889)
Browse files Browse the repository at this point in the history
Co-authored-by: Charles Beauville <[email protected]>
Co-authored-by: Alvaro Lopez Garcia <[email protected]>
  • Loading branch information
charlesbvll and alvarolopez authored Feb 1, 2024
1 parent d371edc commit 700445c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
9 changes: 6 additions & 3 deletions src/py/flwr/server/fleet/grpc_bidi/flower_service_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
- https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
"""


import uuid
from typing import Callable, Iterator

import grpc
Expand Down Expand Up @@ -91,9 +91,12 @@ def Join( # pylint: disable=invalid-name
wrapping the actual message
- The `Join` method is (pretty much) unaware of the protocol
"""
peer: str = context.peer()
# When running Flower behind a proxy, the peer can be the same for
# different clients, so instead of `cid: str = context.peer()` we
# use a `UUID4` that is unique.
cid: str = uuid.uuid4().hex
bridge = self.grpc_bridge_factory()
client_proxy = self.client_proxy_factory(peer, bridge)
client_proxy = self.client_proxy_factory(cid, bridge)
is_success = register_client_proxy(self.client_manager, client_proxy, context)

if is_success:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


import unittest
import uuid
from unittest.mock import MagicMock, call

from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
Expand All @@ -30,7 +31,8 @@

CLIENT_MESSAGE = ClientMessage()
SERVER_MESSAGE = ServerMessage()
CLIENT_CID = "some_client_cid"

CID: str = uuid.uuid4().hex


class FlowerServiceServicerTestCase(unittest.TestCase):
Expand All @@ -42,7 +44,6 @@ def setUp(self) -> None:
"""Create mocks for tests."""
# Mock for the gRPC context argument
self.context_mock = MagicMock()
self.context_mock.peer.return_value = CLIENT_CID

# Define client_messages to be processed by FlowerServiceServicer instance
self.client_messages = [CLIENT_MESSAGE for _ in range(5)]
Expand Down Expand Up @@ -70,7 +71,7 @@ def setUp(self) -> None:
# Create a GrpcClientProxy mock which we will use to test if correct
# methods where called and client_messages are getting passed to it
self.grpc_client_proxy_mock = MagicMock()
self.grpc_client_proxy_mock.cid = CLIENT_CID
self.grpc_client_proxy_mock.cid = CID

self.client_proxy_factory_mock = MagicMock()
self.client_proxy_factory_mock.return_value = self.grpc_client_proxy_mock
Expand Down Expand Up @@ -127,11 +128,7 @@ def test_join(self) -> None:
num_server_messages += 1

assert len(self.client_messages) == num_server_messages
assert self.grpc_client_proxy_mock.cid == CLIENT_CID

self.client_proxy_factory_mock.assert_called_once_with(
CLIENT_CID, self.grpc_bridge_mock
)
assert self.grpc_client_proxy_mock.cid == CID

# Check if the client was registered with the client_manager
self.client_manager_mock.register.assert_called_once_with(
Expand Down

0 comments on commit 700445c

Please sign in to comment.