-
Notifications
You must be signed in to change notification settings - Fork 32
/
edge.py
79 lines (68 loc) · 2.98 KB
/
edge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# The structure of the edge server
# THe edge should include following funcitons
# 1. Server initialization
# 2. Server receives updates from the client
# 3. Server sends the aggregated information back to clients
# 4. Server sends the updates to the cloud server
# 5. Server receives the aggregated information from the cloud server
import copy
from average import average_weights
class Edge():
def __init__(self, id, cids, shared_layers):
"""
id: edge id
cids: ids of the clients under this edge
receiver_buffer: buffer for the received updates from selected clients
shared_state_dict: state dict for shared network
id_registration: participated clients in this round of traning
sample_registration: number of samples of the participated clients in this round of training
all_trainsample_num: the training samples for all the clients under this edge
shared_state_dict: the dictionary of the shared state dict
clock: record the time after each aggregation
:param id: Index of the edge
:param cids: Indexes of all the clients under this edge
:param shared_layers: Structure of the shared layers
:return:
"""
self.id = id
self.cids = cids
self.receiver_buffer = {}
self.shared_state_dict = {}
self.id_registration = []
self.sample_registration = {}
self.all_trainsample_num = 0
self.shared_state_dict = shared_layers.state_dict()
self.clock = []
def refresh_edgeserver(self):
self.receiver_buffer.clear()
del self.id_registration[:]
self.sample_registration.clear()
return None
def client_register(self, client):
self.id_registration.append(client.id)
self.sample_registration[client.id] = len(client.train_loader.dataset)
return None
def receive_from_client(self, client_id, cshared_state_dict):
self.receiver_buffer[client_id] = cshared_state_dict
return None
def aggregate(self, args):
"""
Using the old aggregation funciton
:param args:
:return:
"""
received_dict = [dict for dict in self.receiver_buffer.values()]
sample_num = [snum for snum in self.sample_registration.values()]
self.shared_state_dict = average_weights(w = received_dict,
s_num= sample_num)
def send_to_client(self, client):
client.receive_from_edgeserver(copy.deepcopy(self.shared_state_dict))
return None
def send_to_cloudserver(self, cloud):
cloud.receive_from_edge(edge_id=self.id,
eshared_state_dict= copy.deepcopy(
self.shared_state_dict))
return None
def receive_from_cloudserver(self, shared_state_dict):
self.shared_state_dict = shared_state_dict
return None