-
Notifications
You must be signed in to change notification settings - Fork 176
/
test.py
155 lines (119 loc) · 6.62 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import argparse
import os
import torch
from torch import nn
from torch.nn import functional as F
import torchgeometry as tgm
from datasets import VITONDataset, VITONDataLoader
from networks import SegGenerator, GMM, ALIASGenerator
from utils import gen_noise, load_checkpoint, save_images
def get_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, required=True)
parser.add_argument('-b', '--batch_size', type=int, default=1)
parser.add_argument('-j', '--workers', type=int, default=1)
parser.add_argument('--load_height', type=int, default=1024)
parser.add_argument('--load_width', type=int, default=768)
parser.add_argument('--shuffle', action='store_true')
parser.add_argument('--dataset_dir', type=str, default='./datasets/')
parser.add_argument('--dataset_mode', type=str, default='test')
parser.add_argument('--dataset_list', type=str, default='test_pairs.txt')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints/')
parser.add_argument('--save_dir', type=str, default='./results/')
parser.add_argument('--display_freq', type=int, default=1)
parser.add_argument('--seg_checkpoint', type=str, default='seg_final.pth')
parser.add_argument('--gmm_checkpoint', type=str, default='gmm_final.pth')
parser.add_argument('--alias_checkpoint', type=str, default='alias_final.pth')
# common
parser.add_argument('--semantic_nc', type=int, default=13, help='# of human-parsing map classes')
parser.add_argument('--init_type', choices=['normal', 'xavier', 'xavier_uniform', 'kaiming', 'orthogonal', 'none'], default='xavier')
parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
# for GMM
parser.add_argument('--grid_size', type=int, default=5)
# for ALIASGenerator
parser.add_argument('--norm_G', type=str, default='spectralaliasinstance')
parser.add_argument('--ngf', type=int, default=64, help='# of generator filters in the first conv layer')
parser.add_argument('--num_upsampling_layers', choices=['normal', 'more', 'most'], default='most',
help='If \'more\', add upsampling layer between the two middle resnet blocks. '
'If \'most\', also add one more (upsampling + resnet) layer at the end of the generator.')
opt = parser.parse_args()
return opt
def test(opt, seg, gmm, alias):
up = nn.Upsample(size=(opt.load_height, opt.load_width), mode='bilinear')
gauss = tgm.image.GaussianBlur((15, 15), (3, 3))
gauss.cuda()
test_dataset = VITONDataset(opt)
test_loader = VITONDataLoader(opt, test_dataset)
with torch.no_grad():
for i, inputs in enumerate(test_loader.data_loader):
img_names = inputs['img_name']
c_names = inputs['c_name']['unpaired']
img_agnostic = inputs['img_agnostic'].cuda()
parse_agnostic = inputs['parse_agnostic'].cuda()
pose = inputs['pose'].cuda()
c = inputs['cloth']['unpaired'].cuda()
cm = inputs['cloth_mask']['unpaired'].cuda()
# Part 1. Segmentation generation
parse_agnostic_down = F.interpolate(parse_agnostic, size=(256, 192), mode='bilinear')
pose_down = F.interpolate(pose, size=(256, 192), mode='bilinear')
c_masked_down = F.interpolate(c * cm, size=(256, 192), mode='bilinear')
cm_down = F.interpolate(cm, size=(256, 192), mode='bilinear')
seg_input = torch.cat((cm_down, c_masked_down, parse_agnostic_down, pose_down, gen_noise(cm_down.size()).cuda()), dim=1)
parse_pred_down = seg(seg_input)
parse_pred = gauss(up(parse_pred_down))
parse_pred = parse_pred.argmax(dim=1)[:, None]
parse_old = torch.zeros(parse_pred.size(0), 13, opt.load_height, opt.load_width, dtype=torch.float).cuda()
parse_old.scatter_(1, parse_pred, 1.0)
labels = {
0: ['background', [0]],
1: ['paste', [2, 4, 7, 8, 9, 10, 11]],
2: ['upper', [3]],
3: ['hair', [1]],
4: ['left_arm', [5]],
5: ['right_arm', [6]],
6: ['noise', [12]]
}
parse = torch.zeros(parse_pred.size(0), 7, opt.load_height, opt.load_width, dtype=torch.float).cuda()
for j in range(len(labels)):
for label in labels[j][1]:
parse[:, j] += parse_old[:, label]
# Part 2. Clothes Deformation
agnostic_gmm = F.interpolate(img_agnostic, size=(256, 192), mode='nearest')
parse_cloth_gmm = F.interpolate(parse[:, 2:3], size=(256, 192), mode='nearest')
pose_gmm = F.interpolate(pose, size=(256, 192), mode='nearest')
c_gmm = F.interpolate(c, size=(256, 192), mode='nearest')
gmm_input = torch.cat((parse_cloth_gmm, pose_gmm, agnostic_gmm), dim=1)
_, warped_grid = gmm(gmm_input, c_gmm)
warped_c = F.grid_sample(c, warped_grid, padding_mode='border')
warped_cm = F.grid_sample(cm, warped_grid, padding_mode='border')
# Part 3. Try-on synthesis
misalign_mask = parse[:, 2:3] - warped_cm
misalign_mask[misalign_mask < 0.0] = 0.0
parse_div = torch.cat((parse, misalign_mask), dim=1)
parse_div[:, 2:3] -= misalign_mask
output = alias(torch.cat((img_agnostic, pose, warped_c), dim=1), parse, parse_div, misalign_mask)
unpaired_names = []
for img_name, c_name in zip(img_names, c_names):
unpaired_names.append('{}_{}'.format(img_name.split('_')[0], c_name))
save_images(output, unpaired_names, os.path.join(opt.save_dir, opt.name))
if (i + 1) % opt.display_freq == 0:
print("step: {}".format(i + 1))
def main():
opt = get_opt()
print(opt)
if not os.path.exists(os.path.join(opt.save_dir, opt.name)):
os.makedirs(os.path.join(opt.save_dir, opt.name))
seg = SegGenerator(opt, input_nc=opt.semantic_nc + 8, output_nc=opt.semantic_nc)
gmm = GMM(opt, inputA_nc=7, inputB_nc=3)
opt.semantic_nc = 7
alias = ALIASGenerator(opt, input_nc=9)
opt.semantic_nc = 13
load_checkpoint(seg, os.path.join(opt.checkpoint_dir, opt.seg_checkpoint))
load_checkpoint(gmm, os.path.join(opt.checkpoint_dir, opt.gmm_checkpoint))
load_checkpoint(alias, os.path.join(opt.checkpoint_dir, opt.alias_checkpoint))
seg.cuda().eval()
gmm.cuda().eval()
alias.cuda().eval()
test(opt, seg, gmm, alias)
if __name__ == '__main__':
main()