Skip to content

Commit

Permalink
update batch infer.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Oct 12, 2024
1 parent 61e6fad commit e1ac048
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
3 changes: 1 addition & 2 deletions examples/gpt/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
]
m = GptCorrector("shibing624/chinese-text-correction-1.5b")

batch_res = m.correct_batch(error_sentences, prefix_prompt="文本纠错:\n\n")
batch_res = m.correct_batch(error_sentences, system_prompt=None, prefix_prompt="文本纠错:\n\n")
for i in batch_res:
print(i)
print()

batch_res = m.correct_batch(error_sentences,
system_prompt="你是一个中文文本纠错助手。请根据用户提供的原始文本,生成纠正后的文本。",
prompt_template_name="qwen",
prefix_prompt="")
for i in batch_res:
print(i)
Expand Down
42 changes: 20 additions & 22 deletions pycorrector/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def __init__(
else:
self.tokenizer.pad_token = self.tokenizer.eos_token
logger.debug("Add pad token: {}".format(self.tokenizer.pad_token))
if self.model.config.architectures[0] == "Qwen2ForCausalLM":
self.tokenizer.padding_side = "left"

self.args.model_type = model_type
if model_name is None:
Expand Down Expand Up @@ -550,42 +552,38 @@ def predict(
)

if prompt_template_name:
outputs = []
for s in batch:
messages = [[s, '']]
prompt = prompt_template.get_prompt(messages=messages, system_prompt=system_prompt)
inputs_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs_tokens['input_ids'].to(self.device)
output = self.model.generate(input_ids=input_ids, **generation_kwargs, **kwargs)
outputs.append(output[0])
prompts = [prompt_template.get_prompt(messages=[[s, '']], system_prompt=system_prompt) for s in batch]
inputs = self.tokenizer(prompts, padding=True, return_tensors='pt')
input_ids = inputs['input_ids'].to(self.device)
outputs = self.model.generate(input_ids, **generation_kwargs, **kwargs)
else:
outputs = []
conversation = []
for s in batch:
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': s})
input_id = self.tokenizer.apply_chat_template(
conversation=messages,
tokenize=True,
add_generation_prompt=False,
return_tensors='pt'
)
output = self.model.generate(input_id.to(self.device), **generation_kwargs, **kwargs)
outputs.append(output[0])
conversation.append(messages)
inputs = self.tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
add_generation_prompt=True,
return_tensors='pt',
padding=True,
)
outputs = self.model.generate(inputs.to(self.device), **generation_kwargs, **kwargs)

for prompt, generated_sequence in zip(batch, outputs):
for input_text, generated_sequence in zip(batch, outputs):
# Decode text
gen_text = self.tokenizer.decode(generated_sequence, skip_special_tokens=True)
stop_str = self.tokenizer.eos_token or prompt_template.stop_str
pos = gen_text.find(stop_str)
if pos != -1:
gen_text = gen_text[:pos]
if skip_prompt:
gen_text = gen_text.split(prompt, 1)[-1]
if "assistant" in gen_text:
gen_text = gen_text.split("assistant", 1)[-1]
gen_text = gen_text.strip()
gen_text = gen_text.split(input_text, 1)[-1]
if gen_text.startswith("\nassistant\n"):
gen_text = gen_text.split("\nassistant\n", 1)[-1]
all_outputs.append(gen_text)

return all_outputs
Expand Down

0 comments on commit e1ac048

Please sign in to comment.