forked from seoungwugoh/STM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
youtube_dataset.py
executable file
·95 lines (70 loc) · 3.52 KB
/
youtube_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from os import path
import numpy as np
from PIL import Image
import torch
import torchvision
from torch.utils import data
import glob
class YOUTUBE_VOS_MO_Test(data.Dataset):
# for multi object, do shuffling
def __init__(self, root, all_frames_root, start_idx, end_idx):
self.root = root
self.mask_dir = path.join(root, 'Annotations')
self.image_dir = path.join(root, 'JPEGImages')
self.all_frames_image_dir = path.join(all_frames_root, 'JPEGImages')
self.videos = []
self.num_skip_frames = {}
self.num_frames = {}
self.shape = {}
self.frames_name = {}
self_vid_list = sorted(os.listdir(self.image_dir))
print('This process handles video %d to %d out of %d' % (start_idx, end_idx-1, len(self_vid_list)))
self_vid_list = self_vid_list[start_idx:end_idx]
for vid in self_vid_list:
self.videos.append(vid)
self.num_skip_frames[vid] = len(os.listdir(os.path.join(self.image_dir, vid)))
self.num_frames[vid] = len(os.listdir(os.path.join(self.all_frames_image_dir, vid)))
self.frames_name[vid] = sorted(os.listdir(os.path.join(self.all_frames_image_dir, vid)))
first_mask = os.listdir(path.join(self.mask_dir, vid))[0]
_mask = np.array(Image.open(path.join(self.mask_dir, vid, first_mask)).convert("P"))
self.shape[vid] = np.shape(_mask)
self.K = 7
def __len__(self):
return len(self.videos)
def To_onehot(self, mask):
M = np.zeros((self.K, mask.shape[0], mask.shape[1]), dtype=np.uint8)
for k in range(self.K):
M[k] = (mask == k).astype(np.uint8)
return M
def All_to_onehot(self, masks):
Ms = np.zeros((self.K, masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
for n in range(masks.shape[0]):
Ms[:,n] = self.To_onehot(masks[n])
return Ms
def __getitem__(self, index):
video = self.videos[index]
info = {}
info['name'] = video
info['num_frames'] = self.num_frames[video]
info['num_skip_frames'] = self.num_skip_frames[video]
info['num_objects'] = 0
info['frames_name'] = self.frames_name[video]
N_all_frames = np.empty((self.num_frames[video],)+self.shape[video]+(3,), dtype=np.float32)
N_frames = np.empty((self.num_skip_frames[video],)+self.shape[video]+(3,), dtype=np.float32)
N_masks = np.empty((self.num_skip_frames[video],)+self.shape[video], dtype=np.uint8)
for i, f in enumerate(sorted(os.listdir(path.join(self.image_dir, video)))):
img_file = path.join(self.image_dir, video, f)
N_frames[i] = np.array(Image.open(img_file).convert('RGB'))/255.
mask_file = path.join(self.mask_dir, video, f.replace('.jpg', '.png'))
N_masks[i] = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)
info['num_objects'] = max(info['num_objects'], N_masks[i].max())
for i, f in enumerate(info['frames_name']):
img_file = path.join(self.all_frames_image_dir, video, f)
N_all_frames[i] = np.array(Image.open(img_file).convert('RGB'))/255.
Fs = torch.from_numpy(np.transpose(N_frames.copy(), (3, 0, 1, 2)).copy()).float()
Ms = torch.from_numpy(self.All_to_onehot(N_masks).copy()).float()
AFs = torch.from_numpy(np.transpose(N_all_frames.copy(), (3, 0, 1, 2)).copy()).float()
return Fs, Ms, AFs, info
if __name__ == '__main__':
pass