forked from WynMew/Age-Gender-Predication
-
Notifications
You must be signed in to change notification settings - Fork 0
/
AgeGEvaResNet34_256.py
96 lines (84 loc) · 2.75 KB
/
AgeGEvaResNet34_256.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import torch
import torch.nn as nn
import numpy as np
import os
import pickle
import torchvision
from torch.autograd import Variable
from torchvision import datasets, models, transforms
import torch.optim as optim
import torch.nn.functional as F
from os.path import exists, join, basename, dirname
from os import makedirs, remove
import shutil
from torch.optim import lr_scheduler
import re
from dataloaderimdbwikiAgeG import *
from AgeGPreModelResNet34_256 import *
torch.cuda.set_device(0)
cwd = os.getcwd()
print(cwd)
model = AgeGPre()
model.cuda()
#checkpoint = torch.load('best_imdbAgePreV2_CrossEntloss.pth.tar', map_location=lambda storage, loc: storage)
#checkpoint = torch.load('imdbwikiAgePreResNet34_256_CrossEntloss.pth.tar', map_location=lambda storage, loc: storage)
#checkpoint = torch.load('best_imdbwikiAgePreResNet34_256_CrossEntloss.pth.tar', map_location=lambda storage, loc: storage)
#checkpoint = torch.load('best_imdbwikiAgePreReResNet34_256_CrossEntloss.pth.tar', map_location=lambda storage, loc: storage)
checkpoint = torch.load('imdbwikiAgeGPreResNet34Det256_CrossEntloss.pth.tar', map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
#checkpoint['state_dict'].keys()
#with open("/home/miaoqianwen/AgePre/DataAgeTest") as lmfile:
# lineNum=sum(1 for _ in lmfile)
with open("/home/miaoqianwen/AgePre/detTest") as lmfile:
lineNum=sum(1 for _ in lmfile)
it=iter(range(1, lineNum))
counter=0
Gcounter=0
diff=0
for m in it:
line = lc.getline("/home/miaoqianwen/AgePre/detTestG", m)
line = line.rstrip('\n')
file = line.split(' ')
ImgName = file[0]
iAge = int(file[1])
iGen = []
iGen.append(float(file[2]))
iGen = np.asarray(iGen)
input = io.imread(ImgName)
if input.ndim < 3:
input = cv2.cvtColor(input, cv2.COLOR_GRAY2RGB)
inp = cv2.resize(input, (256, 256))
imgI = (torch.from_numpy(inp.transpose((2, 0, 1))).float().div(255.0).unsqueeze_(0)-0.5)/0.5
imgI = imgI.cuda()
imgI = Variable(imgI)
model.eval()
agePre, genderPre = model(imgI)
v,i =torch.max(agePre[0], 0)
i=i.cpu().data.numpy()[0]
gP = genderPre.cpu().data.numpy()[0]
if gP <0.5:
print ("Gender Pre: 0")
if iGen[0] == 0:
Gcounter = Gcounter +1
else:
print ("Gemder Pre: 1")
if iGen[0] == 1:
Gcounter = Gcounter +1
print(ImgName)
print("label gender", ": ", iGen[0])
print(iAge, ":", i)
#fig = plt.figure()
#ax = fig.add_subplot(1, 1, 1)
#ax.imshow(inp)
#plt.show()
print("--------")
if abs(i - iAge) <= 5:
counter = counter + 1
diff = diff + abs(i - iAge)
print(counter)
print(counter/lineNum)
print(Gcounter)
print(Gcounter/lineNum)
print(diff)
print(diff/lineNum)