-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
107 lines (84 loc) · 3.62 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from masked_eda import eda, load_mlm
# Easy data augmentation techniques for text classification
# Jason Wei and Kai Zou
# arguments to be parsed from command line
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--input", required=True, type=str,
help="input file of unaugmented data")
ap.add_argument("--output", required=False, type=str,
help="output file of unaugmented data")
ap.add_argument("--mask_model", required=False, type=str,
help="uncased name of masked language model used in method")
ap.add_argument("--num_aug", required=False, type=int,
help="number of augmented sentences per original sentence")
ap.add_argument("--alpha_sr", required=False, type=float,
help="percent of words in each sentence to be replaced by synonyms")
ap.add_argument("--alpha_ri", required=False, type=float,
help="percent of words in each sentence to be inserted")
ap.add_argument("--alpha_rs", required=False, type=float,
help="percent of words in each sentence to be swapped")
ap.add_argument("--alpha_rd", required=False, type=float,
help="percent of words in each sentence to be deleted")
args = ap.parse_args()
# the output file
output = None
if args.output:
output = args.output
else:
from os.path import dirname, basename, join
output = join(dirname(args.input), 'eda_' + basename(args.input))
# uncased name of masked language model used in method
# bert , roberta , distilbert
mask_model = 'distilbert' # default
if args.mask_model:
mask_model = args.mask_model
# number of augmented sentences to generate per original sentence
num_aug = 9 # default
if args.num_aug:
num_aug = args.num_aug
# how much to replace each word by synonyms
alpha_sr = 0.1 # default
if args.alpha_sr is not None:
alpha_sr = args.alpha_sr
# how much to insert new words that are synonyms
alpha_ri = 0.1 # default
if args.alpha_ri is not None:
alpha_ri = args.alpha_ri
# how much to swap words
alpha_rs = 0.1 # default
if args.alpha_rs is not None:
alpha_rs = args.alpha_rs
# how much to delete words
alpha_rd = 0.1 # default
if args.alpha_rd is not None:
alpha_rd = args.alpha_rd
if alpha_sr == alpha_ri == alpha_rs == alpha_rd == 0:
ap.error('At least one alpha should be greater than zero')
# generate more data with standard augmentation
def gen_eda(train_orig, output_file, mask_model, alpha_sr, alpha_ri, alpha_rs, alpha_rd, num_aug=9):
writer = open(output_file, 'w')
lines = open(train_orig, 'r').readlines()
# initialize masked language model
loaded = load_mlm(mask_model)
if not loaded:
print('mlm loading failed')
return
for i, line in enumerate(lines):
if i % int(0.05 * len(lines)) == 0:
print('{}\tpercent done!'.format(int(100 * i / len(lines))))
parts = line[:-1].split('\t')
label = parts[0]
sentence = parts[1]
aug_sentences = eda(sentence, mask_model=mask_model, alpha_sr=alpha_sr, alpha_ri=alpha_ri,
alpha_rs=alpha_rs, p_rd=alpha_rd, num_aug=num_aug)
for aug_sentence in aug_sentences:
writer.write(label + "\t" + aug_sentence + '\n')
writer.close()
print("generated augmented sentences with eda for " + train_orig +
" to " + output_file + " with num_aug=" + str(num_aug))
# main function
if __name__ == "__main__":
# generate augmented sentences and output into a new file
gen_eda(args.input, output, mask_model=mask_model, alpha_sr=alpha_sr,
alpha_ri=alpha_ri, alpha_rs=alpha_rs, alpha_rd=alpha_rd, num_aug=num_aug)