Skip to content

Commit

Permalink
tests; removed #TODO; v0 ready?
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Nov 24, 2023
1 parent a7babb7 commit bb5644f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def handle(
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
client.state = state # TODO: inject state into client object
client.state = state
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -118,7 +118,7 @@ def handle(
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res, client.state # TODO: return updated state
return task_res, client.state
raise NotImplementedError()
client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
Expand Down Expand Up @@ -152,7 +152,7 @@ def handle_legacy_message(

# Instantiate the client
client = client_fn("-1")
client.state = state # TODO: inject state into client object
client.state = state
# Execute task
message = None
if field == "get_properties_ins":
Expand All @@ -164,7 +164,7 @@ def handle_legacy_message(
if field == "evaluate_ins":
message = _evaluate(client, server_msg.evaluate_ins)
if message:
return message, client.state # TODO: return updated state
return message, client.state
raise UnknownServerMessage()


Expand Down
13 changes: 11 additions & 2 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from flwr.client import Client
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import (
EvaluateIns,
EvaluateRes,
Expand Down Expand Up @@ -133,7 +134,11 @@ def test_client_without_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, state_updated = handle( # pylint: disable=unused-variable
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down Expand Up @@ -197,7 +202,11 @@ def test_client_with_get_properties() -> None:
disconnect_task_res, actual_sleep_duration = handle_control_message(
task_ins=task_ins
)
task_res = handle(client_fn=_get_client_fn(client), task_ins=task_ins)
task_res, updated_state = handle( # pylint: disable=unused-variable
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
task_ins=task_ins,
)

if not task_res.HasField("task"):
raise ValueError("Task value not found")
Expand Down

0 comments on commit bb5644f

Please sign in to comment.