Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace RunState with RecordSet #2855

Merged
merged 6 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import flwr as fl
import numpy as np

from flwr.common.configsrecord import ConfigsRecord

SUBSET_SIZE = 1000
STATE_VAR = 'timestamp'

Expand All @@ -18,13 +20,15 @@ def get_parameters(self, config):
def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)
value = str(t_stamp)
if STATE_VAR in self.state.configs.keys():
value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]
return self.state.get_configs(STATE_VAR)[STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
Expand Down
15 changes: 9 additions & 6 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm

import flwr as fl
from flwr.common.configsrecord import ConfigsRecord

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -95,14 +96,16 @@ def get_parameters(self, config):
def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)
value = str(t_stamp)
if STATE_VAR in self.state.configs.keys():
value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]

return self.state.get_configs(STATE_VAR)[STATE_VAR]
danieljanes marked this conversation as resolved.
Show resolved Hide resolved
def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from abc import ABC

from flwr.client.run_state import RunState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -33,12 +32,13 @@
Parameters,
Status,
)
from flwr.common.recordset import RecordSet


class Client(ABC):
"""Abstract base class for Flower clients."""

state: RunState
state: RecordSet

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -141,11 +141,11 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> RunState:
def get_state(self) -> RecordSet:
"""Get the run state from this client."""
return self.state

def set_state(self, state: RunState) -> None:
def set_state(self, state: RecordSet) -> None:
"""Apply a run state to this client."""
self.state = state

Expand Down
14 changes: 7 additions & 7 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
get_server_message_from_task_ins,
wrap_client_message_in_task_res,
)
from flwr.client.run_state import RunState
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
SecureAggregation,
Task,
Expand Down Expand Up @@ -88,15 +88,15 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:


def handle(
client_fn: ClientFn, state: RunState, task_ins: TaskIns
) -> Tuple[TaskRes, RunState]:
client_fn: ClientFn, state: RecordSet, task_ins: TaskIns
) -> Tuple[TaskRes, RecordSet]:
"""Handle incoming TaskIns from the server.

Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RunState
state : RecordSet
A dataclass storing the state for the run being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.
Expand Down Expand Up @@ -135,15 +135,15 @@ def handle(


def handle_legacy_message(
client_fn: ClientFn, state: RunState, server_msg: ServerMessage
) -> Tuple[ClientMessage, RunState]:
client_fn: ClientFn, state: RecordSet, server_msg: ServerMessage
) -> Tuple[ClientMessage, RecordSet]:
"""Handle incoming messages from the server.

Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RunState
state : RecordSet
A dataclass storing the state for the run being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import uuid

from flwr.client import Client
from flwr.client.run_state import RunState
from flwr.client.typing import ClientFn
from flwr.common import (
EvaluateIns,
Expand All @@ -33,6 +32,7 @@
serde,
typing,
)
from flwr.common.recordset import RecordSet
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_client_without_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RunState(state={}),
state=RecordSet(),
task_ins=task_ins,
)

Expand Down Expand Up @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RunState(state={}),
state=RecordSet(),
task_ins=task_ins,
)

Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import unittest
from typing import List

from flwr.client.run_state import RunState
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .utils import make_ffn
Expand All @@ -45,7 +45,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable:
def app(fwd: Fwd) -> Bwd:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
return Bwd(task_res=TaskRes(task_id=name), state=RunState({}))
return Bwd(task_res=TaskRes(task_id=name), state=RecordSet())

return app

Expand All @@ -66,7 +66,7 @@ def test_multiple_middlewares(self) -> None:

# Execute
wrapped_app = make_ffn(mock_app, mock_middleware_layers)
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res

# Assert
trace = mock_middleware_names + ["app"]
Expand All @@ -86,11 +86,11 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd:
footprint.append("filter")
fwd.task_ins.task_id += "filter"
# Skip calling app
return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({}))
return Bwd(task_res=TaskRes(task_id="filter"), state=RecordSet())

# Execute
wrapped_app = make_ffn(mock_app, [filter_layer])
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res

# Assert
self.assertEqual(footprint, ["filter"])
Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@

from typing import Any, Dict

from flwr.client.run_state import RunState
from flwr.common.recordset import RecordSet


class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(self) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.run_states: Dict[int, RunState] = {}
self.run_states: Dict[int, RecordSet] = {}

def register_runstate(self, run_id: int) -> None:
"""Register new run state for this node."""
if run_id not in self.run_states:
self.run_states[run_id] = RunState({})
self.run_states[run_id] = RecordSet()

def retrieve_runstate(self, run_id: int) -> RunState:
def retrieve_runstate(self, run_id: int) -> RecordSet:
"""Get run state given a run_id."""
if run_id in self.run_states:
return self.run_states[run_id]
Expand All @@ -43,6 +43,6 @@ def retrieve_runstate(self, run_id: int) -> RunState:
" by a client."
)

def update_runstate(self, run_id: int, run_state: RunState) -> None:
def update_runstate(self, run_id: int, run_state: RecordSet) -> None:
"""Update run state."""
self.run_states[run_id] = run_state
17 changes: 10 additions & 7 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@


from flwr.client.node_state import NodeState
from flwr.client.run_state import RunState
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611


def _run_dummy_task(state: RunState) -> RunState:
if "counter" in state.state:
state.state["counter"] += "1"
else:
state.state["counter"] = "1"
def _run_dummy_task(state: RecordSet) -> RecordSet:
counter_value: str = "1"
if "counter" in state.configs.keys():
counter_value = state.get_configs("counter")["count"] # type: ignore
counter_value += "1"

state.set_configs(name="counter", record=ConfigsRecord({"count": counter_value}))

return state

Expand Down Expand Up @@ -56,4 +59,4 @@ def test_multirun_in_node_state() -> None:

# Verify values
for run_id, state in node_state.run_states.items():
assert state.state["counter"] == expected_values[run_id]
assert state.get_configs("counter")["count"] == expected_values[run_id]
12 changes: 6 additions & 6 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from typing import Callable, Dict, Tuple

from flwr.client.client import Client
from flwr.client.run_state import RunState
from flwr.common import (
Config,
NDArrays,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.recordset import RecordSet
from flwr.common.typing import (
Code,
EvaluateIns,
Expand Down Expand Up @@ -70,7 +70,7 @@
class NumPyClient(ABC):
"""Abstract base class for Flower clients using NumPy."""

state: RunState
state: RecordSet

def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""Return a client's set of properties.
Expand Down Expand Up @@ -174,11 +174,11 @@ def evaluate(
_ = (self, parameters, config)
return 0.0, 0, {}

def get_state(self) -> RunState:
def get_state(self) -> RecordSet:
"""Get the run state from this client."""
return self.state

def set_state(self, state: RunState) -> None:
def set_state(self, state: RecordSet) -> None:
"""Apply a run state to this client."""
self.state = state

Expand Down Expand Up @@ -278,12 +278,12 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
)


def _get_state(self: Client) -> RunState:
def _get_state(self: Client) -> RecordSet:
"""Return state of underlying NumPyClient."""
return self.numpy_client.get_state() # type: ignore


def _set_state(self: Client, state: RunState) -> None:
def _set_state(self: Client, state: RecordSet) -> None:
"""Apply state to underlying NumPyClient."""
self.numpy_client.set_state(state) # type: ignore

Expand Down
25 changes: 0 additions & 25 deletions src/py/flwr/client/run_state.py

This file was deleted.

6 changes: 3 additions & 3 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass
from typing import Callable

from flwr.client.run_state import RunState
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .client import Client as Client
Expand All @@ -28,15 +28,15 @@ class Fwd:
"""."""

task_ins: TaskIns
state: RunState
state: RecordSet


@dataclass
class Bwd:
"""."""

task_res: TaskRes
state: RunState
state: RecordSet


FlowerCallable = Callable[[Fwd], Bwd]
Expand Down
Loading
Loading