-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathchat_robot.py
101 lines (66 loc) · 2.41 KB
/
chat_robot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# -*- coding: utf-8 -*-
"""
Created on Wed May 30 15:47:57 2018
@author: shen1994
"""
from data_process import DataProcess
from test import load_model
from test import common_prediction
def clean_repeat_words(words):
data_process = DataProcess(use_word2cut=False)
words_length = len(words)
if words_length < 2:
return words
repeat_words = [",", ";", "。", "?", "!"]
new_words = []
last_word = words[0]
if not (last_word == data_process.__VOCAB__[0]):
new_words.append(last_word)
for index in range(1, words_length):
if (words[index] == last_word) and (words[index] in repeat_words):
continue
else:
if not (last_word == data_process.__VOCAB__[0]):
new_words.append(words[index])
last_word = words[index]
return new_words
def assembly_word(words):
default_answer = u"小哥哥,对不起呢,我不知道。"
data_process = DataProcess(use_word2cut=False)
words_length = len(words)
EOS_index = -1
for index in range(words_length):
if words[index] == data_process.__VOCAB__[3]:
EOS_index = index
break
if EOS_index == 0 or EOS_index == -1:
return default_answer
GO_index = -1
for index in range(words_length):
if words[index] == data_process.__VOCAB__[2]:
GO_index = index
break
new_words = []
if (GO_index - EOS_index) >= -1:
return default_answer
if GO_index == -1:
new_words.extend(words[0:EOS_index])
else:
new_words.extend(words[(GO_index + 1):EOS_index])
new_words = clean_repeat_words(new_words)
if not new_words:
return default_answer
text = "".join(word for word in new_words)
return text
def run():
questions = [u"我喜欢你?", u"品尝大董意境菜时兴奋不已,并起身激情拥抱"]
model = load_model("model/seq2seq_model_weights.h5")
prediction_words = common_prediction(model, questions)
for index in range(len(questions)):
print("------------------------------\n")
print("问: " + questions[index])
answer = assembly_word(prediction_words[index])
print("答: " + answer)
print("------------------------------\n")
if __name__ == "__main__":
run()