|
| 1 | +import random |
| 2 | +import os |
| 3 | +import shutil |
| 4 | + |
| 5 | +data_path = os.path.abspath(os.path.join(os.getcwd(), "../..")) |
| 6 | + |
| 7 | +train_lines = 29000 |
| 8 | +val_lines = 1014 |
| 9 | +test_2016_lines = 1000 |
| 10 | +test_2017_lines = 1000 |
| 11 | +test_coco_lines = 461 |
| 12 | + |
| 13 | +dic = { |
| 14 | + 'c': '[MASK_C]', |
| 15 | + 'p': '[MASK_P]', |
| 16 | + 'n': '[MASK_N]', |
| 17 | + 'ns': '[MASK_NS]', |
| 18 | +} |
| 19 | + |
| 20 | +pos_color = open('../multi30k.color.position', 'r', encoding='utf-8') |
| 21 | +pos_noun = open('../multi30k.noun.position', 'r', encoding='utf-8') |
| 22 | +pos_nouns = open('../multi30k.nouns.position', 'r', encoding='utf-8') |
| 23 | +pos_people = open('../multi30k.people.position', 'r', encoding='utf-8') |
| 24 | + |
| 25 | +#pos_bpe_color = open('multi30k.color.bpe.position', 'w', encoding='utf-8') |
| 26 | +#pos_bpe_noun = open('multi30k.noun.bpe.position', 'w', encoding='utf-8') |
| 27 | +#pos_bpe_nouns = open('multi30k.nouns.bpe.position', 'w', encoding='utf-8') |
| 28 | +#pos_bpe_people = open('multi30k.people.bpe.position', 'w', encoding='utf-8') |
| 29 | + |
| 30 | +def record_origin_pos(pos): |
| 31 | + l = [] |
| 32 | + for line in pos: |
| 33 | + l.append(line.strip().split()) |
| 34 | + return l |
| 35 | + |
| 36 | +def get_position(f, o, l): |
| 37 | + num = 0 |
| 38 | + for line in f: |
| 39 | + if line[0] != '-1': |
| 40 | + for i in line: |
| 41 | + if i in l[num].keys(): # bpe used on this sentence |
| 42 | + y = l[num][i] |
| 43 | + for j in y: |
| 44 | + o.write(j+' ') |
| 45 | + else: |
| 46 | + o.write(i+' ') |
| 47 | + else: # no masking token in this line |
| 48 | + o.write('-1') |
| 49 | + o.write('\n') |
| 50 | + num += 1 |
| 51 | + |
| 52 | +def get_matching(): |
| 53 | + _l = [] # list of origin2bpe matching |
| 54 | + with open('origin2bpe.en-de.match', 'r', encoding='utf-8') as f: |
| 55 | + for sentence in f: |
| 56 | + dic = {} |
| 57 | + if sentence.strip() == '-1': |
| 58 | + _l.append(dic) # empty dict |
| 59 | + else: |
| 60 | + x = sentence.strip().split(' ') |
| 61 | + for i in x: |
| 62 | + dic[i.split(':')[0]] = i.split(':')[1].split('-') |
| 63 | + _l.append(dic) |
| 64 | + return _l |
| 65 | + |
| 66 | +def get_mask_bpe_pos(l, list_matching): |
| 67 | + new_l = [] |
| 68 | + for i, j in zip(l, list_matching): |
| 69 | + #print(i ,j) |
| 70 | + tmp = [] |
| 71 | + for tuple_pos in i: |
| 72 | + if tuple_pos[0] in j.keys(): |
| 73 | + for bpe_pos in j[tuple_pos[0]]: |
| 74 | + tmp.append((bpe_pos, tuple_pos[1])) |
| 75 | + else: |
| 76 | + tmp.append(tuple_pos) |
| 77 | + new_l.append(tmp) |
| 78 | + return new_l |
| 79 | + |
| 80 | +if __name__ == '__main__': |
| 81 | + list_matching = get_matching() |
| 82 | + |
| 83 | + # record origin text data's position |
| 84 | + _pos_people = record_origin_pos(pos_people) |
| 85 | + _pos_color = record_origin_pos(pos_color) |
| 86 | + _pos_noun = record_origin_pos(pos_noun) |
| 87 | + _pos_nouns = record_origin_pos(pos_nouns) |
| 88 | + pos_people.close() |
| 89 | + pos_color.close() |
| 90 | + pos_noun.close() |
| 91 | + pos_nouns.close() |
| 92 | + |
| 93 | + #get_position(_pos_color, pos_bpe_color, list_matching) |
| 94 | + #get_position(_pos_people, pos_bpe_people, list_matching) |
| 95 | + #get_position(_pos_noun, pos_bpe_noun, list_matching) |
| 96 | + #get_position(_pos_nouns, pos_bpe_nouns, list_matching) |
| 97 | + #pos_bpe_people.close() |
| 98 | + #pos_bpe_color.close() |
| 99 | + #pos_bpe_noun.close() |
| 100 | + #pos_bpe_nouns.close() |
| 101 | + |
| 102 | + # masking 1-4 |
| 103 | + for num in range(1, 5): |
| 104 | + l = [] # list of masking origin text data's position |
| 105 | + for p, c, n, ns in zip(_pos_people, _pos_color, _pos_noun, _pos_nouns): |
| 106 | + where = [] |
| 107 | + #if p[0] != '-1': |
| 108 | + # for i in p: |
| 109 | + # where.append((i, 'p')) |
| 110 | + #if c[0] != '-1': |
| 111 | + # for i in c: |
| 112 | + # where.append((i, 'c')) |
| 113 | + if n[0] != '-1': |
| 114 | + for i in n: |
| 115 | + where.append((i, 'n')) |
| 116 | + if ns[0] != '-1': |
| 117 | + for i in ns: |
| 118 | + where.append((i, 'ns')) |
| 119 | + |
| 120 | + if len(where) > num: |
| 121 | + where = random.sample(where, num) |
| 122 | + l.append(where) |
| 123 | + |
| 124 | + language = 'multi30k-en-de' |
| 125 | + mask_token = 'mask'+str(num) |
| 126 | + new_dir = os.path.join(data_path, language+'.'+mask_token) |
| 127 | + |
| 128 | + if not os.path.exists(new_dir): |
| 129 | + os.mkdir(new_dir) |
| 130 | + |
| 131 | + multi30k = open(os.path.join(data_path, 'multi30k', language+'.bpe.en'), 'r', encoding='utf-8') |
| 132 | + out_train = open(os.path.join(new_dir, 'train.en'), 'w', encoding='utf-8') |
| 133 | + out_valid = open(os.path.join(new_dir, 'valid.en'), 'w', encoding='utf-8') |
| 134 | + out_test_2016 = open(os.path.join(new_dir, 'test.2016.en'), 'w', encoding='utf-8') |
| 135 | + out_test_2017 = open(os.path.join(new_dir, 'test.2017.en'), 'w', encoding='utf-8') |
| 136 | + out_test_coco = open(os.path.join(new_dir, 'test.coco.en'), 'w', encoding='utf-8') |
| 137 | + |
| 138 | + new_l = get_mask_bpe_pos(l, list_matching) |
| 139 | + |
| 140 | + # write |
| 141 | + tmp = [] |
| 142 | + for line, position in zip(multi30k, new_l): |
| 143 | + lines = line.strip().split() |
| 144 | + for i in position: |
| 145 | + #print(i) |
| 146 | + lines[int(i[0])] = dic[i[1]] |
| 147 | + |
| 148 | + tmp.append(' '.join(lines)+'\n') |
| 149 | + |
| 150 | + for idx, i in enumerate(tmp): |
| 151 | + if idx < train_lines: |
| 152 | + out_train.write(i) |
| 153 | + elif train_lines <= idx and idx < train_lines+val_lines: |
| 154 | + out_valid.write(i) |
| 155 | + elif train_lines+val_lines <= idx and idx < train_lines+val_lines+test_2016_lines: |
| 156 | + out_test_2016.write(i) |
| 157 | + elif train_lines+val_lines+test_2016_lines <= idx and idx < train_lines+val_lines+test_2016_lines+test_2017_lines: |
| 158 | + out_test_2017.write(i) |
| 159 | + else: |
| 160 | + out_test_coco.write(i) |
| 161 | + |
| 162 | + # copy target language file |
| 163 | + target_language = language.split('-')[-1] |
| 164 | + shutil.copyfile(os.path.join(data_path, language, 'train.'+target_language), os.path.join(new_dir, 'train.'+target_language)) |
| 165 | + shutil.copyfile(os.path.join(data_path, language, 'valid.'+target_language), os.path.join(new_dir, 'valid.'+target_language)) |
| 166 | + shutil.copyfile(os.path.join(data_path, language, 'test.2016.'+target_language), os.path.join(new_dir, 'test.2016.'+target_language)) |
| 167 | + shutil.copyfile(os.path.join(data_path, language, 'test.2017.'+target_language), os.path.join(new_dir, 'test.2017.'+target_language)) |
| 168 | + shutil.copyfile(os.path.join(data_path, language, 'test.coco.'+target_language), os.path.join(new_dir, 'test.coco.'+target_language)) |
| 169 | + |
| 170 | + multi30k.close() |
| 171 | + out_train.close() |
| 172 | + out_valid.close() |
| 173 | + out_test_2016.close() |
| 174 | + out_test_2017.close() |
| 175 | + out_test_coco.close() |
0 commit comments