-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathdataloader.py
75 lines (62 loc) · 2.13 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import json
import nltk
import time
import torch
from PIL import Image
class DataLoader():
def __init__(self, dir_path, vocab, transform):
self.images = None
self.captions_dict = None
self.vocab = vocab
self.transform = transform
self.load_captions(dir_path)
self.load_images(dir_path)
def load_images(self, dir_path):
files = os.listdir(dir_path)
images = {}
for file in files :
extension = file.split('.')[1]
if extension == 'jpg':
images[file] = self.transform(Image.open(os.path.join(dir_path, file)))
self.images = images
def load_captions(self, dir_path):
file = os.path.join(dir_path, 'captions.txt')
captions_dict = {}
with open(file) as f:
for line in f:
curr_dict = json.loads(line)
for i,txt in curr_dict.items():
captions_dict[i] = txt
self.captions_dict = captions_dict
def caption2ids(self, caption):
vocab = self.vocab
tokens = nltk.tokenize.word_tokenize(caption.lower())
vec = []
vec.append(vocab.get_id('<start>'))
vec.extend([vocab.get_id(word) for word in tokens])
vec.append(vocab.get_id('<end>'))
return vec
def gen_data(self):
images = []
captions = []
for image_id, curr_captions in self.captions_dict.items():
num_captions = len(curr_captions)
images.extend([image_id] * num_captions)
for caption in curr_captions:
captions.append(self.caption2ids(caption))
data = images, captions
return data
def get_image(self, image_id):
return self.images[image_id]
def shuffle_data(data, seed=0):
images, captions = data
shuffled_imgs = []
shuffled_captions = []
num_images = len(images)
torch.manual_seed(seed)
perm = list(torch.randperm(num_images))
for i in range(num_images):
shuffled_captions.append(captions[perm[i]])
shuffled_imgs.append(images[perm[i]])
return shuffled_imgs, shuffled_captions