-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutk_main.py
89 lines (77 loc) · 2.95 KB
/
utk_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch import nn, optim
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision import models, transforms
from utils.ethnicity_detection_check_result import check_result
from trainer import Trainer
from datasets.utk_dataset import NUM_ETHNICITY_BUCKETS, UTKDataset
BATCH_SIZE = 128
DATA_LOADER_NUM_WORKERS = 10
IMAGE_DIR = 'race/UTKFace'
MODEL_PATH = 'models/utk_model_resnet_50.pt'
def main():
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(f'Using device {device}')
# Use a pretrained RESNET-50 model.
model = models.resnet50(pretrained=True)
model = model.to(device=device)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_ETHNICITY_BUCKETS).to(device=device)
loss_func = CrossEntropyLoss().to(device=device)
# dtype depends on the loss function.
dtype = torch.cuda.LongTensor
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loader_train, loader_val, loader_test = _split_data()
model_trainer = Trainer(
model, loss_func, dtype, optimizer, device,
loader_train, loader_val, loader_test, check_result,
MODEL_PATH, num_epochs=5, print_every=500,
)
model_trainer.train()
model_trainer.test()
def _split_data():
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.59702533, 0.4573939, 0.3917105], [0.25691032, 0.22929442, 0.22493552]),
])
val_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.59702533, 0.4573939, 0.3917105], [0.25691032, 0.22929442, 0.22493552]),
])
train_dataset = UTKDataset(IMAGE_DIR, train_transform)
val_dataset = UTKDataset(IMAGE_DIR, val_transform)
test_dataset = UTKDataset(IMAGE_DIR, val_transform)
# Do a rough 8:1:1 split between training set, validation set and test set.
num_train = int(len(train_dataset) * 0.8)
num_val = int(len(val_dataset) * 0.1)
loader_train = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
num_workers=DATA_LOADER_NUM_WORKERS,
sampler=sampler.SubsetRandomSampler(range(num_train))
)
loader_val = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
num_workers=DATA_LOADER_NUM_WORKERS,
sampler=sampler.SubsetRandomSampler(range(num_train, num_train + num_val))
)
loader_test = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
num_workers=DATA_LOADER_NUM_WORKERS,
sampler=sampler.SubsetRandomSampler(range(num_train + num_val, len(test_dataset)))
)
return loader_train, loader_val, loader_test
if __name__ == '__main__':
main()