-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtranslate.py
310 lines (256 loc) · 9.42 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import asyncio
import ctranslate2
import transformers
from typing import Dict, List, Union
from stopes_snippet.split_clean import splitAndClean
translator = ctranslate2.Translator(
"nllb-200-3.3B-converted", device='auto', compute_type="float16", device_index=[0, 1])
# Cache the tokenizer, in case it's expensive to create them..
tokenizers: Dict[str, Union[transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast]] = {}
async def translate_batch_async(src_lang_flores: List[str], tgt_lang_flores: List[str], batches_textArrArr: List[List[str]]) -> List[str]:
# Flaten the passed batches down to simple flat arrays
flatBatches_textArr: List[str] = []
# flores lang codes..
flatBatches_src_lang_flores: List[str] = []
# flores lang codes..
flatBatches_tgt_lang_flores: List[str] = []
for i, batch_textArr in enumerate(batches_textArrArr):
# two-letter Google translate codes..
thisSrcLang_flores = src_lang_flores[i]
thisDestLang_flores = tgt_lang_flores[i]
for text in batch_textArr:
flatBatches_textArr.append(text)
flatBatches_src_lang_flores.append(src_lang_flores[i])
flatBatches_tgt_lang_flores.append(tgt_lang_flores[i])
# Further divide each string to sentences using a sentence splitter..
sentences: List[str] = []
# How many sentences were in each passed string..
sentences_counts: List[int] = []
sentences_src_lang_flores: List[str] = []
sentences_tgt_lang_flores: List[str] = []
for i, thisTranslateText in enumerate(flatBatches_textArr):
thisSrcLang_flores = flatBatches_src_lang_flores[i]
thisTgtLang_flores = flatBatches_tgt_lang_flores[i]
theseSentences: List[str] = splitAndClean(
thisSrcLang_flores, thisTranslateText)
sentences_counts.append(len(theseSentences))
sentences.extend(theseSentences)
sentences_src_lang_flores.extend(
[thisSrcLang_flores]*len(theseSentences))
sentences_tgt_lang_flores.extend(
[thisTgtLang_flores]*len(theseSentences))
# Tokenize the sentences
sentences_tokensied: List[List[str]] = []
for i, sentence in enumerate(sentences):
thisSrcLang_flores = sentences_src_lang_flores[i]
if (thisSrcLang_flores not in tokenizers):
tokenizers[thisSrcLang_flores] = transformers.AutoTokenizer.from_pretrained(
"nllb-200-3.3B", src_lang=thisSrcLang_flores)
tokenizer = tokenizers[thisSrcLang_flores]
thisSentenceTokens = tokenizer.convert_ids_to_tokens(
tokenizer.encode(sentence))
sentences_tokensied.extend([thisSentenceTokens])
# ok, let's translate already..
print(f"Processing batch: {len(sentences_tokensied)}")
def sync_func():
return translator.translate_batch(
sentences_tokensied,
target_prefix=[[thisDestLang_flores]
for thisDestLang_flores in sentences_tgt_lang_flores],
max_batch_size=128
)
# Run sync_func asyncronously, so we don't block the event loop.
# Allows other requests to be handled meanwhile.
loop = asyncio.get_event_loop()
results = await loop.run_in_executor(None, lambda: sync_func())
targets = [result.hypotheses[0][1:] for result in results]
sentences_translations = [tokenizer.decode(
tokenizer.convert_tokens_to_ids(target)) for target in targets]
# Let's reconstruct back to where we split with the sentence splitter..
flatBatches_translations: List[str] = []
for count in sentences_counts:
# Joining with a space, ideally this would be language specific..
flatBatches_translations.append(
' '.join(sentences_translations[:count]))
# Remove these items from the list
sentences_translations[:count] = []
# Let's assemble back into the passed batches..
batches_translationsArrArr: List[List[str]] = []
# Loop over the input batches
for batch_textArr in batches_textArrArr:
batch_translationArr: List[str] = []
# loop over the strings passed in each batch
for text in batch_textArr:
batch_translationArr.append(flatBatches_translations.pop(0))
batches_translationsArrArr.append(batch_translationArr)
return batches_translationsArrArr
# def translate_sync(src_lang_flores: str, tgt_lang_flores: str, textArr: List[str]) -> List[str]:
# if (src_lang_flores not in tokenizers):
# tokenizers[src_lang_flores] = transformers.AutoTokenizer.from_pretrained(
# "nllb-200-3.3B", src_lang=src_lang_flores)
# tokenizer = tokenizers[src_lang_flores]
# target_prefix = [tgt_lang_flores]
# arg1 = [tokenizer.convert_ids_to_tokens(
# tokenizer.encode(text)) for text in textArr]
# arg2 = target_prefix = [target_prefix]*len(textArr)
# print(f"Processing batch: {len(arg2)}")
# results = translator.translate_batch(
# arg1,
# target_prefix=arg2,
# max_batch_size=128
# )
# targets = [result.hypotheses[0][1:] for result in results]
# translations = [tokenizer.decode(
# tokenizer.convert_tokens_to_ids(target)) for target in targets]
# return translations
# async def translate_async(src_lang_flores: str, tgt_lang_flores: str, textArr: List[str]) -> List[str]:
# if (src_lang_flores not in tokenizers):
# tokenizers[src_lang_flores] = transformers.AutoTokenizer.from_pretrained(
# "nllb-200-3.3B", src_lang=src_lang_flores)
# tokenizer = tokenizers[src_lang_flores]
# target_prefix = [tgt_lang_flores]
# arg1 = [tokenizer.convert_ids_to_tokens(
# tokenizer.encode(text)) for text in textArr]
# arg2 = target_prefix = [target_prefix]*len(textArr)
# print(f"Processing batch: {len(arg2)}")
# def sync_func():
# return translator.translate_batch(
# arg1,
# target_prefix=arg2,
# max_batch_size=128
# )
# # Run sync_func asyncronously, so we don't block the event loop.
# # Allows other requests to be handled meanwhile.
# loop = asyncio.get_event_loop()
# results = await loop.run_in_executor(None, lambda: sync_func())
# targets = [result.hypotheses[0][1:] for result in results]
# translations = [tokenizer.decode(
# tokenizer.convert_tokens_to_ids(target)) for target in targets]
# # print(translations)
# return translations
googleToFlores200Codes = {
'af': 'afr_Latn',
'sq': 'als_Latn',
'am': 'amh_Ethi',
'ar': 'arb_Arab',
'hy': 'hye_Armn',
'az': 'azj_Latn', # 'North Azerbaijani'
'eu': 'eus_Latn',
'be': 'bel_Cyrl',
'bn': 'ben_Beng',
'bs': 'bos_Latn',
'bg': 'bul_Cyrl',
'ca': 'cat_Latn',
'ceb': 'ceb_Latn',
'zh-CN': 'zho_Hans',
'zh-TW': 'zho_Hant',
# Nope?
# co: 'Corsican',
'hr': 'hrv_Latn',
'cs': 'ces_Latn',
'da': 'dan_Latn',
'nl': 'nld_Latn',
'en': 'eng_Latn',
'eo': 'epo_Latn',
'et': 'est_Latn',
'fi': 'fin_Latn',
'fr': 'fra_Latn',
# Nope
# fy: 'Frisian',
'gl': 'glg_Latn',
'ka': 'kat_Geor',
'de': 'deu_Latn',
'el': 'ell_Grek',
'gu': 'guj_Gujr',
'ht': 'hat_Latn',
'ha': 'hau_Latn',
# Nope
# haw: 'Hawaiian',
'iw': 'heb_Hebr',
'hi': 'hin_Deva',
# Nope:
# hmn: 'Hmong',
'hu': 'hun_Latn',
'is': 'isl_Latn',
'ig': 'ibo_Latn',
'id': 'ind_Latn',
'ga': 'gle_Latn',
'it': 'ita_Latn',
'ja': 'jpn_Jpan',
'jv': 'jav_Latn',
'kn': 'kan_Knda',
'kk': 'kaz_Cyrl',
'km': 'khm_Khmr',
'ko': 'kor_Hang',
# Kurmanji - Northern Kurdish. There's also Sorani (central Kurdish)
'ku': 'kmr_Latn',
'ky': 'kir_Cyrl',
'lo': 'lao_Laoo',
# Nope:
# la: 'Latin',
'lv': 'lvs_Latn',
'lt': 'lit_Latn',
'lb': 'ltz_Latn',
'mk': 'mkd_Cyrl',
'mg': 'plt_Latn',
'ms': 'zsm_Latn',
'ml': 'mal_Mlym',
'mt': 'mlt_Latn',
'mi': 'mri_Latn',
'mr': 'mar_Deva',
'mn': 'khk_Cyrl',
'my': 'mya_Mymr',
'ne': 'npi_Deva',
# Bokmal, not newnorsk:
'no': 'nob_Latn',
'ny': 'nya_Latn',
# Southern Pashto
'ps': 'pbt_Arab',
'fa': 'pes_Arab',
'pl': 'pol_Latn',
'pt': 'por_Latn',
# Nope:
# pa: 'Punjabi',
'ro': 'ron_Latn',
'ru': 'rus_Cyrl',
'sm': 'smo_Latn',
'gd': 'gla_Latn',
# Note: cyrillic!
'sr': 'srp_Cyrl',
# Nope:
# st: 'Sesotho',
'sn': 'sna_Latn',
'sd': 'snd_Arab',
'si': 'sin_Sinh',
'sk': 'slk_Latn',
'sl': 'slv_Latn',
'so': 'som_Latn',
'es': 'spa_Latn',
'su': 'sun_Latn',
'sw': 'swh_Latn',
'sv': 'swe_Latn',
'tl': 'Ttgl_Latn',
'tg': 'tgk_Cyrl',
'ta': 'tam_Taml',
'te': 'tel_Telu',
'th': 'tha_Thai',
'tr': 'tur_Latn',
'uk': 'ukr_Cyrl',
'ur': 'urd_Arab',
'uz': 'uzn_Latn',
'vi': 'vie_Latn',
'cy': 'cym_Latn',
'xh': 'xho_Latn',
# 'Eastern' Yiddish
'yi': 'ydd_Hebr',
'yo': 'yor_Latn',
'zu': 'zul_Latn',
# New:
'rw': 'kin_Latn',
'or': 'ory_Orya',
# Tatarstan, there is also Crimean Tatar:
'tt': 'tat_Cyrl',
'tk': 'tuk_Latn',
'ug': 'uig_Arab',
}