-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathpreprocess.py
73 lines (63 loc) · 2.44 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import sg_utils
import numpy as np
from collections import Counter
from pycoco.coco import COCO
from nltk import pos_tag, word_tokenize
def get_vocab_top_k(vocab, k):
v = dict();
for key in vocab.keys():
v[key] = vocab[key][:k]
return v
def get_vocab_counts(image_ids, coco_caps, max_cap, vocab):
counts = np.zeros((len(image_ids), len(vocab['words'])), dtype = np.float)
for i in xrange(len(image_ids)):
ann_ids = coco_caps.getAnnIds(image_ids[i])
assert(len(ann_ids) >= max_cap), 'less than {:d} number of captions for image {:d}'.format(max_cap, image_ids[i])
ann_ids.sort()
ann_ids = ann_ids[:max_cap]
anns = coco_caps.loadAnns(ann_ids)
tmp = [word_tokenize( str(a['caption']).lower()) for a in anns]
for (j,tmp_j) in enumerate(tmp):
pos = [vocab['words'].index(tmp_j_k) for tmp_j_k in tmp_j if tmp_j_k in vocab['words']]
pos = list(set(pos))
counts[i, pos] = counts[i,pos]+1
return counts
def get_vocab(imset, coco_caps, punctuations, mapping):
image_ids = coco_caps.getImgIds()
image_ids.sort(); t = []
for i in xrange(len(image_ids)):
annIds = coco_caps.getAnnIds(image_ids[i]);
anns = coco_caps.loadAnns(annIds);
tmp = [pos_tag( word_tokenize( str(a['caption']).lower())) for a in anns]
t.append(tmp)
# Make a vocabulary by computing counts of words over the whole dataset.
t = [t3 for t1 in t for t2 in t1 for t3 in t2]
t = [(l, 'other') if mapping.get(r) is None else (l, mapping[r]) for (l,r) in t]
vcb = Counter(elem for elem in t)
vcb = vcb.most_common()
# Merge things that are in the same or similar pos
word = [l for ((l,r),c) in vcb];
pos = [r for ((l,r),c) in vcb];
count = [c for ((l,r),c) in vcb];
poss = [];
counts = [];
words = sorted(set(word))
for j in xrange(len(words)):
indexes = [i for i,x in enumerate(word) if x == words[j]]
pos_tmp = [pos[i] for i in indexes]
count_tmp = [count[i] for i in indexes]
ind = np.argmax(count_tmp)
poss.append(pos_tmp[ind])
counts.append(sum(count_tmp))
ind = np.argsort(counts)
ind = ind[::-1]
words = [words[i] for i in ind]
poss = [poss[i] for i in ind]
counts = [counts[i] for i in ind]
# Remove punctuations
non_punct = [i for (i,x) in enumerate(words) if x not in punctuations]
words = [words[i] for i in non_punct]
counts = [counts[i] for i in non_punct]
poss = [poss[i] for i in non_punct]
vocab = {'words': words, 'counts': counts, 'poss': poss};
return vocab