Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加使用模型实现本地翻译 #13

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,17 @@ PyPDF2使用相对简单,但只支持英文,对中文支持不太友好;

return ret

#### 本地翻译 ####
需要使用huggingface的opus-mt-en-zh模型进行翻译
也可以下载到本地,将文件夹放在
src/TranslateTool/TranslateTool/model/opus-mt-en-zhs
```
--- src
--- --- TranslateTool
--- --- --- TranslateTool
--- --- --- --- model
```

### 写入文档 ###
#### 写TXT文档 ####
TXT文档的写比较简单,代码如下所示:
Expand Down
3 changes: 2 additions & 1 deletion src/TranslateTool/TranslateTool/T_Docx.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def translate(self):
content = para.text.strip() # 去掉多余空格

if content != '':
ret = translate_func(content)
# ret = translate_func(content)
ret = translate_local(content)
trans = ret if ret else '翻译失败'
# 写入新文件
new_doc.add_paragraph(content)
Expand Down
3 changes: 2 additions & 1 deletion src/TranslateTool/TranslateTool/T_Pdf_PyPDF2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def translate(self):
for line in content_list:
line = line.strip()
if line:
ret = translate_func(line)
# ret = translate_func(line)
ret = translate_local(line)
trans = ret if ret else '翻译失败'
self.write(line + '\n')
self.write(trans)
Expand Down
3 changes: 2 additions & 1 deletion src/TranslateTool/TranslateTool/T_Pdf_pdfminer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def translate(self):
content = out.get_text().strip()
if content:
to_trans_content = content.replace("\r\n","")
ret = translate_func(to_trans_content)
# ret = translate_func(to_trans_content)
ret = translate_local(to_trans_content)
trans = ret if ret else '翻译失败'

self.write(content)
Expand Down
7 changes: 5 additions & 2 deletions src/TranslateTool/TranslateTool/T_Txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def translate(self):

temp_line = temp_line.strip()
if temp_line:
ret = translate_func(temp_line)
# ret = translate_func(temp_line)
ret = translate_local(temp_line)
trans = ret if ret else '翻译失败'
self.write(temp_line)
self.write(trans + '\n')
Expand All @@ -67,7 +68,8 @@ def translate(self):
line = f.readline()

if temp_line: # 防止拼接之后,有一部分未执行翻译
ret = translate_func(temp_line)
# ret = translate_func(temp_line)
ret = translate_local(temp_line)
trans = ret if ret else '翻译失败'
self.write(temp_line)
self.write(trans + '\n')
Expand All @@ -78,6 +80,7 @@ def translate(self):
Logger().write(self.fileName + '翻译完成,新文档:' + self.new_path)



def prepare(self):
'''准备:生成的文件名和路径'''

Expand Down
15 changes: 13 additions & 2 deletions src/TranslateTool/TranslateTool/TranslateFunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import requests # pip intasll requests
from Py4Js import *
from Logger import *

from translate import *

# 百度翻译方法
def baidu_translate(content,type=1):
Expand Down Expand Up @@ -187,13 +187,21 @@ def google_translate(content):
else:
return (False,ret)

def translate_local(content):
'''本地模型翻译'''
# 模型下载到本地可实现离线操作

tokenizer,model = load_model()
ret = translate(content,tokenizer,model)
return ret

def translate_func(content):
'''集成百度、谷歌、有道多合一的翻译'''

funcs = [google_translate, youdao_translate] # baidu_translate,google_translate,youdao_translate
funcs = [translate_local,google_translate, youdao_translate] # baidu_translate,google_translate,youdao_translate
count = 0


# 循环调用百度、谷歌、有道API,其中如果谁调成功就返回,或者大于等于9次没有成功也返回。
while True:
for i in range(len(funcs)):
Expand All @@ -212,3 +220,6 @@ def translate_func(content):
return ''
else:
continue



48 changes: 48 additions & 0 deletions src/TranslateTool/TranslateTool/translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


class Config():
path = "src/TranslateTool/TranslateTool/model/opus-mt-en-zh"


def load_model():
config = Config()

tokenizer = AutoTokenizer.from_pretrained(config.path)

model = AutoModelForSeq2SeqLM.from_pretrained(config.path)

return tokenizer,model


def IN():
sentence=input("请输入:")
return sentence

def translate(text,tokenizer,model):
input_ids = tokenizer.encode(text, return_tensors="pt")
outputs = model.generate(input_ids)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
return decoded

def run():
print("---load_model---")
tokenizer,model = load_model()
print("---英译中---")
print("输入exit结束")
while (True):
text = IN()
if text == 'exit':

break

decoded = translate(text,tokenizer,model)

print(decoded)


if __name__ == '__main__':
run()