diff --git a/src/lemonade/tools/ort_genai/oga.py b/src/lemonade/tools/ort_genai/oga.py index e487ae8..9f2bb32 100644 --- a/src/lemonade/tools/ort_genai/oga.py +++ b/src/lemonade/tools/ort_genai/oga.py @@ -23,6 +23,7 @@ from turnkeyml.tools import FirstTool import turnkeyml.common.status as status import turnkeyml.common.printing as printing +import numpy as np from lemonade.tools.adapter import ( ModelAdapter, TokenizerAdapter, @@ -102,8 +103,10 @@ class OrtGenaiModel(ModelAdapter): def __init__(self, input_folder): super().__init__() self.model = og.Model(input_folder) + self.model_path = input_folder self.type = "ort-genai" self.config = self.load_config(input_folder) + self.tokenizer = og.Tokenizer(self.model) def load_config(self, input_folder): config_path = os.path.join(input_folder, "genai_config.json") @@ -124,7 +127,37 @@ def generate( streamer: OrtGenaiStreamer = None, pad_token_id=None, stopping_criteria=None, + chat_template = "", ): + + # Get model type + model_type = None + if hasattr(self.model, "type"): + model_type = self.model.type + else: + import json, os + + with open(os.path.join(self.model_path, "genai_config.json"), "r") as f: + genai_config = json.load(f) + model_type = genai_config["model"]["type"] + + # Set chat template + if chat_template: + if chat_template.count('{') != 1 or chat_template.count('}') != 1: + raise ValueError("Chat template must have exactly one pair of curly braces with input word in it, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'") + else: + if model_type.startswith("phi2") or model_type.startswith("phi3"): + chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>' + elif model_type.startswith("phi4"): + chat_template = '<|im_start|>user<|im_sep|>\n{input}<|im_end|>\n<|im_start|>assistant<|im_sep|>' + elif model_type.startswith("llama3"): + chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>' + elif model_type.startswith("llama2"): + chat_template = '{input}' + elif model_type.startswith("qwen2"): + chat_template = '<|im_start|>user\n{input}<|im_end|>\n<|im_start|>assistant\n' + else: + raise ValueError(f"Chat Template for model type {model_type} is not known. Please provide chat template using --chat_template") params = og.GeneratorParams(self.model) # There is a breaking API change in OGA 0.6.0 @@ -144,6 +177,13 @@ def generate( if use_oga_pre_6_api: params.input_ids = input_ids + if isinstance(input_ids, list): + input_ids_np = np.array(input_ids, dtype=np.int32) + else: + input_ids_np = input_ids.cpu().numpy().astype(np.int32) + + decoded_prompt = self.tokenizer.decode(input_ids_np) + if self.config and "search" in self.config: search_config = self.config["search"] params.set_search_options( @@ -177,8 +217,13 @@ def generate( params.try_graph_capture_with_max_batch_size(1) generator = og.Generator(self.model, params) + prompt = decoded_prompt + prompt = f'{chat_template.format(input=decoded_prompt)}' + + input_tokens = self.tokenizer.encode(prompt) + if use_oga_post_6_api: - generator.append_tokens(input_ids) + generator.append_tokens(input_tokens) if streamer is None: prompt_start_time = time.perf_counter()