From e1ac048f949f8252f622fd8c7f98d5bacd6a9782 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Sun, 13 Oct 2024 03:53:17 +0800 Subject: [PATCH] update batch infer. --- examples/gpt/demo.py | 3 +-- pycorrector/gpt/gpt_model.py | 42 +++++++++++++++++------------------- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/gpt/demo.py b/examples/gpt/demo.py index 75bdf910..af53b44d 100644 --- a/examples/gpt/demo.py +++ b/examples/gpt/demo.py @@ -19,14 +19,13 @@ ] m = GptCorrector("shibing624/chinese-text-correction-1.5b") - batch_res = m.correct_batch(error_sentences, prefix_prompt="文本纠错:\n\n") + batch_res = m.correct_batch(error_sentences, system_prompt=None, prefix_prompt="文本纠错:\n\n") for i in batch_res: print(i) print() batch_res = m.correct_batch(error_sentences, system_prompt="你是一个中文文本纠错助手。请根据用户提供的原始文本,生成纠正后的文本。", - prompt_template_name="qwen", prefix_prompt="") for i in batch_res: print(i) diff --git a/pycorrector/gpt/gpt_model.py b/pycorrector/gpt/gpt_model.py index 31f70f89..d591845a 100644 --- a/pycorrector/gpt/gpt_model.py +++ b/pycorrector/gpt/gpt_model.py @@ -166,6 +166,8 @@ def __init__( else: self.tokenizer.pad_token = self.tokenizer.eos_token logger.debug("Add pad token: {}".format(self.tokenizer.pad_token)) + if self.model.config.architectures[0] == "Qwen2ForCausalLM": + self.tokenizer.padding_side = "left" self.args.model_type = model_type if model_name is None: @@ -550,31 +552,28 @@ def predict( ) if prompt_template_name: - outputs = [] - for s in batch: - messages = [[s, '']] - prompt = prompt_template.get_prompt(messages=messages, system_prompt=system_prompt) - inputs_tokens = self.tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs_tokens['input_ids'].to(self.device) - output = self.model.generate(input_ids=input_ids, **generation_kwargs, **kwargs) - outputs.append(output[0]) + prompts = [prompt_template.get_prompt(messages=[[s, '']], system_prompt=system_prompt) for s in batch] + inputs = self.tokenizer(prompts, padding=True, return_tensors='pt') + input_ids = inputs['input_ids'].to(self.device) + outputs = self.model.generate(input_ids, **generation_kwargs, **kwargs) else: - outputs = [] + conversation = [] for s in batch: messages = [] if system_prompt: messages.append({'role': 'system', 'content': system_prompt}) messages.append({'role': 'user', 'content': s}) - input_id = self.tokenizer.apply_chat_template( - conversation=messages, - tokenize=True, - add_generation_prompt=False, - return_tensors='pt' - ) - output = self.model.generate(input_id.to(self.device), **generation_kwargs, **kwargs) - outputs.append(output[0]) + conversation.append(messages) + inputs = self.tokenizer.apply_chat_template( + conversation=conversation, + tokenize=True, + add_generation_prompt=True, + return_tensors='pt', + padding=True, + ) + outputs = self.model.generate(inputs.to(self.device), **generation_kwargs, **kwargs) - for prompt, generated_sequence in zip(batch, outputs): + for input_text, generated_sequence in zip(batch, outputs): # Decode text gen_text = self.tokenizer.decode(generated_sequence, skip_special_tokens=True) stop_str = self.tokenizer.eos_token or prompt_template.stop_str @@ -582,10 +581,9 @@ def predict( if pos != -1: gen_text = gen_text[:pos] if skip_prompt: - gen_text = gen_text.split(prompt, 1)[-1] - if "assistant" in gen_text: - gen_text = gen_text.split("assistant", 1)[-1] - gen_text = gen_text.strip() + gen_text = gen_text.split(input_text, 1)[-1] + if gen_text.startswith("\nassistant\n"): + gen_text = gen_text.split("\nassistant\n", 1)[-1] all_outputs.append(gen_text) return all_outputs