diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index 9efb0748d9d5..3448e18e20c5 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -41,8 +41,8 @@ from .grpc_client.connection import grpc_connection from .grpc_rere_client.connection import grpc_request_response from .message_handler.message_handler import handle_control_message +from .node_state import NodeState from .numpy_client import NumPyClient -from .workload_state import WorkloadState def run_client() -> None: @@ -318,6 +318,8 @@ def _load_app() -> Flower: # Initialize connection context manager connection, address = _init_connection(transport, server_address) + node_state = NodeState() + while True: sleep_duration: int = 0 with connection( @@ -345,16 +347,27 @@ def _load_app() -> Flower: send(task_res) break + # Register state + node_state.register_workloadstate(workload_id=task_ins.workload_id) + # Load app app: Flower = load_flower_callable_fn() # Handle task message fwd_msg: Fwd = Fwd( task_ins=task_ins, - state=WorkloadState(state={}), + state=node_state.retrieve_workloadstate( + workload_id=task_ins.workload_id + ), ) bwd_msg: Bwd = app(fwd=fwd_msg) + # Update node state + node_state.update_workloadstate( + workload_id=bwd_msg.task_res.workload_id, + workload_state=bwd_msg.state, + ) + # Send send(bwd_msg.task_res) diff --git a/src/py/flwr/client/node_state.py b/src/py/flwr/client/node_state.py new file mode 100644 index 000000000000..ee4f70dc4dca --- /dev/null +++ b/src/py/flwr/client/node_state.py @@ -0,0 +1,50 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Node state.""" + + +from typing import Any, Dict + +from flwr.client.workload_state import WorkloadState + + +class NodeState: + """State of a node where client nodes execute workloads.""" + + def __init__(self) -> None: + self._meta: Dict[str, Any] = {} # holds metadata about the node + self.workload_states: Dict[int, WorkloadState] = {} + + def register_workloadstate(self, workload_id: int) -> None: + """Register new workload state for this node.""" + if workload_id not in self.workload_states: + self.workload_states[workload_id] = WorkloadState({}) + + def retrieve_workloadstate(self, workload_id: int) -> WorkloadState: + """Get workload state given a workload_id.""" + if workload_id in self.workload_states: + return self.workload_states[workload_id] + + raise RuntimeError( + f"WorkloadState for workload_id={workload_id} doesn't exist." + " A workload must be registered before it can be retrieved or updated " + " by a client." + ) + + def update_workloadstate( + self, workload_id: int, workload_state: WorkloadState + ) -> None: + """Update workload state.""" + self.workload_states[workload_id] = workload_state diff --git a/src/py/flwr/client/node_state_tests.py b/src/py/flwr/client/node_state_tests.py new file mode 100644 index 000000000000..d9f9ae7db3b0 --- /dev/null +++ b/src/py/flwr/client/node_state_tests.py @@ -0,0 +1,59 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Node state tests.""" + + +from flwr.client.node_state import NodeState +from flwr.client.workload_state import WorkloadState +from flwr.proto.task_pb2 import TaskIns + + +def _run_dummy_task(state: WorkloadState) -> WorkloadState: + if "counter" in state.state: + state.state["counter"] += "1" + else: + state.state["counter"] = "1" + + return state + + +def test_multiworkload_in_node_state() -> None: + """Test basic NodeState logic.""" + # Tasks to perform + tasks = [TaskIns(workload_id=w_id) for w_id in [0, 1, 1, 2, 3, 2, 1, 5]] + # the "tasks" is to count how many times each workload is executed + expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"} + + # NodeState + node_state = NodeState() + + for task in tasks: + w_id = task.workload_id + + # Register + node_state.register_workloadstate(workload_id=w_id) + + # Get workload state + state = node_state.retrieve_workloadstate(workload_id=w_id) + + # Run "task" + updated_state = _run_dummy_task(state) + + # Update workload state + node_state.update_workloadstate(workload_id=w_id, workload_state=updated_state) + + # Verify values + for w_id, state in node_state.workload_states.items(): + assert state.state["counter"] == expected_values[w_id]