Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle client legacy message from task_type #2839

Closed
wants to merge 15 commits into from
23 changes: 13 additions & 10 deletions src/py/flwr/client/flower.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
import importlib
from typing import List, Optional, cast

from flwr.client.message_handler.message_handler import handle
from flwr.client.message_handler.message_handler import (
handle_legacy_message_from_tasktype,
)
from flwr.client.middleware.utils import make_ffn
from flwr.client.typing import Bwd, ClientFn, Fwd, Layer
from flwr.client.typing import ClientFn, Layer
from flwr.common.flowercontext import FlowerContext


class Flower:
Expand Down Expand Up @@ -55,20 +58,20 @@ def __init__(
layers: Optional[List[Layer]] = None,
) -> None:
# Create wrapper function for `handle`
def ffn(fwd: Fwd) -> Bwd: # pylint: disable=invalid-name
task_res, state_updated = handle(
client_fn=client_fn,
state=fwd.state,
task_ins=fwd.task_ins,
def ffn(
context: FlowerContext,
) -> FlowerContext: # pylint: disable=invalid-name
context = handle_legacy_message_from_tasktype(
client_fn=client_fn, context=context
)
return Bwd(task_res=task_res, state=state_updated)
return context

# Wrap middleware layers around the wrapped handle function
self._call = make_ffn(ffn, layers if layers is not None else [])

def __call__(self, fwd: Fwd) -> Bwd:
def __call__(self, context: FlowerContext) -> FlowerContext:
"""."""
return self._call(fwd)
return self._call(context)


class LoadCallableError(Exception):
Expand Down
49 changes: 49 additions & 0 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.common.flowercontext import FlowerContext
from flwr.common.recordset_utils import (
evaluate_res_to_recordset,
fit_res_to_recordset,
getparameters_res_to_recordset,
getproperties_res_to_recordset,
recordset_to_evaluate_ins,
recordset_to_fit_ins,
recordset_to_getparameters_ins,
recordset_to_getproperties_ins,
)
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
SecureAggregation,
Task,
Expand Down Expand Up @@ -177,6 +188,44 @@ def handle_legacy_message(
raise UnknownServerMessage()


def handle_legacy_message_from_tasktype(
client_fn: ClientFn, context: FlowerContext
) -> FlowerContext:
"""Handle legacy message in the inner most middleware layer."""
client = client_fn("-1")
task_type = context.metadata.task_type

if task_type == "get_properties_ins":
get_properties_res = maybe_call_get_properties(
client=client,
get_properties_ins=recordset_to_getproperties_ins(context.in_message),
)
context.out_message = getproperties_res_to_recordset(get_properties_res)
elif task_type == "get_parameteres_ins":
get_parameters_res = maybe_call_get_parameters(
client=client,
get_parameters_ins=recordset_to_getparameters_ins(context.in_message),
)
context.out_message = getparameters_res_to_recordset(get_parameters_res)
elif task_type == "fit_ins":
fit_res = maybe_call_fit(
client=client,
fit_ins=recordset_to_fit_ins(context.in_message),
)
context.out_message = fit_res_to_recordset(fit_res, keep_input=False)
elif task_type == "evaluate_ins":
evaluate_res = maybe_call_evaluate(
client=client,
evaluate_ins=recordset_to_evaluate_ins(context.in_message),
)
context.out_message = evaluate_res_to_recordset(evaluate_res)
else:
# TODO: what to do with reconnect?
print("do something")

return context


def _reconnect(
reconnect_msg: ServerMessage.ReconnectIns,
) -> Tuple[ClientMessage, int]:
Expand Down
Loading