Skip to content

Commit

Permalink
w/ set_stat()/get_state()
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq committed Nov 27, 2023
1 parent 655cdca commit 62b071b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 8 additions & 0 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
self.state = state

def to_client(self) -> Client:
"""Return client (itself)."""
return self
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def handle(
if server_msg is None:
# Instantiate the client
client = client_fn("-1")
client.state = state
client.set_state(state)
# Secure Aggregation
if task_ins.task.HasField("sa") and isinstance(
client, SecureAggregationHandler
Expand All @@ -118,7 +118,7 @@ def handle(
sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
),
)
return task_res, client.state
return task_res, client.get_state()
raise NotImplementedError()
client_msg, updated_state = handle_legacy_message(client_fn, state, server_msg)
task_res = wrap_client_message_in_task_res(client_msg)
Expand Down Expand Up @@ -152,7 +152,7 @@ def handle_legacy_message(

# Instantiate the client
client = client_fn("-1")
client.state = state
client.set_state(state)
# Execute task
message = None
if field == "get_properties_ins":
Expand All @@ -164,7 +164,7 @@ def handle_legacy_message(
if field == "evaluate_ins":
message = _evaluate(client, server_msg.evaluate_ins)
if message:
return message, client.state
return message, client.get_state()
raise UnknownServerMessage()


Expand Down

0 comments on commit 62b071b

Please sign in to comment.