forked from yinguobing/cnn-facial-landmark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlandmark.py
120 lines (97 loc) · 3.76 KB
/
landmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Convolutional Neural Network for facial landmarks detection.
"""
import argparse
from typing import Iterator
import cv2
import matplotlib.pyplot as plt
import numpy
import torch
from torch import optim
from torch.utils.data import DataLoader
from IterableTFRecordDataset import IterableTFRecordDataset
from model import get_landmark_model
IMAGE = "image/encoded"
MARKS = "label/marks"
INPUT_SHAPE = (128, 128, 3)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--train_record', default='train.record', type=str, help='Training record file')
parser.add_argument('--train_index', default='train.index', type=str, help='Training record index file')
parser.add_argument('--epochs', default=1, type=int, help='epochs for training')
parser.add_argument('--batch_size', default=16, type=int, help='training batch size')
return parser.parse_args()
def decode_image(features: dict):
features[IMAGE] = numpy.reshape(cv2.imdecode(features[IMAGE], -1), INPUT_SHAPE[::-1])
return features
def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int) -> DataLoader:
return DataLoader(dataset, batch_size=batch_size)
def get_loss_function() -> torch.nn.Module:
return torch.nn.MSELoss()
def get_optimizer(parameters: Iterator[torch.nn.Parameter], lr: float) -> torch.optim.Optimizer:
return optim.Adam(params=parameters, lr=lr)
def split_dataset(dataset, validation_size):
return torch.utils.data.random_split(dataset, (len(dataset) - validation_size, validation_size))
def main(train_record: str, train_index: str, batch_size: int, lr: float, output_size: int, validation_size: int):
train_set, val_set = split_dataset(
IterableTFRecordDataset(
train_record,
train_index,
description={IMAGE: "byte", MARKS: "byte"},
shuffle_queue_size=1024,
transform=decode_image
),
validation_size
)
model: torch.nn.Module = get_landmark_model(output_size=output_size)
criterion: torch.nn.Module = get_loss_function()
optimizer: torch.optim.Optimizer = get_optimizer(model.parameters(), lr)
train_data_loader = get_data_loader(train_set, batch_size)
val_data_loader = get_data_loader(val_set, batch_size)
train_losses = []
val_losses = []
for epoch in range(args.epochs):
train_loss = []
val_loss = []
model.train()
for data in iter(train_data_loader):
optimizer.zero_grad()
outputs = model.forward(data[IMAGE].float().cuda())
targets = data[MARKS].float().cuda()
loss = criterion.forward(outputs, targets)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
with torch.no_grad():
model.eval()
for data in iter(val_data_loader):
outputs = model.forward(data[IMAGE].float().cuda())
targets = data[MARKS].float().cuda()
loss = criterion.forward(outputs, targets)
val_loss.append(loss.item())
avg_train_loss = sum(train_loss) / len(train_loss)
avg_val_loss = sum(val_loss) / len(val_loss)
train_losses.append(avg_train_loss)
val_losses.append(avg_val_loss)
print(
f"""Epoch {epoch}:
Train: loss = {avg_train_loss}
Validation: loss = {avg_val_loss}
"""
)
plt.plot(numpy.array(train_losses))
plt.plot(numpy.array(val_losses))
plt.show()
if __name__ == '__main__':
OUTPUT_SIZE = 1583
LEARNING_RATE = 0.001
VALIDATION_SIZE = 500
args = get_args()
main(
args.train_record,
args.train_index,
args.batch_size,
LEARNING_RATE,
OUTPUT_SIZE,
VALIDATION_SIZE
)