-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_val.py
135 lines (119 loc) · 4.41 KB
/
train_val.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import numpy as np
# import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import DataLoader, Dataset
import os
from sklearn.model_selection import train_test_split
print('————执行train_val.py文件————')
root_train = 'D://新建文件夹/植物识别/plant-seedlings-classification/train'
root_test = 'D://新建文件夹/植物识别/plant-seedlings-classification/test'
list_train_cate = os.listdir(root_train)
cate2label = {}
for i in range(len(list_train_cate)):
if list_train_cate[i] not in cate2label:
cate2label[list_train_cate[i]] = i
# 加载img和label
list_train_img = []
list_train_label = []
for cate in list_train_cate:
list_train_name = os.listdir(os.path.join(root_train, cate))
for name in list_train_name:
img = Image.open(os.path.join(root_train, cate, name))
list_train_img.append(np.array(img))
list_train_label.append(cate2label[cate])
# 展示输入图片缩略图
fig, axe = plt.subplots(3, 3, figsize=(50, 50))
axe = axe.flatten()
for i, ax in enumerate(axe):
ax.imshow(list_train_img[i])
ax.axis('off')
plt.show()
# 单张显示图片
# img_idx = Image.fromarray(list_train_img[0])
# img_idx.show()
# 分割验证集和训练集
list_train_img, list_val_img = train_test_split(list_train_img, test_size=0.3, random_state=24)
list_train_label, list_val_label = train_test_split(list_train_label, test_size=0.3, random_state=24)
transform = transforms.Compose([
transforms.Resize((250, 250)),
transforms.ToTensor(),
])
class Mydata(Dataset):
def __init__(self, x, y, transform):
self.X = x
self.Y = y
self.transform = transform
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
x_idx = self.X[idx]
y_idx = self.Y[idx]
y_idx = np.array(y_idx)
y_idx = torch.from_numpy(y_idx).type(torch.LongTensor)
x_idx = self.transform(Image.fromarray(x_idx).convert('RGB')).type(torch.FloatTensor)
return x_idx, y_idx
train_set = Mydata(list_train_img, list_train_label, transform)
val_set = Mydata(list_val_img, list_val_label, transform)
train_loader = DataLoader(
train_set,
batch_size=16,
shuffle=True,
num_workers=0,
)
val_loader = DataLoader(
val_set,
batch_size=16,
shuffle=True,
num_workers=0,
)
print('————数据加载完毕————')
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(512, len(list_train_img))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fun = nn.CrossEntropyLoss()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(DEVICE)
print('————开始训练————')
loss_mean = 100
acc_epoch = 0
for epoch in range(30):
model = model.train()
correct = 0
total = 0
loss_sum = 0
list_loss = []
for i, (data, label) in enumerate(train_loader):
data = data.to(DEVICE)
label = label.to(DEVICE)
y_pred = model(data)
loss = loss_fun(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
list_loss.append(loss.item())
if i % 7 == 6:
print('Epoch:{}||Index:{}||Avg_Loss:{}'.format(epoch, i, loss_sum / 7))
loss_sum = 0
print('平均loss为{}'.format(np.mean(list_loss)))
if np.mean(list_loss) < loss_mean:
loss_mean = np.mean(list_loss)
torch.save(model.state_dict(), 'D://新建文件夹/植物识别/weight_best_loss.pth')
with torch.no_grad():
model.eval()
for j, (data_val, label_val) in enumerate(val_loader):
data_val = data_val.to(DEVICE)
label_val = label_val.to(DEVICE)
y_pred = model(data_val)
pred_val = y_pred.max(1, keepdim=True)[1]
total += label_val.shape[0]
correct += pred_val.eq(label_val.view_as(pred_val)).sum().item()
print('验证集准确率为{}'.format(correct / total))
if correct / total > acc_epoch:
acc_epoch = correct / total
torch.save(model.state_dict(), 'D://新建文件夹/植物识别/weight_best_acc.pth')
print('————训练结束————')