forked from OpenNMT/OpenNMT-py
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext_baseline.py
132 lines (111 loc) · 4.99 KB
/
text_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# coding=utf8
from __future__ import absolute_import
from __future__ import division, print_function, unicode_literals
from sumy.parsers.html import HtmlParser
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.lsa import LsaSummarizer as Summarizer
from sumy.summarizers.lex_rank import LexRankSummarizer as LexRank
from sumy.nlp.stemmers import Stemmer
from sumy.utils import get_stop_words
import argparse
from rouge import Rouge
import os
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
sys.setrecursionlimit(1000000)
def LexRank_Text(originalText, LANGUAGE="chinese"):
"""Get LexRank output from a text.
Get LexRank output from a text.
Args:
originalText: Text.
LANGUAGE: The language of text.
Returns:
str
"""
parser = PlaintextParser.from_string(originalText, Tokenizer(LANGUAGE))
stemmer = Stemmer(LANGUAGE)
summarizer = LexRank(stemmer)
summarizer.stop_words = get_stop_words(LANGUAGE)
# print(summarizer(parser.document, 1))
for sentence in summarizer(parser.document, 1):
return str(sentence)
def read_and_filter(_file_path):
with open(_file_path, 'r') as f:
_new_lines = []
_session_length = []
for line in f:
#print('----')
#line = line.decode('utf-8')
sessions = line.strip('\n').split('||')
for s in sessions:
assert len(s.split('\t')) == 11
item_name = [s.split('\t')[9].split() for s in sessions]
item_comment = [s.split('\t')[10].split() for s in sessions]
_new_line = []
for tmp_name, tmp_comment in zip(item_name, item_comment):
_new_line.extend(tmp_name)
_new_line.extend(tmp_comment)
_new_lines.append(_new_line)
_session_length.append(len(tmp_name))
return _new_lines, _session_length
def write_file(_lines, _path):
#print(_lines)
#_lines = [_line.encode('utf-8') for _line in _lines]
#print(_lines)
with open(_path, 'w') as f:
f.write('\n'.join(_lines))
def read_file(_path):
with open(_path, 'r') as f:
lines = [line.strip('\n') for line in f]
return lines
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Parameters description")
parser.add_argument('-src', type=str, help="Source file path.")
parser.add_argument('-tgt', type=str, help="Target file path.")
parser.add_argument('-baseline', type=str,
help="Baseline Method for Explanation Generation.")
parser.add_argument('-save', type=str, help="Path to save the output file.")
parser.add_argument('-test', type=int,default=-1, help='Test the code with a small number of examples.')
parser.add_argument('-re', type=bool, default= False, help="Reprocess data.")
args = parser.parse_args()
LANGUAGE = "chinese"
assert args.baseline in ['lexrank']
if args.baseline == 'lexrank':
if os.path.exists(args.save+'/lexrank_src.txt') and os.path.exists(args.save+'/lexrank_tgt.txt') and args.re==False:
src_lines = read_file(args.save+'/lexrank_src.txt')
tgt_lines = read_file(args.save+'/lexrank_tgt.txt')
else:
_src_lines, _lengths = read_and_filter(args.src)
#print(_lengths)
_tgt_lines, _ = read_and_filter(args.tgt)
src_lines = []
tgt_lines = []
assert len(_src_lines) == len(_tgt_lines), len(
_src_lines) == len(_lengths)
for src_line, tgt_line, length in zip(_src_lines, _tgt_lines, _lengths):
if length < 20:
src_lines.append(' '.join(src_line))
tgt_lines.append(' '.join(tgt_line))
write_file(src_lines, args.save+'/lexrank_src.txt')
write_file(tgt_lines, args.save+'/lexrank_tgt.txt')
_sum_lines = []
if args.test > 0:
src_lines = src_lines[0:args.test]
tgt_lines = tgt_lines[0:args.test]
for i, _line in enumerate(src_lines):
print('process: {}/{}'.format(i,len(src_lines)))
_sum_line = LexRank_Text(_line, LANGUAGE) if LexRank_Text(_line, LANGUAGE) != None else 'None'
_sum_lines.append(_sum_line)
assert len(_sum_lines) == len(tgt_lines)
rouge = Rouge()
write_file(_sum_lines, args.save+'/lexrank_sumy.txt')
scores = rouge.get_scores(_sum_lines, tgt_lines, avg=True)
_sentence_msg = 'ROUGE-1: P:{}\tR:{}\tF1:{}'.format(
scores['rouge-1']['p'], scores['rouge-1']['r'], scores['rouge-1']['f'])
_sentence_msg += '\nROUGE-2: P:{}\tR:{}\tF1:{}'.format(
scores['rouge-2']['p'], scores['rouge-2']['r'], scores['rouge-2']['f'])
_sentence_msg += '\nROUGE-L: P:{}\tR:{}\tF1:{}'.format(
scores['rouge-l']['p'], scores['rouge-l']['r'], scores['rouge-l']['f'])
print(_sentence_msg)