Skip to content

Commit

Permalink
Fix the gibberish output from llm-prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
apsonawane committed Feb 1, 2025
1 parent 4f3be13 commit e6cf6e0
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from turnkeyml.tools import FirstTool
import turnkeyml.common.status as status
import turnkeyml.common.printing as printing
import numpy as np
from lemonade.tools.adapter import (
ModelAdapter,
TokenizerAdapter,
Expand Down Expand Up @@ -102,8 +103,10 @@ class OrtGenaiModel(ModelAdapter):
def __init__(self, input_folder):
super().__init__()
self.model = og.Model(input_folder)
self.model_path = input_folder
self.type = "ort-genai"
self.config = self.load_config(input_folder)
self.tokenizer = og.Tokenizer(self.model)

def load_config(self, input_folder):
config_path = os.path.join(input_folder, "genai_config.json")
Expand All @@ -124,7 +127,37 @@ def generate(
streamer: OrtGenaiStreamer = None,
pad_token_id=None,
stopping_criteria=None,
chat_template = "",
):

# Get model type
model_type = None
if hasattr(self.model, "type"):
model_type = self.model.type
else:
import json, os

with open(os.path.join(self.model_path, "genai_config.json"), "r") as f:
genai_config = json.load(f)
model_type = genai_config["model"]["type"]

# Set chat template
if chat_template:
if chat_template.count('{') != 1 or chat_template.count('}') != 1:
raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
else:
if model_type.startswith("phi2") or model_type.startswith("phi3"):
chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
elif model_type.startswith("phi4"):
chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>'
elif model_type.startswith("llama3"):
chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
elif model_type.startswith("llama2"):
chat_template = '<s>{input}'
elif model_type.startswith("qwen2"):
chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n'
else:
raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template")
params = og.GeneratorParams(self.model)

# There is a breaking API change in OGA 0.6.0
Expand All @@ -144,6 +177,13 @@ def generate(
if use_oga_pre_6_api:
params.input_ids = input_ids

if isinstance(input_ids, list):
input_ids_np = np.array(input_ids, dtype=np.int32)
else:
input_ids_np = input_ids.cpu().numpy().astype(np.int32)

decoded_prompt = self.tokenizer.decode(input_ids_np)

if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
Expand Down Expand Up @@ -177,8 +217,13 @@ def generate(
params.try_graph_capture_with_max_batch_size(1)

generator = og.Generator(self.model, params)
prompt = decoded_prompt
prompt = f'{chat_template.format(input=decoded_prompt)}'

input_tokens = self.tokenizer.encode(prompt)

if use_oga_post_6_api:
generator.append_tokens(input_ids)
generator.append_tokens(input_tokens)

if streamer is None:
prompt_start_time = time.perf_counter()
Expand Down

0 comments on commit e6cf6e0

Please sign in to comment.