Skip to content

Commit

Permalink
Add NodeState to start_client. (#2645)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Dec 7, 2023
1 parent 306e9f6 commit 229341b
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from .grpc_client.connection import grpc_connection
from .grpc_rere_client.connection import grpc_request_response
from .message_handler.message_handler import handle_control_message
from .node_state import NodeState
from .numpy_client import NumPyClient
from .workload_state import WorkloadState


def run_client() -> None:
Expand Down Expand Up @@ -318,6 +318,8 @@ def _load_app() -> Flower:
# Initialize connection context manager
connection, address = _init_connection(transport, server_address)

node_state = NodeState()

while True:
sleep_duration: int = 0
with connection(
Expand Down Expand Up @@ -345,16 +347,27 @@ def _load_app() -> Flower:
send(task_res)
break

# Register state
node_state.register_workloadstate(workload_id=task_ins.workload_id)

# Load app
app: Flower = load_flower_callable_fn()

# Handle task message
fwd_msg: Fwd = Fwd(
task_ins=task_ins,
state=WorkloadState(state={}),
state=node_state.retrieve_workloadstate(
workload_id=task_ins.workload_id
),
)
bwd_msg: Bwd = app(fwd=fwd_msg)

# Update node state
node_state.update_workloadstate(
workload_id=bwd_msg.task_res.workload_id,
workload_state=bwd_msg.state,
)

# Send
send(bwd_msg.task_res)

Expand Down
50 changes: 50 additions & 0 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Node state."""


from typing import Any, Dict

from flwr.client.workload_state import WorkloadState


class NodeState:
"""State of a node where client nodes execute workloads."""

def __init__(self) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.workload_states: Dict[int, WorkloadState] = {}

def register_workloadstate(self, workload_id: int) -> None:
"""Register new workload state for this node."""
if workload_id not in self.workload_states:
self.workload_states[workload_id] = WorkloadState({})

def retrieve_workloadstate(self, workload_id: int) -> WorkloadState:
"""Get workload state given a workload_id."""
if workload_id in self.workload_states:
return self.workload_states[workload_id]

raise RuntimeError(
f"WorkloadState for workload_id={workload_id} doesn't exist."
" A workload must be registered before it can be retrieved or updated "
" by a client."
)

def update_workloadstate(
self, workload_id: int, workload_state: WorkloadState
) -> None:
"""Update workload state."""
self.workload_states[workload_id] = workload_state
59 changes: 59 additions & 0 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Node state tests."""


from flwr.client.node_state import NodeState
from flwr.client.workload_state import WorkloadState
from flwr.proto.task_pb2 import TaskIns


def _run_dummy_task(state: WorkloadState) -> WorkloadState:
if "counter" in state.state:
state.state["counter"] += "1"
else:
state.state["counter"] = "1"

return state


def test_multiworkload_in_node_state() -> None:
"""Test basic NodeState logic."""
# Tasks to perform
tasks = [TaskIns(workload_id=w_id) for w_id in [0, 1, 1, 2, 3, 2, 1, 5]]
# the "tasks" is to count how many times each workload is executed
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}

# NodeState
node_state = NodeState()

for task in tasks:
w_id = task.workload_id

# Register
node_state.register_workloadstate(workload_id=w_id)

# Get workload state
state = node_state.retrieve_workloadstate(workload_id=w_id)

# Run "task"
updated_state = _run_dummy_task(state)

# Update workload state
node_state.update_workloadstate(workload_id=w_id, workload_state=updated_state)

# Verify values
for w_id, state in node_state.workload_states.items():
assert state.state["counter"] == expected_values[w_id]

0 comments on commit 229341b

Please sign in to comment.