Skip to content

Commit

Permalink
[WebUI] Add prompt format and stopping words for Qwen (#10066)
Browse files Browse the repository at this point in the history
* add prompt format and stopping_words for qwen mdoel

* performance optimization

* optimize

* update

* meet comments
  • Loading branch information
sgwhat authored Feb 5, 2024
1 parent ef20adb commit 68b5cf0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 4 deletions.
12 changes: 12 additions & 0 deletions python/llm/example/Text-Generation-WebUI/modules/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> b
return shared.stop_everything


class StopWordsCriteria(transformers.StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
def __init__(self, stop_words, tokenizer):
self.stop_words = stop_words
self.tokenizer = tokenizer

def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the end-of-function strings."""
text = self.tokenizer.decode(input_ids[-1][-1])
return text in self.stop_words


class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.

# This file is adapted from
# https://github.com/oobabooga/text-generation-webui/blob/main/modules/text_generation.py
Expand All @@ -35,7 +35,8 @@
from modules.callbacks import (
Iteratorize,
Stream,
_StopEverythingStoppingCriteria
_StopEverythingStoppingCriteria,
StopWordsCriteria
)
from modules.extensions import apply_extensions
from modules.grammar.grammar_utils import initialize_grammar
Expand Down Expand Up @@ -331,6 +332,19 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})

#tune the prompt based on qwen
QWEN_PROMPT_FORMAT = """
<|im_start|>system
You are a helpful assistant.
<|im_end|>
<|im_start|>user
{prompt}
<|im_end|>
<|im_start|>assistant
"""
if shared.model.config.model_type == "qwen":
question = QWEN_PROMPT_FORMAT.format(prompt=question)

# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
Expand All @@ -346,10 +360,19 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
generate_params.update({'inputs_embeds': inputs_embeds})

# Stopping criteria / eos token
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = transformers.StoppingCriteriaList()
generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria())

if shared.model.config.model_type == "qwen":
stopping_words = ["<|endoftext|>", "<|im_end|>", "<|im_start|>"]
generate_params['stopping_criteria'].append(StopWordsCriteria(stopping_words, shared.tokenizer))

for st in state['custom_stopping_strings']:
if type(st) is str:
stopping_words = [item.strip().strip('"') for item in [state['custom_stopping_strings']][0].split(',')]
generate_params['stopping_criteria'].append(StopWordsCriteria(stopping_words, shared.tokenizer))


# Logits processor
processor = state.get('logits_processor', LogitsProcessorList([]))
Expand Down

0 comments on commit 68b5cf0

Please sign in to comment.