Skip to content

Commit

Permalink
Minor changes fixing onnxruntime_genai issue and input_path
Browse files Browse the repository at this point in the history
Signed-off-by: Akshay Sonawane <[email protected]>
  • Loading branch information
apsonawane committed Jan 15, 2025
1 parent 4e7450d commit 453ca89
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def generate(

max_length = len(input_ids) + max_new_tokens

params.input_ids = input_ids
if self.config and "search" in self.config:
search_config = self.config["search"]
params.set_search_options(
Expand Down Expand Up @@ -159,10 +158,10 @@ def generate(
params.try_graph_capture_with_max_batch_size(1)

generator = og.Generator(self.model, params)
generator.append_tokens(input_ids)

if streamer is None:
prompt_start_time = time.perf_counter()
generator.compute_logits()
generator.generate_next_token()
prompt_end_time = time.perf_counter()

Expand All @@ -173,7 +172,6 @@ def generate(
token_gen_times = []
while not generator.is_done():
token_gen_start_time = time.perf_counter()
generator.compute_logits()
generator.generate_next_token()
token_gen_end_time = time.perf_counter()

Expand All @@ -194,7 +192,6 @@ def generate(
stop_early = False

while not generator.is_done() and not stop_early:
generator.compute_logits()
generator.generate_next_token()

new_token = generator.get_next_tokens()[0]
Expand Down Expand Up @@ -252,6 +249,13 @@ def parser(add_help: bool = True) -> argparse.ArgumentParser:
add_help=add_help,
)

parser.add_argument(
"-ip",
"--input_path",
default="",
help= "the local huggingface model in your disk",
)

parser.add_argument(
"-d",
"--device",
Expand Down Expand Up @@ -303,6 +307,7 @@ def run(
self,
state: State,
input: str,
input_path: str = "",
device: str = "igpu",
dtype: str = "int4",
int4_block_size: int = None,
Expand Down Expand Up @@ -448,7 +453,7 @@ def run(
try:
model_builder.create_model(
checkpoint, # model_name
"", # input_path
input_path, # input_path
full_model_path, # output_path
dtype, # precision
execution_providers[device], # execution_provider
Expand Down

0 comments on commit 453ca89

Please sign in to comment.