-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaugment.py
66 lines (56 loc) · 2.25 KB
/
augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# encoding: utf-8
"""
Created by Gözde Gül Şahin
20.05.2018
Code to play with augmentation options and test on a single file
"""
__author__ = 'Gözde Gül Şahin'
from IO import conllud
from SP import augmenter
import codecs
import argparse
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-infile', type=str, default='./data/ud-treebanks-v2.1/UD_Turkish/tr-ud-test.conllu', help='UD file to augment')
parser.add_argument('-outfile', type=str, default='./data/ud-treebanks-v2.1/UD_Turkish/augmented.conllu', help='Output file')
parser.add_argument('-maxrot', type=int, default=3, help='Maximum number of rotation operations per sentence')
parser.add_argument('-prob', type=float, default=0.7, help='Probability of the augmentation operation')
parser.add_argument('-operation', type=str, default='rotate', help='rotate|crop')
args = parser.parse_args()
# Rotates and crops with given probabilities and saves the results
augment(args)
def augment(args):
inFile = args.infile
outfile = args.outfile
operation = args.operation
max_rotate = args.maxrot
ud_reader = conllud.conllUD(inFile)
ud_sents = ud_reader.sents
loi = [u"nsubj", u"dobj", u"iobj", u"obj", u"obl"]
pl = u"root"
# for predicate
multilabs = [u"case", u"fixed", u"flat", u"cop", u"compound"]
fout = codecs.open(outfile,'w','utf-8')
if operation=="rotate":
for s in ud_sents:
rotator = augmenter.rotator(s, aloi=loi, pl=pl, multilabs=multilabs, prob=1.0)
augSents = rotator.rotate(maxshuffle=max_rotate)
for augsent in augSents:
for row in augsent:
line = u"\t".join(row)
fout.write(line)
fout.write(u"\n")
fout.write(u"\n")
elif operation=="crop":
for s in ud_sents:
cropper = augmenter.cropper(s, aloi=loi, pl=pl, multilabs=multilabs, prob=1.0)
augSents = cropper.crop()
for augsent in augSents:
for row in augsent:
line = u"\t".join(row)
fout.write(line)
fout.write(u"\n")
fout.write(u"\n")
fout.close()
if __name__ == "__main__":
main()