-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdvc_dataset.py
180 lines (141 loc) · 7.04 KB
/
dvc_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
'''
参考
https://github.com/ZhihaoHu/PyTorchVideoCompression/blob/master/DVC/dataset.py
'''
import os
import torch
import imageio
import numpy as np
import torch.utils.data as data
# from DVC.subnet.basics import *
# from DVC.subnet.ms_ssim_torch import ms_ssim
from DVC.augmentation import random_flip, random_crop_and_pad_image_and_labels
from test_video import PSNR, ms_ssim
# 以下参数来自 DCVC_net.py
out_channel_mv = 128
out_channel_N = 64
out_channel_M = 96
vimeo_data_path = '' # 'H:/Data/vimeo_septuplet/vimeo_septuplet/sequences/''../../../../../../mnt/h/Data/vimeo_septuplet/vimeo_septuplet/sequences/'
vimeo_test_list_path = '/mnt/data3/zhaojunzhang/vimeo_septuplet/mini_dvc_test_10k.txt'
class DataSet(data.Dataset):
def __init__(self, path=vimeo_test_list_path, im_height=256, im_width=256):
self.image_input_list, self.image_ref_list = self.get_vimeo(rootdir='/mnt/data3/zhaojunzhang/vimeo_septuplet/sequences/', filefolderlist=path)
self.im_height = im_height
self.im_width = im_width
self.featurenoise = torch.zeros([out_channel_M, self.im_height // 16, self.im_width // 16])
self.znoise = torch.zeros([out_channel_N, self.im_height // 64, self.im_width // 64])
self.mvnois = torch.zeros([out_channel_mv, self.im_height // 16, self.im_width // 16])
self.mvznois = torch.zeros([out_channel_N, self.im_height // 64, self.im_width // 64]) # modified
print("dataset find image: ", len(self.image_input_list))
# TODO: 改为使用全部数据?
def get_vimeo(self, rootdir, filefolderlist):
with open(filefolderlist) as f:
data = f.readlines()
fns_train_input = []
fns_train_ref = []
for n, line in enumerate(data, 1):
y = os.path.join(rootdir, line.rstrip())
# print(y)
fns_train_input += [y]
refnumber = int(y[-5:-4]) - 2
# refnumber = int(y[-5:-4]) - 1
refname = y[0:-5] + str(refnumber) + '.png'
fns_train_ref += [refname]
return fns_train_input, fns_train_ref
def __len__(self):
return len(self.image_input_list)
def __getitem__(self, index):
input_image = imageio.imread(self.image_input_list[index])
ref_image = imageio.imread(self.image_ref_list[index])
input_image = input_image.astype(np.float32) / 255.0
ref_image = ref_image.astype(np.float32) / 255.0
input_image = input_image.transpose(2, 0, 1)
ref_image = ref_image.transpose(2, 0, 1)
input_image = torch.from_numpy(input_image).float()
ref_image = torch.from_numpy(ref_image).float()
input_image, ref_image = random_crop_and_pad_image_and_labels(input_image, ref_image, [self.im_height, self.im_width])
input_image, ref_image = random_flip(input_image, ref_image)
quant_noise_feature, quant_noise_z, quant_noise_mv = torch.nn.init.uniform_(torch.zeros_like(self.featurenoise), -0.5, 0.5), torch.nn.init.uniform_(torch.zeros_like(self.znoise), -0.5, 0.5), torch.nn.init.uniform_(torch.zeros_like(self.mvnois), -0.5, 0.5)
quant_noise_z_mv = torch.nn.init.uniform_(torch.zeros_like(self.mvznois), -0.5, 0.5)
return input_image, ref_image, quant_noise_feature, quant_noise_z, quant_noise_mv, quant_noise_z_mv
class RawDataSet(data.Dataset):
def __init__(self, path=vimeo_test_list_path):
self.image_input_list, self.image_ref_list = self.get_vimeo(rootdir='/mnt/data3/zhaojunzhang/vimeo_septuplet/sequences/', filefolderlist=path)
print("dataset find image: ", len(self.image_input_list))
def get_vimeo(self, rootdir, filefolderlist):
with open(filefolderlist) as f:
data = f.readlines()
fns_train_input = []
fns_train_ref = []
for n, line in enumerate(data, 1):
y = os.path.join(rootdir, line.rstrip())
# print(y)
fns_train_input += [y]
refnumber = int(y[-5:-4]) - 2
# refnumber = int(y[-5:-4]) - 1
refname = y[0:-5] + str(refnumber) + '.png'
fns_train_ref += [refname]
return fns_train_input, fns_train_ref
def __len__(self):
return len(self.image_input_list)
def __getitem__(self, index):
input_image = imageio.imread(self.image_input_list[index])
ref_image = imageio.imread(self.image_ref_list[index])
input_image = input_image.astype(np.float32) / 255.0
ref_image = ref_image.astype(np.float32) / 255.0
input_image = input_image.transpose(2, 0, 1)
ref_image = ref_image.transpose(2, 0, 1)
input_image = torch.from_numpy(input_image).float()
ref_image = torch.from_numpy(ref_image).float()
return input_image, ref_image
class UVGDataSet(data.Dataset):
def __init__(self, root="/mnt/data3/zhaojunzhang/uvg4dcvc/images/", filelist="/mnt/data3/zhaojunzhang/uvg4dcvc/originalv.txt", refdir='', testfull=False, im_height=256, im_width=256):
self.im_height = im_height
self.im_width = im_width
with open(filelist) as f:
folders = f.readlines()
self.ref = []
# self.refbpp = []
self.input = []
self.hevcclass = []
# AllIbpp = self.getbpp(refdir)
ii = 0
for folder in folders:
seq = folder.rstrip()
# seqIbpp = AllIbpp[ii]
imlist = os.listdir(os.path.join(root, seq))
cnt = 0
for im in imlist:
if im[-4:] == '.png':
cnt += 1
if testfull:
framerange = cnt // 12
else:
framerange = 1
for i in range(framerange):
num = i * 12 + 1
refpath = os.path.join(root, seq, refdir, 'im'+str(num).zfill(3)+'.png')
inputpath = os.path.join(root, seq, 'im'+str(num+2).zfill(3)+'.png')
# inputpath = []
# for j in range(12):
# inputpath.append(os.path.join(root, seq, 'im' + str(i * 12 + j + 1).zfill(3)+'.png'))
self.ref.append(refpath)
# self.refbpp.append(seqIbpp)
self.input.append(inputpath)
ii += 1
# print(self.ref)
# print(self.input)
def __len__(self):
return len(self.ref)
def __getitem__(self, index):
input_image = imageio.imread(self.input[index])
ref_image = imageio.imread(self.ref[index])
input_image = input_image.astype(np.float32) / 255.0
ref_image = ref_image.astype(np.float32) / 255.0
input_image = input_image.transpose(2, 0, 1)
ref_image = ref_image.transpose(2, 0, 1)
input_image = torch.from_numpy(input_image).float()
ref_image = torch.from_numpy(ref_image).float()
input_image, ref_image = random_crop_and_pad_image_and_labels(input_image, ref_image, [self.im_height, self.im_width])
input_image, ref_image = random_flip(input_image, ref_image)
return input_image, ref_image