Skip to content

Commit

Permalink
bug fixes, rev version number
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers committed Jan 28, 2025
1 parent d246b93 commit 9d75764
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 10 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test_lemonade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
shell: bash -el {0}
run: |
pylint src/lemonade --rcfile .pylintrc --disable E0401
pylint examples --rcfile .pylintrc --disable E0401,E0611 --jobs=1
- name: Test HF+CPU server
if: runner.os == 'Windows'
timeout-minutes: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/lemonade/demos/chat/chat_hybrid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from threading import Thread, Event
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import StoppingCriteriaList
from lemonade.tools.chat import StopOnEvent
from lemonade import leap
from lemonade.tools.ort_genai.oga import OrtGenaiStreamer
Expand Down
5 changes: 3 additions & 2 deletions examples/lemonade/demos/chat/chat_start.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
from threading import Thread, Event
from transformers import StoppingCriteriaList
from lemonade.tools.chat import StopOnEvent
from queue import Queue
from time import sleep
from transformers import StoppingCriteriaList
from lemonade.tools.chat import StopOnEvent


class TextStreamer:
Expand Down Expand Up @@ -43,6 +43,7 @@ def generate_placeholder(
Not needed once we integrate with LEAP.
"""

# pylint: disable=line-too-long
response = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."

for word in response.split(" "):
Expand Down
5 changes: 2 additions & 3 deletions examples/lemonade/demos/search/search_start.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import sys
from threading import Thread, Event
from queue import Queue
from time import sleep
from transformers import StoppingCriteriaList
from lemonade.tools.chat import StopOnEvent

# These imports are not needed when we add the LLM
from queue import Queue
from time import sleep

employee_handbook = """
1. You will work very hard every day.\n
Expand Down
6 changes: 4 additions & 2 deletions src/lemonade/tools/ort_genai/oga.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def generate(
self,
input_ids,
max_new_tokens=512,
min_new_tokens=0,
do_sample=True,
top_k=50,
top_p=1.0,
Expand All @@ -135,6 +136,7 @@ def generate(
params.pad_token_id = pad_token_id

max_length = len(input_ids) + max_new_tokens
min_length = len(input_ids) + min_new_tokens

if use_oga_pre_6_api:
params.input_ids = input_ids
Expand All @@ -147,7 +149,7 @@ def generate(
top_p=search_config.get("top_p", top_p),
temperature=search_config.get("temperature", temperature),
max_length=max_length,
min_length=0,
min_length=min_length,
early_stopping=search_config.get("early_stopping", False),
length_penalty=search_config.get("length_penalty", 1.0),
num_beams=search_config.get("num_beams", 1),
Expand All @@ -167,7 +169,7 @@ def generate(
top_p=top_p,
temperature=temperature,
max_length=max_length,
min_length=max_length,
min_length=min_length,
)
params.try_graph_capture_with_max_batch_size(1)

Expand Down
6 changes: 5 additions & 1 deletion src/lemonade/tools/ort_genai/oga_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ def run(
model.generate(input_ids, max_new_tokens=output_tokens)

for _ in tqdm.tqdm(range(iterations), desc="iterations"):
outputs = model.generate(input_ids, max_new_tokens=output_tokens)
outputs = model.generate(
input_ids,
max_new_tokens=output_tokens,
min_new_tokens=output_tokens,
)

token_len = len(outputs[0]) - input_ids_len

Expand Down
2 changes: 1 addition & 1 deletion src/turnkeyml/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "5.0.2"
__version__ = "5.0.3"

0 comments on commit 9d75764

Please sign in to comment.