Skip to content

Commit

Permalink
Fix the gibberish output from llm-prompt
Browse files Browse the repository at this point in the history
Signed-off-by: Akshay Sonawane <[email protected]>
  • Loading branch information
apsonawane committed Feb 1, 2025
1 parent 4f3be13 commit efebc2d
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import shutil
from fnmatch import fnmatch
from queue import Queue
import numpy as np
from packaging.version import Version
from huggingface_hub import snapshot_download
import onnxruntime_genai as og
Expand Down Expand Up @@ -102,8 +103,10 @@ class OrtGenaiModel(ModelAdapter):
def __init__(self, input_folder):
super().__init__()
self.model = og.Model(input_folder)
self.model_path = input_folder
self.type = "ort-genai"
self.config = self.load_config(input_folder)
self.tokenizer = og.Tokenizer(self.model)

def load_config(self, input_folder):
config_path = os.path.join(input_folder, "genai_config.json")
Expand All @@ -124,7 +127,44 @@ def generate(
streamer: OrtGenaiStreamer = None,
pad_token_id=None,
stopping_criteria=None,
chat_template="",
):

# Get model type
model_type = None
if hasattr(self.model, "type"):
model_type = self.model.type
else:
with open(os.path.join(self.model_path, "genai_config.json"), "r") as f:
genai_config = json.load(f)
model_type = genai_config["model"]["type"]

# Set chat template
if chat_template:
if chat_template.count("{") != 1 or chat_template.count("}") != 1:
raise ValueError(
"Chat template must have exactly one pair of curly braces \
with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'"
)
else:
if model_type.startswith("phi2") or model_type.startswith("phi3"):
chat_template = "<|user|>\n{input} <|end|>\n<|assistant|>"
elif model_type.startswith("phi4"):
chat_template = "<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>"
elif model_type.startswith("llama3"):
chat_template = "<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|> \
<|start_header_id|>assistant<|end_header_id|>"
elif model_type.startswith("llama2"):
chat_template = "<s>{input}"
elif model_type.startswith("qwen2"):
chat_template = (
"<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n"
)
else:
raise ValueError(
f"Chat Template for model type {model_type} is not known. \
Please provide chat template using --chat_template"
)
params = og.GeneratorParams(self.model)

# There is a breaking API change in OGA 0.6.0
Expand All @@ -144,6 +184,13 @@ def generate(
if use_oga_pre_6_api:
params.input_ids = input_ids

if isinstance(input_ids, list):
input_ids_np = np.array(input_ids, dtype=np.int32)
else:
input_ids_np = input_ids.cpu().numpy().astype(np.int32)

decoded_prompt = self.tokenizer.decode(input_ids_np)

if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
Expand Down Expand Up @@ -177,8 +224,13 @@ def generate(
params.try_graph_capture_with_max_batch_size(1)

generator = og.Generator(self.model, params)
prompt = decoded_prompt
prompt = f"{chat_template.format(input=decoded_prompt)}"

input_tokens = self.tokenizer.encode(prompt)

if use_oga_post_6_api:
generator.append_tokens(input_ids)
generator.append_tokens(input_tokens)

if streamer is None:
prompt_start_time = time.perf_counter()
Expand Down

0 comments on commit efebc2d

Please sign in to comment.