-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
109 lines (80 loc) · 3.06 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import glob
import h5py
import random
import numpy as np
from PIL import Image
import torch.utils.data as data
import torchvision.transforms as transforms
def random_crop(hr, lr, size, scale):
h, w = lr.shape[:-1]
x = random.randint(0, w-size)
y = random.randint(0, h-size)
hsize = size*scale
hx, hy = x*scale, y*scale
crop_lr = lr[y:y+size, x:x+size].copy()
crop_hr = hr[hy:hy+hsize, hx:hx+hsize].copy()
return crop_hr, crop_lr
def random_flip_and_rotate(im1, im2):
if random.random() < 0.5:
im1 = np.flipud(im1)
im2 = np.flipud(im2)
if random.random() < 0.5:
im1 = np.fliplr(im1)
im2 = np.fliplr(im2)
angle = random.choice([0, 1, 2, 3])
im1 = np.rot90(im1, angle)
im2 = np.rot90(im2, angle)
return im1.copy(), im2.copy()
class TrainDataset(data.Dataset):
def __init__(self, path, size, scale):
super(TrainDataset, self).__init__()
self.size = size
h5f = h5py.File(path, "r")
self.hr = [v[:] for v in h5f["HR"].values()]
if scale == 0:
self.scale = [2, 3, 4]
self.lr = [[v[:] for v in h5f["X{}".format(i)].values()] for i in self.scale]
else:
self.scale = [scale]
self.lr = [[v[:] for v in h5f["X{}".format(scale)].values()]]
h5f.close()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
size = self.size
item = [(self.hr[index], self.lr[i][index]) for i, _ in enumerate(self.lr)]
item = [random_crop(hr, lr, size, self.scale[i]) for i, (hr, lr) in enumerate(item)]
item = [random_flip_and_rotate(hr, lr) for hr, lr in item]
return [(self.transform(hr), self.transform(lr)) for hr, lr in item]
def __len__(self):
return len(self.hr)
class TestDataset(data.Dataset):
def __init__(self, dirname, scale):
super(TestDataset, self).__init__()
self.name = dirname.split("/")[-1]
self.scale = scale
if "DIV" in self.name:
self.hr = glob.glob(os.path.join("{}_HR".format(dirname), "*.png"))
self.lr = glob.glob(os.path.join("{}_LR_bicubic".format(dirname),
"X{}/*.png".format(scale)))
else:
hr_files = glob.glob(os.path.join(dirname, "image_SRF_{}/HR/*.png".format(scale)))
lr_files = glob.glob(os.path.join(dirname, "image_SRF_{}/LR/*.png".format(scale)))
self.hr = [name for name in hr_files]
self.lr = [name for name in lr_files]
self.hr.sort()
self.lr.sort()
self.transform = transforms.Compose([
transforms.ToTensor()
])
def __getitem__(self, index):
hr = Image.open(self.hr[index])
lr = Image.open(self.lr[index])
hr = hr.convert("RGB")
lr = lr.convert("RGB")
filename = self.hr[index].split("/")[-1]
return self.transform(hr), self.transform(lr), filename
def __len__(self):
return len(self.hr)