Skip to content

Commit 14d20fb

Browse files
update
1 parent 117a44c commit 14d20fb

16 files changed

+331091
-4
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ sh train_mmt.sh
6464
sh translation_mmt.sh
6565
```
6666

67-
# masking data
67+
# create masking data
6868
```bash
6969
pip3 install stanfordcorenlp
7070
wget https://nlp.stanford.edu/software/stanford-corenlp-latest.zip
@@ -73,10 +73,10 @@ cd fairseq_mmt
7373
python3 record_masking_position.py
7474

7575
cd data/masking
76-
cd en2de
76+
# create en-de masking data
7777
python3 match_origin2bpe_position.py
78-
python3 get_bpe_position.py # create mask1-4 data
79-
python3 create_masking_multi30k.py # create mask color&people data
78+
python3 create_maskding1234_multi30k.py # create mask1-4 data
79+
python3 create_maskingcp_multi30k.py # create mask color&people data
8080

8181
sh preprocess_mmt.sh
8282
```
+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import random
2+
import os
3+
import shutil
4+
5+
data_path = os.path.abspath(os.path.join(os.getcwd(), "../.."))
6+
7+
train_lines = 29000
8+
val_lines = 1014
9+
test_2016_lines = 1000
10+
test_2017_lines = 1000
11+
test_coco_lines = 461
12+
13+
dic = {
14+
'c': '[MASK_C]',
15+
'p': '[MASK_P]',
16+
'n': '[MASK_N]',
17+
'ns': '[MASK_NS]',
18+
}
19+
20+
pos_color = open('../multi30k.color.position', 'r', encoding='utf-8')
21+
pos_noun = open('../multi30k.noun.position', 'r', encoding='utf-8')
22+
pos_nouns = open('../multi30k.nouns.position', 'r', encoding='utf-8')
23+
pos_people = open('../multi30k.people.position', 'r', encoding='utf-8')
24+
25+
#pos_bpe_color = open('multi30k.color.bpe.position', 'w', encoding='utf-8')
26+
#pos_bpe_noun = open('multi30k.noun.bpe.position', 'w', encoding='utf-8')
27+
#pos_bpe_nouns = open('multi30k.nouns.bpe.position', 'w', encoding='utf-8')
28+
#pos_bpe_people = open('multi30k.people.bpe.position', 'w', encoding='utf-8')
29+
30+
def record_origin_pos(pos):
31+
l = []
32+
for line in pos:
33+
l.append(line.strip().split())
34+
return l
35+
36+
def get_position(f, o, l):
37+
num = 0
38+
for line in f:
39+
if line[0] != '-1':
40+
for i in line:
41+
if i in l[num].keys(): # bpe used on this sentence
42+
y = l[num][i]
43+
for j in y:
44+
o.write(j+' ')
45+
else:
46+
o.write(i+' ')
47+
else: # no masking token in this line
48+
o.write('-1')
49+
o.write('\n')
50+
num += 1
51+
52+
def get_matching():
53+
_l = [] # list of origin2bpe matching
54+
with open('origin2bpe.en-de.match', 'r', encoding='utf-8') as f:
55+
for sentence in f:
56+
dic = {}
57+
if sentence.strip() == '-1':
58+
_l.append(dic) # empty dict
59+
else:
60+
x = sentence.strip().split(' ')
61+
for i in x:
62+
dic[i.split(':')[0]] = i.split(':')[1].split('-')
63+
_l.append(dic)
64+
return _l
65+
66+
def get_mask_bpe_pos(l, list_matching):
67+
new_l = []
68+
for i, j in zip(l, list_matching):
69+
#print(i ,j)
70+
tmp = []
71+
for tuple_pos in i:
72+
if tuple_pos[0] in j.keys():
73+
for bpe_pos in j[tuple_pos[0]]:
74+
tmp.append((bpe_pos, tuple_pos[1]))
75+
else:
76+
tmp.append(tuple_pos)
77+
new_l.append(tmp)
78+
return new_l
79+
80+
if __name__ == '__main__':
81+
list_matching = get_matching()
82+
83+
# record origin text data's position
84+
_pos_people = record_origin_pos(pos_people)
85+
_pos_color = record_origin_pos(pos_color)
86+
_pos_noun = record_origin_pos(pos_noun)
87+
_pos_nouns = record_origin_pos(pos_nouns)
88+
pos_people.close()
89+
pos_color.close()
90+
pos_noun.close()
91+
pos_nouns.close()
92+
93+
#get_position(_pos_color, pos_bpe_color, list_matching)
94+
#get_position(_pos_people, pos_bpe_people, list_matching)
95+
#get_position(_pos_noun, pos_bpe_noun, list_matching)
96+
#get_position(_pos_nouns, pos_bpe_nouns, list_matching)
97+
#pos_bpe_people.close()
98+
#pos_bpe_color.close()
99+
#pos_bpe_noun.close()
100+
#pos_bpe_nouns.close()
101+
102+
# masking 1-4
103+
for num in range(1, 5):
104+
l = [] # list of masking origin text data's position
105+
for p, c, n, ns in zip(_pos_people, _pos_color, _pos_noun, _pos_nouns):
106+
where = []
107+
#if p[0] != '-1':
108+
# for i in p:
109+
# where.append((i, 'p'))
110+
#if c[0] != '-1':
111+
# for i in c:
112+
# where.append((i, 'c'))
113+
if n[0] != '-1':
114+
for i in n:
115+
where.append((i, 'n'))
116+
if ns[0] != '-1':
117+
for i in ns:
118+
where.append((i, 'ns'))
119+
120+
if len(where) > num:
121+
where = random.sample(where, num)
122+
l.append(where)
123+
124+
language = 'multi30k-en-de'
125+
mask_token = 'mask'+str(num)
126+
new_dir = os.path.join(data_path, language+'.'+mask_token)
127+
128+
if not os.path.exists(new_dir):
129+
os.mkdir(new_dir)
130+
131+
multi30k = open(os.path.join(data_path, 'multi30k', language+'.bpe.en'), 'r', encoding='utf-8')
132+
out_train = open(os.path.join(new_dir, 'train.en'), 'w', encoding='utf-8')
133+
out_valid = open(os.path.join(new_dir, 'valid.en'), 'w', encoding='utf-8')
134+
out_test_2016 = open(os.path.join(new_dir, 'test.2016.en'), 'w', encoding='utf-8')
135+
out_test_2017 = open(os.path.join(new_dir, 'test.2017.en'), 'w', encoding='utf-8')
136+
out_test_coco = open(os.path.join(new_dir, 'test.coco.en'), 'w', encoding='utf-8')
137+
138+
new_l = get_mask_bpe_pos(l, list_matching)
139+
140+
# write
141+
tmp = []
142+
for line, position in zip(multi30k, new_l):
143+
lines = line.strip().split()
144+
for i in position:
145+
#print(i)
146+
lines[int(i[0])] = dic[i[1]]
147+
148+
tmp.append(' '.join(lines)+'\n')
149+
150+
for idx, i in enumerate(tmp):
151+
if idx < train_lines:
152+
out_train.write(i)
153+
elif train_lines <= idx and idx < train_lines+val_lines:
154+
out_valid.write(i)
155+
elif train_lines+val_lines <= idx and idx < train_lines+val_lines+test_2016_lines:
156+
out_test_2016.write(i)
157+
elif train_lines+val_lines+test_2016_lines <= idx and idx < train_lines+val_lines+test_2016_lines+test_2017_lines:
158+
out_test_2017.write(i)
159+
else:
160+
out_test_coco.write(i)
161+
162+
# copy target language file
163+
target_language = language.split('-')[-1]
164+
shutil.copyfile(os.path.join(data_path, language, 'train.'+target_language), os.path.join(new_dir, 'train.'+target_language))
165+
shutil.copyfile(os.path.join(data_path, language, 'valid.'+target_language), os.path.join(new_dir, 'valid.'+target_language))
166+
shutil.copyfile(os.path.join(data_path, language, 'test.2016.'+target_language), os.path.join(new_dir, 'test.2016.'+target_language))
167+
shutil.copyfile(os.path.join(data_path, language, 'test.2017.'+target_language), os.path.join(new_dir, 'test.2017.'+target_language))
168+
shutil.copyfile(os.path.join(data_path, language, 'test.coco.'+target_language), os.path.join(new_dir, 'test.coco.'+target_language))
169+
170+
multi30k.close()
171+
out_train.close()
172+
out_valid.close()
173+
out_test_2016.close()
174+
out_test_2017.close()
175+
out_test_coco.close()
+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import shutil
3+
4+
data_path = os.path.abspath(os.path.join(os.getcwd(), "../.."))
5+
6+
dic1 = {
7+
'maskc': 'color',
8+
'maskp': 'people',
9+
'maskn': 'noun',
10+
'maskns': 'nouns',
11+
}
12+
13+
dic2 = {
14+
'maskc': '[MASK_C]',
15+
'maskp': '[MASK_P]',
16+
'maskn': '[MASK_N]',
17+
'maskns': '[MASK_NS]',
18+
}
19+
20+
train_lines = 29000
21+
val_lines = 1014
22+
test_2016_lines = 1000
23+
test_2017_lines = 1000
24+
test_coco_lines = 461
25+
26+
if __name__ == "__main__":
27+
language = 'multi30k-en-de'
28+
for mask_token in ['maskc', 'maskp']:
29+
new_dir = os.path.join(data_path, language+'.'+mask_token)
30+
31+
if not os.path.exists(new_dir):
32+
os.mkdir(new_dir)
33+
34+
pos = open('multi30k.'+dic1[mask_token]+'.bpe.position', 'r', encoding='utf-8')
35+
multi30k = open(os.path.join(data_path, 'multi30k', language+'.bpe.en'), 'r', encoding='utf-8')
36+
out_train = open(os.path.join(new_dir, 'train.en'), 'w', encoding='utf-8')
37+
out_valid = open(os.path.join(new_dir, 'valid.en'), 'w', encoding='utf-8')
38+
out_test_2016 = open(os.path.join(new_dir, 'test.2016.en'), 'w', encoding='utf-8')
39+
out_test_2017 = open(os.path.join(new_dir, 'test.2017.en'), 'w', encoding='utf-8')
40+
out_test_coco = open(os.path.join(new_dir, 'test.coco.en'), 'w', encoding='utf-8')
41+
42+
tmp = []
43+
for line, position in zip(multi30k, pos):
44+
x = position.strip().split()
45+
lines = line.strip().split()
46+
for i in x:
47+
if i == '-1':
48+
break
49+
else:
50+
lines[int(i)] = dic2[mask_token]
51+
52+
tmp.append(' '.join(lines)+'\n')
53+
54+
# write
55+
for idx, i in enumerate(tmp):
56+
if idx < train_lines:
57+
out_train.write(i)
58+
elif train_lines <= idx and idx < train_lines+val_lines:
59+
out_valid.write(i)
60+
elif train_lines+val_lines <= idx and idx < train_lines+val_lines+test_2016_lines:
61+
out_test_2016.write(i)
62+
elif train_lines+val_lines+test_2016_lines <= idx and idx < train_lines+val_lines+test_2016_lines+test_2017_lines:
63+
out_test_2017.write(i)
64+
else:
65+
out_test_coco.write(i)
66+
67+
pos.close()
68+
multi30k.close()
69+
out_train.close()
70+
out_valid.close()
71+
out_test_2016.close()
72+
out_test_2017.close()
73+
out_test_coco.close()
74+
75+
# copy target language file
76+
target_language = language.split('-')[-1]
77+
shutil.copyfile(os.path.join(data_path, language, 'train.'+target_language), os.path.join(new_dir, 'train.'+target_language))
78+
shutil.copyfile(os.path.join(data_path, language, 'valid.'+target_language), os.path.join(new_dir, 'valid.'+target_language))
79+
shutil.copyfile(os.path.join(data_path, language, 'test.2016.'+target_language), os.path.join(new_dir, 'test.2016.'+target_language))
80+
shutil.copyfile(os.path.join(data_path, language, 'test.2017.'+target_language), os.path.join(new_dir, 'test.2017.'+target_language))
81+
shutil.copyfile(os.path.join(data_path, language, 'test.coco.'+target_language), os.path.join(new_dir, 'test.coco.'+target_language))

0 commit comments

Comments
 (0)