-
Notifications
You must be signed in to change notification settings - Fork 2
/
server.py
72 lines (60 loc) · 2.71 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import flwr as fl
import sys
import numpy as np
from typing import List, Tuple, Optional, Callable
from flwr.server.client_proxy import ClientProxy
from flwr.common.typing import EvaluateRes
#class SaveModelStrategy(fl.server.strategy.FedAvg):
#def aggregate_fit(self,rnd,results,failures):
#aggregated_weights = super().aggregate_fit(rnd, results, failures)
#if aggregated_weights is not None:
# Save aggregated_weights
#print(f"Saving round {rnd} aggregated_weights...")
# np.savez(f"round-{rnd}-weights.npz", *aggregated_weights)
# return aggregated_weights
class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
def aggregate_fit(
self,
rnd: int,
results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]],
failures: List[BaseException],
) -> Optional[fl.common.NDArrays]:
aggregated_weights = super().aggregate_fit(rnd, results, failures)
if aggregated_weights is not None:
# Save aggregated_weights
print(f"Saving round {rnd} aggregated_weights...")
np.savez(f"round_weights/round-{rnd}-weights.npz", *aggregated_weights)
return aggregated_weights
def aggregate_evaluate(
self,
rnd: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[BaseException],
) -> Optional[float]:
"""Aggregate evaluation losses using weighted average."""
if not results:
return None
# Weigh accuracy of each client by number of examples used
accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
examples = [r.num_examples for _, r in results]
# Aggregate and print custom metric
accuracy_aggregated = sum(accuracies) / sum(examples)
#wandb.log({"round": rnd, "server_aggregated_accuracy": accuracy_aggregated})
print(
f"Round {rnd} accuracy aggregated from client results: {accuracy_aggregated}"
)
# Call aggregate_evaluate from base class (FedAvg)
return super().aggregate_evaluate(rnd, results, failures)
# Create strategy and run server
#strategy = SaveModelStrategy()
strategy = AggregateCustomMetricStrategy(fraction_fit=0.5, fraction_evaluate=0.5,min_fit_clients=4,
min_evaluate_clients=4,
min_available_clients=4,)
# Start Flower server for three rounds of federated learning with 1Gb of data
fl.server.start_server(
server_address = "127.0.0.1:8080" ,
# config={"num_rounds": 5} ,
config = fl.server.ServerConfig(num_rounds = 5),
#grpc_max_message_length = 1024*1024*1024,
strategy = strategy
)