-
Notifications
You must be signed in to change notification settings - Fork 32
/
client.py
89 lines (81 loc) · 3.33 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# The structure of the client
# Should include following funcitons
# 1. Client intialization, dataloaders, model(include optimizer)
# 2. Client model update
# 3. Client send updates to server
# 4. Client receives updates from server
# 5. Client modify local model based on the feedback from the server
from torch.autograd import Variable
import torch
from models.initialize_model import initialize_model
import copy
class Client():
def __init__(self, id, train_loader, test_loader, args, device):
self.id = id
self.train_loader = train_loader
self.test_loader = test_loader
self.model = initialize_model(args, device)
# copy.deepcopy(self.model.shared_layers.state_dict())
self.receiver_buffer = {}
self.batch_size = args.batch_size
#record local update epoch
self.epoch = 0
# record the time
self.clock = []
def local_update(self, num_iter, device):
itered_num = 0
loss = 0.0
end = False
# the upperbound selected in the following is because it is expected that one local update will never reach 1000
for epoch in range(1000):
for data in self.train_loader:
inputs, labels = data
inputs = Variable(inputs).to(device)
labels = Variable(labels).to(device)
loss += self.model.optimize_model(input_batch=inputs,
label_batch=labels)
itered_num += 1
if itered_num >= num_iter:
end = True
# print(f"Iterer number {itered_num}")
self.epoch += 1
self.model.exp_lr_sheduler(epoch=self.epoch)
# self.model.print_current_lr()
break
if end: break
self.epoch += 1
self.model.exp_lr_sheduler(epoch = self.epoch)
# self.model.print_current_lr()
# print(itered_num)
# print(f'The {self.epoch}')
loss /= num_iter
return loss
def test_model(self, device):
correct = 0.0
total = 0.0
with torch.no_grad():
for data in self.test_loader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = self.model.test_model(input_batch= inputs)
_, predict = torch.max(outputs, 1)
total += labels.size(0)
correct += (predict == labels).sum().item()
return correct, total
def send_to_edgeserver(self, edgeserver):
edgeserver.receive_from_client(client_id= self.id,
cshared_state_dict = copy.deepcopy(self.model.shared_layers.state_dict())
)
return None
def receive_from_edgeserver(self, shared_state_dict):
self.receiver_buffer = shared_state_dict
return None
def sync_with_edgeserver(self):
"""
The global has already been stored in the buffer
:return: None
"""
# self.model.shared_layers.load_state_dict(self.receiver_buffer)
self.model.update_model(self.receiver_buffer)
return None