From 68b5cf02582e6107b5d57930d1c2439666356f86 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:23:13 +0800 Subject: [PATCH] [WebUI] Add prompt format and stopping words for Qwen (#10066) * add prompt format and stopping_words for qwen mdoel * performance optimization * optimize * update * meet comments --- .../modules/callbacks.py | 12 +++++++ .../modules/text_generation.py | 31 ++++++++++++++++--- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/python/llm/example/Text-Generation-WebUI/modules/callbacks.py b/python/llm/example/Text-Generation-WebUI/modules/callbacks.py index 8f753dfd8d1..0911ef03bad 100644 --- a/python/llm/example/Text-Generation-WebUI/modules/callbacks.py +++ b/python/llm/example/Text-Generation-WebUI/modules/callbacks.py @@ -41,6 +41,18 @@ def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> b return shared.stop_everything +class StopWordsCriteria(transformers.StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generated functions in the batch are completed.""" + def __init__(self, stop_words, tokenizer): + self.stop_words = stop_words + self.tokenizer = tokenizer + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the end-of-function strings.""" + text = self.tokenizer.decode(input_ids[-1][-1]) + return text in self.stop_words + + class Stream(transformers.StoppingCriteria): def __init__(self, callback_func=None): self.callback_func = callback_func diff --git a/python/llm/example/Text-Generation-WebUI/modules/text_generation.py b/python/llm/example/Text-Generation-WebUI/modules/text_generation.py index cf8227f4d6f..8f746b32c81 100644 --- a/python/llm/example/Text-Generation-WebUI/modules/text_generation.py +++ b/python/llm/example/Text-Generation-WebUI/modules/text_generation.py @@ -11,7 +11,7 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. +# limitations under the License. # This file is adapted from # https://github.com/oobabooga/text-generation-webui/blob/main/modules/text_generation.py @@ -35,7 +35,8 @@ from modules.callbacks import ( Iteratorize, Stream, - _StopEverythingStoppingCriteria + _StopEverythingStoppingCriteria, + StopWordsCriteria ) from modules.extensions import apply_extensions from modules.grammar.grammar_utils import initialize_grammar @@ -331,6 +332,19 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if shared.args.deepspeed: generate_params.update({'synced_gpus': True}) + #tune the prompt based on qwen + QWEN_PROMPT_FORMAT = """ + <|im_start|>system + You are a helpful assistant. + <|im_end|> + <|im_start|>user + {prompt} + <|im_end|> + <|im_start|>assistant + """ + if shared.model.config.model_type == "qwen": + question = QWEN_PROMPT_FORMAT.format(prompt=question) + # Encode the input input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) output = input_ids[0] @@ -346,10 +360,19 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings generate_params.update({'inputs_embeds': inputs_embeds}) # Stopping criteria / eos token + generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] generate_params['eos_token_id'] = eos_token_ids - generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() - generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) + + if shared.model.config.model_type == "qwen": + stopping_words = ["<|endoftext|>", "<|im_end|>", "<|im_start|>"] + generate_params['stopping_criteria'].append(StopWordsCriteria(stopping_words, shared.tokenizer)) + + for st in state['custom_stopping_strings']: + if type(st) is str: + stopping_words = [item.strip().strip('"') for item in [state['custom_stopping_strings']][0].split(',')] + generate_params['stopping_criteria'].append(StopWordsCriteria(stopping_words, shared.tokenizer)) + # Logits processor processor = state.get('logits_processor', LogitsProcessorList([]))