Skip to content

Commit

Permalink
Refactor start client control message (#2559)
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Nov 6, 2023
1 parent c82c9e2 commit e70661a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 38 deletions.
12 changes: 8 additions & 4 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
from .message_handler.message_handler import handle
from .message_handler.message_handler import handle, handle_control_message
from .numpy_client import NumPyClient


Expand Down Expand Up @@ -159,13 +159,17 @@ def single_client_factory(
time.sleep(3) # Wait for 3s before asking again
continue

# Handle control message
task_res, sleep_duration = handle_control_message(task_ins=task_ins)
if task_res:
send(task_res)
break

# Handle task message
task_res, sleep_duration, keep_going = handle(client_fn, task_ins)
task_res = handle(client_fn, task_ins)

# Send
send(task_res)
if not keep_going:
break

# Unregister node
if delete_node is not None:
Expand Down
79 changes: 52 additions & 27 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Client-side message handler."""


from typing import Tuple
from typing import Optional, Tuple

from flwr.client.client import (
Client,
Expand All @@ -35,24 +35,24 @@
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage


class UnexpectedServerMessage(Exception):
"""Exception indicating that the received message is unexpected."""


class UnknownServerMessage(Exception):
"""Exception indicating that the received message is unknown."""


def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
"""Handle incoming TaskIns from the server.
def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:
"""Handle control part of the incoming message.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
task_ins: TaskIns
task_ins : TaskIns
The task instruction coming from the server, to be processed by the client.
Returns
-------
task_res: TaskRes
The task response that should be returned to the server.
sleep_duration : int
Number of seconds that the client should disconnect from the server.
keep_going : bool
Expand All @@ -61,6 +61,38 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
reconnect later (False).
"""
server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)

# SecAgg message
if server_msg is None:
return None, 0

# ReconnectIns message
field = server_msg.WhichOneof("msg")
if field == "reconnect_ins":
disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins)
task_res = wrap_client_message_in_task_res(disconnect_msg)
return task_res, sleep_duration

# Any other message
return None, 0


def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
"""Handle incoming TaskIns from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.
Returns
-------
task_res : TaskRes
The task response that should be returned to the server.
"""
server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
Expand All @@ -80,18 +112,16 @@ def handle(client_fn: ClientFn, task_ins: TaskIns) -> Tuple[TaskRes, int, bool]:
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res, 0, True
return task_res
raise NotImplementedError()
client_msg, sleep_duration, keep_going = handle_legacy_message(
client_fn, server_msg
)
client_msg = handle_legacy_message(client_fn, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
return task_res, sleep_duration, keep_going
return task_res


def handle_legacy_message(
client_fn: ClientFn, server_msg: ServerMessage
) -> Tuple[ClientMessage, int, bool]:
) -> ClientMessage:
"""Handle incoming messages from the server.
Parameters
Expand All @@ -103,31 +133,26 @@ def handle_legacy_message(
Returns
-------
client_msg: ClientMessage
client_msg : ClientMessage
The result message that should be returned to the server.
sleep_duration : int
Number of seconds that the client should disconnect from the server.
keep_going : bool
Flag that indicates whether the client should continue to process the
next message from the server (True) or disconnect and optionally
reconnect later (False).
"""
field = server_msg.WhichOneof("msg")

# Must be handled elsewhere
if field == "reconnect_ins":
disconnect_msg, sleep_duration = _reconnect(server_msg.reconnect_ins)
return disconnect_msg, sleep_duration, False
raise UnexpectedServerMessage()

# Instantiate the client
client = client_fn("-1")
# Execute task
if field == "get_properties_ins":
return _get_properties(client, server_msg.get_properties_ins), 0, True
return _get_properties(client, server_msg.get_properties_ins)
if field == "get_parameters_ins":
return _get_parameters(client, server_msg.get_parameters_ins), 0, True
return _get_parameters(client, server_msg.get_parameters_ins)
if field == "fit_ins":
return _fit(client, server_msg.fit_ins), 0, True
return _fit(client, server_msg.fit_ins)
if field == "evaluate_ins":
return _evaluate(client, server_msg.evaluate_ins), 0, True
return _evaluate(client, server_msg.evaluate_ins)
raise UnknownServerMessage()


Expand Down
16 changes: 9 additions & 7 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Code, ServerMessage, Status

from .message_handler import handle
from .message_handler import handle, handle_control_message


class ClientWithoutProps(Client):
Expand Down Expand Up @@ -130,9 +130,10 @@ def test_client_without_get_properties() -> None:
)

# Execute
task_res, actual_sleep_duration, actual_keep_going = handle(
client_fn=_get_client_fn(client), task_ins=task_ins
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down Expand Up @@ -171,8 +172,8 @@ def test_client_without_get_properties() -> None:
expected_msg = ClientMessage(get_properties_res=expected_get_properties_res)

assert actual_msg == expected_msg
assert not disconnect_task_res
assert actual_sleep_duration == 0
assert actual_keep_going is True


def test_client_with_get_properties() -> None:
Expand All @@ -193,9 +194,10 @@ def test_client_with_get_properties() -> None:
)

# Execute
task_res, actual_sleep_duration, actual_keep_going = handle(
client_fn=_get_client_fn(client), task_ins=task_ins
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down Expand Up @@ -237,5 +239,5 @@ def test_client_with_get_properties() -> None:
expected_msg = ClientMessage(get_properties_res=expected_get_properties_res)

assert actual_msg == expected_msg
assert not disconnect_task_res
assert actual_sleep_duration == 0
assert actual_keep_going is True

0 comments on commit e70661a

Please sign in to comment.