Skip to content

Commit

Permalink
Replace RunState with RecordSet (#2855)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jan 26, 2024
1 parent 52c3952 commit 9bfd38e
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 88 deletions.
14 changes: 9 additions & 5 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import flwr as fl
import numpy as np

from flwr.common.configsrecord import ConfigsRecord

SUBSET_SIZE = 1000
STATE_VAR = 'timestamp'

Expand All @@ -18,13 +20,15 @@ def get_parameters(self, config):
def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)
value = str(t_stamp)
if STATE_VAR in self.state.configs.keys():
value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]
return self.state.get_configs(STATE_VAR)[STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
Expand Down
14 changes: 8 additions & 6 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tqdm import tqdm

import flwr as fl
from flwr.common.configsrecord import ConfigsRecord

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
Expand Down Expand Up @@ -95,14 +96,15 @@ def get_parameters(self, config):
def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)
value = str(t_stamp)
if STATE_VAR in self.state.configs.keys():
value = self.state.get_configs(STATE_VAR)[STATE_VAR] # type: ignore
value += f",{t_stamp}"

self.state.set_configs(name=STATE_VAR, record=ConfigsRecord({STATE_VAR: value}))

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]

return self.state.get_configs(STATE_VAR)[STATE_VAR]
def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from abc import ABC

from flwr.client.run_state import RunState
from flwr.common import (
Code,
EvaluateIns,
Expand All @@ -33,12 +32,13 @@
Parameters,
Status,
)
from flwr.common.recordset import RecordSet


class Client(ABC):
"""Abstract base class for Flower clients."""

state: RunState
state: RecordSet

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Return set of client's properties.
Expand Down Expand Up @@ -141,11 +141,11 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
metrics={},
)

def get_state(self) -> RunState:
def get_state(self) -> RecordSet:
"""Get the run state from this client."""
return self.state

def set_state(self, state: RunState) -> None:
def set_state(self, state: RecordSet) -> None:
"""Apply a run state to this client."""
self.state = state

Expand Down
14 changes: 7 additions & 7 deletions src/py/flwr/client/message_handler/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
get_server_message_from_task_ins,
wrap_client_message_in_task_res,
)
from flwr.client.run_state import RunState
from flwr.client.secure_aggregation import SecureAggregationHandler
from flwr.client.typing import ClientFn
from flwr.common import serde
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import ( # pylint: disable=E0611
SecureAggregation,
Task,
Expand Down Expand Up @@ -88,15 +88,15 @@ def handle_control_message(task_ins: TaskIns) -> Tuple[Optional[TaskRes], int]:


def handle(
client_fn: ClientFn, state: RunState, task_ins: TaskIns
) -> Tuple[TaskRes, RunState]:
client_fn: ClientFn, state: RecordSet, task_ins: TaskIns
) -> Tuple[TaskRes, RecordSet]:
"""Handle incoming TaskIns from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RunState
state : RecordSet
A dataclass storing the state for the run being executed by the client.
task_ins: TaskIns
The task instruction coming from the server, to be processed by the client.
Expand Down Expand Up @@ -135,15 +135,15 @@ def handle(


def handle_legacy_message(
client_fn: ClientFn, state: RunState, server_msg: ServerMessage
) -> Tuple[ClientMessage, RunState]:
client_fn: ClientFn, state: RecordSet, server_msg: ServerMessage
) -> Tuple[ClientMessage, RecordSet]:
"""Handle incoming messages from the server.
Parameters
----------
client_fn : ClientFn
A callable that instantiates a Client.
state : RunState
state : RecordSet
A dataclass storing the state for the run being executed by the client.
server_msg: ServerMessage
The message coming from the server, to be processed by the client.
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/message_handler/message_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import uuid

from flwr.client import Client
from flwr.client.run_state import RunState
from flwr.client.typing import ClientFn
from flwr.common import (
EvaluateIns,
Expand All @@ -33,6 +32,7 @@
serde,
typing,
)
from flwr.common.recordset import RecordSet
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.proto.transport_pb2 import ( # pylint: disable=E0611
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_client_without_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RunState(state={}),
state=RecordSet(),
task_ins=task_ins,
)

Expand Down Expand Up @@ -209,7 +209,7 @@ def test_client_with_get_properties() -> None:
)
task_res, _ = handle(
client_fn=_get_client_fn(client),
state=RunState(state={}),
state=RecordSet(),
task_ins=task_ins,
)

Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/middleware/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import unittest
from typing import List

from flwr.client.run_state import RunState
from flwr.client.typing import Bwd, FlowerCallable, Fwd, Layer
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .utils import make_ffn
Expand All @@ -45,7 +45,7 @@ def make_mock_app(name: str, footprint: List[str]) -> FlowerCallable:
def app(fwd: Fwd) -> Bwd:
footprint.append(name)
fwd.task_ins.task_id += f"{name}"
return Bwd(task_res=TaskRes(task_id=name), state=RunState({}))
return Bwd(task_res=TaskRes(task_id=name), state=RecordSet())

return app

Expand All @@ -66,7 +66,7 @@ def test_multiple_middlewares(self) -> None:

# Execute
wrapped_app = make_ffn(mock_app, mock_middleware_layers)
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res

# Assert
trace = mock_middleware_names + ["app"]
Expand All @@ -86,11 +86,11 @@ def filter_layer(fwd: Fwd, _: FlowerCallable) -> Bwd:
footprint.append("filter")
fwd.task_ins.task_id += "filter"
# Skip calling app
return Bwd(task_res=TaskRes(task_id="filter"), state=RunState({}))
return Bwd(task_res=TaskRes(task_id="filter"), state=RecordSet())

# Execute
wrapped_app = make_ffn(mock_app, [filter_layer])
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RunState({}))).task_res
task_res = wrapped_app(Fwd(task_ins=task_ins, state=RecordSet())).task_res

# Assert
self.assertEqual(footprint, ["filter"])
Expand Down
10 changes: 5 additions & 5 deletions src/py/flwr/client/node_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@

from typing import Any, Dict

from flwr.client.run_state import RunState
from flwr.common.recordset import RecordSet


class NodeState:
"""State of a node where client nodes execute runs."""

def __init__(self) -> None:
self._meta: Dict[str, Any] = {} # holds metadata about the node
self.run_states: Dict[int, RunState] = {}
self.run_states: Dict[int, RecordSet] = {}

def register_runstate(self, run_id: int) -> None:
"""Register new run state for this node."""
if run_id not in self.run_states:
self.run_states[run_id] = RunState({})
self.run_states[run_id] = RecordSet()

def retrieve_runstate(self, run_id: int) -> RunState:
def retrieve_runstate(self, run_id: int) -> RecordSet:
"""Get run state given a run_id."""
if run_id in self.run_states:
return self.run_states[run_id]
Expand All @@ -43,6 +43,6 @@ def retrieve_runstate(self, run_id: int) -> RunState:
" by a client."
)

def update_runstate(self, run_id: int, run_state: RunState) -> None:
def update_runstate(self, run_id: int, run_state: RecordSet) -> None:
"""Update run state."""
self.run_states[run_id] = run_state
17 changes: 10 additions & 7 deletions src/py/flwr/client/node_state_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@


from flwr.client.node_state import NodeState
from flwr.client.run_state import RunState
from flwr.common.configsrecord import ConfigsRecord
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611


def _run_dummy_task(state: RunState) -> RunState:
if "counter" in state.state:
state.state["counter"] += "1"
else:
state.state["counter"] = "1"
def _run_dummy_task(state: RecordSet) -> RecordSet:
counter_value: str = "1"
if "counter" in state.configs.keys():
counter_value = state.get_configs("counter")["count"] # type: ignore
counter_value += "1"

state.set_configs(name="counter", record=ConfigsRecord({"count": counter_value}))

return state

Expand Down Expand Up @@ -56,4 +59,4 @@ def test_multirun_in_node_state() -> None:

# Verify values
for run_id, state in node_state.run_states.items():
assert state.state["counter"] == expected_values[run_id]
assert state.get_configs("counter")["count"] == expected_values[run_id]
12 changes: 6 additions & 6 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from typing import Callable, Dict, Tuple

from flwr.client.client import Client
from flwr.client.run_state import RunState
from flwr.common import (
Config,
NDArrays,
Scalar,
ndarrays_to_parameters,
parameters_to_ndarrays,
)
from flwr.common.recordset import RecordSet
from flwr.common.typing import (
Code,
EvaluateIns,
Expand Down Expand Up @@ -70,7 +70,7 @@
class NumPyClient(ABC):
"""Abstract base class for Flower clients using NumPy."""

state: RunState
state: RecordSet

def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""Return a client's set of properties.
Expand Down Expand Up @@ -174,11 +174,11 @@ def evaluate(
_ = (self, parameters, config)
return 0.0, 0, {}

def get_state(self) -> RunState:
def get_state(self) -> RecordSet:
"""Get the run state from this client."""
return self.state

def set_state(self, state: RunState) -> None:
def set_state(self, state: RecordSet) -> None:
"""Apply a run state to this client."""
self.state = state

Expand Down Expand Up @@ -278,12 +278,12 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
)


def _get_state(self: Client) -> RunState:
def _get_state(self: Client) -> RecordSet:
"""Return state of underlying NumPyClient."""
return self.numpy_client.get_state() # type: ignore


def _set_state(self: Client, state: RunState) -> None:
def _set_state(self: Client, state: RecordSet) -> None:
"""Apply state to underlying NumPyClient."""
self.numpy_client.set_state(state) # type: ignore

Expand Down
25 changes: 0 additions & 25 deletions src/py/flwr/client/run_state.py

This file was deleted.

6 changes: 3 additions & 3 deletions src/py/flwr/client/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dataclasses import dataclass
from typing import Callable

from flwr.client.run_state import RunState
from flwr.common.recordset import RecordSet
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

from .client import Client as Client
Expand All @@ -28,15 +28,15 @@ class Fwd:
"""."""

task_ins: TaskIns
state: RunState
state: RecordSet


@dataclass
class Bwd:
"""."""

task_res: TaskRes
state: RunState
state: RecordSet


FlowerCallable = Callable[[Fwd], Bwd]
Expand Down
Loading

0 comments on commit 9bfd38e

Please sign in to comment.