Skip to content

Commit

Permalink
Rename WorkloadState to RunState (#2770)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jan 5, 2024
1 parent 0f5ce99 commit 2b4297d
Show file tree
Hide file tree
Showing 34 changed files with 228 additions and 231 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main(cfg: DictConfig):
save_path = HydraConfig.get().runtime.output_dir

## 2. Prepare your dataset
# When simulating FL workloads we have a lot of freedom on how the FL clients behave,
# When simulating FL runs we have a lot of freedom on how the FL clients behave,
# what data they have, how much data, etc. This is not possible in real FL settings.
# In simulation you'd often encounter two types of dataset:
# * naturally partitioned, that come pre-partitioned by user id (e.g. FEMNIST,
Expand Down Expand Up @@ -91,7 +91,7 @@ def main(cfg: DictConfig):
"num_gpus": 0.0,
}, # (optional) controls the degree of parallelism of your simulation.
# Lower resources per client allow for more clients to run concurrently
# (but need to be set taking into account the compute/memory footprint of your workload)
# (but need to be set taking into account the compute/memory footprint of your run)
# `num_cpus` is an absolute number (integer) indicating the number of threads a client should be allocated
# `num_gpus` is a ratio indicating the portion of gpu memory that a client needs.
)
Expand Down
6 changes: 3 additions & 3 deletions examples/mt-pytorch/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
create_run_res: driver_pb2.CreateRunResponse = driver.create_run(
req=driver_pb2.CreateRunRequest()
)
# -------------------------------------------------------------------------- Driver SDK

run_id = create_workload_res.run_id
run_id = create_run_res.run_id
print(f"Created run id {run_id}")

history = History()
Expand Down
7 changes: 4 additions & 3 deletions examples/secaggplus-mt/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task:
task_id="", # Do not set, will be created and set by the DriverAPI
group_id="",
run_id=run_id,
run_id=run_id,
task=merge(
task,
task_pb2.Task(
Expand Down Expand Up @@ -84,12 +85,12 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:

# -------------------------------------------------------------------------- Driver SDK
driver.connect()
create_workload_res: driver_pb2.CreateWorkloadResponse = driver.create_workload(
req=driver_pb2.CreateWorkloadRequest()
create_run_res: driver_pb2.CreateRunResponse = driver.create_run(
req=driver_pb2.CreateRunRequest()
)
# -------------------------------------------------------------------------- Driver SDK

run_id = create_workload_res.run_id
run_id = create_run_res.run_id
print(f"Created run id {run_id}")

history = History()
Expand Down
10 changes: 5 additions & 5 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import "flwr/proto/node.proto";
import "flwr/proto/task.proto";

service Driver {
// Request workload_id
rpc CreateWorkload(CreateWorkloadRequest) returns (CreateWorkloadResponse) {}
// Request run_id
rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {}

// Return a set of nodes
rpc GetNodes(GetNodesRequest) returns (GetNodesResponse) {}
Expand All @@ -34,9 +34,9 @@ service Driver {
rpc PullTaskRes(PullTaskResRequest) returns (PullTaskResResponse) {}
}

// CreateWorkload
message CreateWorkloadRequest {}
message CreateWorkloadResponse { sint64 run_id = 1; }
// CreateRun
message CreateRunRequest {}
message CreateRunResponse { sint64 run_id = 1; }

// GetNodes messages
message GetNodesRequest { sint64 run_id = 1; }
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,22 +349,22 @@ def _load_app() -> Flower:
break

# Register state
node_state.register_workloadstate(run_id=task_ins.run_id)
node_state.register_runstate(run_id=task_ins.run_id)

# Load app
app: Flower = load_flower_callable_fn()

# Handle task message
fwd_msg: Fwd = Fwd(
task_ins=task_ins,
state=node_state.retrieve_workloadstate(run_id=task_ins.run_id),
state=node_state.retrieve_runstate(run_id=task_ins.run_id),
)
bwd_msg: Bwd = app(fwd=fwd_msg)

# Update node state
node_state.update_workloadstate(
node_state.update_runstate(
run_id=bwd_msg.task_res.run_id,
workload_state=bwd_msg.state,
run_state=bwd_msg.state,
)

# Send
Expand Down
12 changes: 6 additions & 6 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from abc import ABC

from flwr.client.workload_state import WorkloadState
from flwr.client.run_state import RunState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -38,7 +38,7 @@
class Client(ABC):
"""Abstract base class for Flower clients."""

state: WorkloadState
state: RunState

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -141,12 +141,12 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> WorkloadState:
"""Get the workload state from this client."""
def get_state(self) -> RunState:
"""Get the run state from this client."""
return self.state

def set_state(self, state: WorkloadState) -> None:
"""Apply a workload state to this client."""
def set_state(self, state: RunState) -> None:
"""Apply a run state to this client."""
self.state = state

def to_client(self) -> Client:
Expand Down
18 changes: 9 additions & 9 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
get_server_message_from_task_ins,
wrap_client_message_in_task_res,
)
from flwr.client.run_state import RunState
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import serde
from flwr.proto.task_pb2 import SecureAggregation, Task, TaskIns, TaskRes
from flwr.proto.transport_pb2 import ClientMessage, Reason, ServerMessage
Expand Down Expand Up @@ -79,16 +79,16 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:


def handle(
client_fn: ClientFn, state: WorkloadState, task_ins: TaskIns
) -> Tuple[TaskRes, WorkloadState]:
client_fn: ClientFn, state: RunState, task_ins: TaskIns
) -> Tuple[TaskRes, RunState]:
"""Handle incoming TaskIns from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
state : RunState
A dataclass storing the state for the run being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.
Expand Down Expand Up @@ -126,16 +126,16 @@ def handle(


def handle_legacy_message(
client_fn: ClientFn, state: WorkloadState, server_msg: ServerMessage
) -> Tuple[ClientMessage, WorkloadState]:
client_fn: ClientFn, state: RunState, server_msg: ServerMessage
) -> Tuple[ClientMessage, RunState]:
"""Handle incoming messages from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : WorkloadState
A dataclass storing the state for the workload being executed by the client.
state : RunState
A dataclass storing the state for the run being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import uuid

from flwr.client import Client
from flwr.client.run_state import RunState
from flwr.client.typing import ClientFn
from flwr.client.workload_state import WorkloadState
from flwr.common import (
EvaluateIns,
EvaluateRes,
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_client_without_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
state=RunState(state={}),
task_ins=task_ins,
)

Expand Down Expand Up @@ -204,7 +204,7 @@ def test_client_with_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=WorkloadState(state={}),
state=RunState(state={}),
task_ins=task_ins,
)

Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import unittest
from typing import List

from flwr.client.run_state import RunState
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.client.workload_state import WorkloadState
from flwr.proto.task_pb2 import TaskIns, TaskRes

from .utils import make_ffn
Expand All @@ -45,7 +45,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable:
def app(fwd: Fwd) -> Bwd:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
return Bwd(task_res=TaskRes(task_id=name), state=WorkloadState({}))
return Bwd(task_res=TaskRes(task_id=name), state=RunState({}))

return app

Expand All @@ -66,7 +66,7 @@ def test_multiple_middlewares(self) -> None:

# Execute
wrapped_app = make_ffn(mock_app, mock_middleware_layers)
task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res

# Assert
trace = mock_middleware_names + ["app"]
Expand All @@ -86,11 +86,11 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd:
footprint.append("filter")
fwd.task_ins.task_id += "filter"
# Skip calling app
return Bwd(task_res=TaskRes(task_id="filter"), state=WorkloadState({}))
return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({}))

# Execute
wrapped_app = make_ffn(mock_app, [filter_layer])
task_res = wrapped_app(Fwd(task_ins=task_ins, state=WorkloadState({}))).task_res
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res

# Assert
self.assertEqual(footprint, ["filter"])
Expand Down
32 changes: 16 additions & 16 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,32 @@

from typing import Any, Dict

from flwr.client.workload_state import WorkloadState
from flwr.client.run_state import RunState


class NodeState:
"""State of a node where client nodes execute workloads."""
"""State of a node where client nodes execute runs."""

def __init__(self) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.workload_states: Dict[int, WorkloadState] = {}
self.run_states: Dict[int, RunState] = {}

def register_workloadstate(self, run_id: int) -> None:
"""Register new workload state for this node."""
if run_id not in self.workload_states:
self.workload_states[run_id] = WorkloadState({})
def register_runstate(self, run_id: int) -> None:
"""Register new run state for this node."""
if run_id not in self.run_states:
self.run_states[run_id] = RunState({})

def retrieve_workloadstate(self, run_id: int) -> WorkloadState:
"""Get workload state given a run_id."""
if run_id in self.workload_states:
return self.workload_states[run_id]
def retrieve_runstate(self, run_id: int) -> RunState:
"""Get run state given a run_id."""
if run_id in self.run_states:
return self.run_states[run_id]

raise RuntimeError(
f"WorkloadState for run_id={run_id} doesn't exist."
" A workload must be registered before it can be retrieved or updated "
f"RunState for run_id={run_id} doesn't exist."
" A run must be registered before it can be retrieved or updated "
" by a client."
)

def update_workloadstate(self, run_id: int, workload_state: WorkloadState) -> None:
"""Update workload state."""
self.workload_states[run_id] = workload_state
def update_runstate(self, run_id: int, run_state: RunState) -> None:
"""Update run state."""
self.run_states[run_id] = run_state
26 changes: 13 additions & 13 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@


from flwr.client.node_state import NodeState
from flwr.client.workload_state import WorkloadState
from flwr.client.run_state import RunState
from flwr.proto.task_pb2 import TaskIns


def _run_dummy_task(state: WorkloadState) -> WorkloadState:
def _run_dummy_task(state: RunState) -> RunState:
if "counter" in state.state:
state.state["counter"] += "1"
else:
Expand All @@ -29,31 +29,31 @@ def _run_dummy_task(state: WorkloadState) -> WorkloadState:
return state


def test_multiworkload_in_node_state() -> None:
def test_multirun_in_node_state() -> None:
"""Test basic NodeState logic."""
# Tasks to perform
tasks = [TaskIns(run_id=r_id) for r_id in [0, 1, 1, 2, 3, 2, 1, 5]]
# the "tasks" is to count how many times each workload is executed
tasks = [TaskIns(run_id=run_id) for run_id in [0, 1, 1, 2, 3, 2, 1, 5]]
# the "tasks" is to count how many times each run is executed
expected_values = {0: "1", 1: "1" * 3, 2: "1" * 2, 3: "1", 5: "1"}

# NodeState
node_state = NodeState()

for task in tasks:
r_id = task.run_id
run_id = task.run_id

# Register
node_state.register_workloadstate(run_id=r_id)
node_state.register_runstate(run_id=run_id)

# Get workload state
state = node_state.retrieve_workloadstate(run_id=r_id)
# Get run state
state = node_state.retrieve_runstate(run_id=run_id)

# Run "task"
updated_state = _run_dummy_task(state)

# Update workload state
node_state.update_workloadstate(run_id=r_id, workload_state=updated_state)
# Update run state
node_state.update_runstate(run_id=run_id, run_state=updated_state)

# Verify values
for r_id, state in node_state.workload_states.items():
assert state.state["counter"] == expected_values[r_id]
for run_id, state in node_state.run_states.items():
assert state.state["counter"] == expected_values[run_id]
Loading

0 comments on commit 2b4297d

Please sign in to comment.