-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlstm_fedavg.py
68 lines (53 loc) · 2.01 KB
/
lstm_fedavg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Train/test of multiple LSTM based time forecasting models with FedAvg"""
import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
from approach.lstm_forecast import ShallowForecastLSTM
from util.utils import train_one_step, test_model, predict, SequenceDataset
from util.fed_utils import FedAvg_loop
# import data: you need to download it, see readme.md in the datasets directory
# x_data = np.load("datasets/traffic.npy")
T_size, p = x_data.shape
T_train = int(0.70 * T_size)
lag = 48
n_clients = 2 # reduced number of time series components
# scale the data
scalers = [StandardScaler(with_mean=True, with_std=True) for _ in range(n_clients)]
batch_size = 2**3
train_loaders, test_loaders = [], []
for k in range(n_clients):
train_loaders.append(
DataLoader(
SequenceDataset(
scalers[k].fit_transform(x_data[:T_train, k][:, np.newaxis]),
lag=lag
),
batch_size=batch_size, shuffle=True
)
)
test_loaders.append(
DataLoader(
SequenceDataset(
scalers[k].transform(x_data[T_train:, k][:, np.newaxis]),
lag=lag
),
batch_size=batch_size, shuffle=False
)
)
# Instantiate a base pytorch model
# base_model = ShallowForecastLSTM(num_var=1, hidden_units=25)
# train with FedAvg
models, local_losses_bfr_fedavg, local_losses_aft_fedavg, global_loss = FedAvg_loop(
call_basemodel=ShallowForecastLSTM, # give a CALL TO THE model instances, not the instances
archi_basemodel={'n_inputs':1, 'n_hidden': 10},
train_sets=train_loaders,
test_sets=test_loaders,
n_cr=5, n_local_epochs=5,
lr= 10**-2
)
np.array(local_losses_bfr_fedavg)
for k in range(n_clients):
plt.plot(np.array(local_losses_bfr_fedavg)[:, k], label='loss before FedAvg', alpha=0.4)
# plt.plot(local_losses_aft_fedavg[k], label='loss before FedAvg', alpha=0.4)
plt.show()