Skip to content

Commit

Permalink
Add E2E test for WorkloadState and NodeState (#2696)
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Dec 8, 2023
1 parent e584c6d commit dbf56d9
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 4 deletions.
20 changes: 18 additions & 2 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from datetime import datetime

import flwr as fl
import numpy as np

SUBSET_SIZE = 1000
STATE_VAR = 'timestamp'


model_params = np.array([1])
Expand All @@ -12,16 +15,29 @@ class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model_params

def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
model_params = [param * (objective/np.mean(param)) for param in model_params]
return model_params, 1, {}
self._record_timestamp_to_state()
return model_params, 1, {STATE_VAR: self._retrieve_timestamp_from_state()}

def evaluate(self, parameters, config):
model_params = parameters
loss = min(np.abs(1 - np.mean(model_params)/objective), 1)
accuracy = 1 - loss
return loss, 1, {"accuracy": accuracy}
self._record_timestamp_to_state()
return loss, 1, {"accuracy": accuracy, STATE_VAR: self._retrieve_timestamp_from_state()}

def client_fn(cid):
return FlowerClient().to_client()
Expand Down
30 changes: 30 additions & 0 deletions e2e/bare/simulation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,41 @@
from typing import List, Tuple
import numpy as np

import flwr as fl
from flwr.common import Metrics

from client import client_fn
STATE_VAR = 'timestamp'


# Define metric aggregation function
def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Ensure that timestamps are monotonically increasing."""
states = []
for _, m in metrics:
# split string and covert timestamps to float
states.append([float(tt) for tt in m[STATE_VAR].split(',')])

for client_state in states:
if len(client_state) == 1:
continue
deltas = np.diff(client_state)
assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}"

return {STATE_VAR: states}


strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics)

hist = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98

# The checks in record_state_metrics don't do anythinng if client's state has a single entry
state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1]
assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds"
19 changes: 17 additions & 2 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections import OrderedDict
from datetime import datetime

import torch
import torch.nn as nn
Expand All @@ -18,6 +19,7 @@
warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SUBSET_SIZE = 1000
STATE_VAR = 'timestamp'


class Net(nn.Module):
Expand Down Expand Up @@ -89,16 +91,29 @@ def load_data():
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]

def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}
self._record_timestamp_to_state()
return self.get_parameters(config={}), len(trainloader.dataset), {STATE_VAR: self._retrieve_timestamp_from_state()}

def evaluate(self, parameters, config):
set_parameters(net, parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}
self._record_timestamp_to_state()
return loss, len(testloader.dataset), {"accuracy": accuracy, STATE_VAR: self._retrieve_timestamp_from_state()}

def set_parameters(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
Expand Down
31 changes: 31 additions & 0 deletions e2e/pytorch/simulation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,42 @@
from typing import List, Tuple
import numpy as np

import flwr as fl
from flwr.common import Metrics


from client import client_fn
STATE_VAR = 'timestamp'


# Define metric aggregation function
def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Ensure that timestamps are monotonically increasing."""
states = []
for _, m in metrics:
# split string and covert timestamps to float
states.append([float(tt) for tt in m[STATE_VAR].split(',')])

for client_state in states:
if len(client_state) == 1:
continue
deltas = np.diff(client_state)
assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}"

return {STATE_VAR: states}


strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics)

hist = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98

# The checks in record_state_metrics don't do anythinng if client's state has a single entry
state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1]
assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds"
39 changes: 39 additions & 0 deletions e2e/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
from typing import List, Tuple
import numpy as np


import flwr as fl
from flwr.common import Metrics
STATE_VAR = 'timestamp'


# Define metric aggregation function
def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Ensure that timestamps are monotonically increasing."""
if not metrics:
return {}

if STATE_VAR not in metrics[0][1]:
# Do nothing if keyword is not present
return {}

states = []
for _, m in metrics:
# split string and covert timestamps to float
states.append([float(tt) for tt in m[STATE_VAR].split(',')])

for client_state in states:
if len(client_state) == 1:
continue
deltas = np.diff(client_state)
assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}"

return {STATE_VAR: states}


strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics)

hist = fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98

if STATE_VAR in hist.metrics_distributed:
# The checks in record_state_metrics don't do anythinng if client's state has a single entry
state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1]
assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds"

0 comments on commit dbf56d9

Please sign in to comment.