Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ChatGPTGenerator #5692

Closed
wants to merge 34 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0fc2bac
add generators module
ZanSara Aug 30, 2023
7f6325c
add tests for module helper
ZanSara Aug 30, 2023
47b6799
add chatgpt generator
ZanSara Aug 30, 2023
4e8fcb3
add init and serialization tests
ZanSara Aug 30, 2023
cbf7701
test component
ZanSara Aug 30, 2023
419f615
reno
ZanSara Aug 30, 2023
49ff654
Merge branch 'main' into generators-module
ZanSara Aug 30, 2023
4edeb8e
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 30, 2023
08e9c62
reno
ZanSara Aug 30, 2023
a984e67
more tests
ZanSara Aug 30, 2023
612876a
add another test
ZanSara Aug 31, 2023
ec8e14a
Merge branch 'generators-module' of github.com:deepset-ai/haystack in…
ZanSara Aug 31, 2023
366b0ff
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 31, 2023
e9c3de7
chat token limit
ZanSara Aug 31, 2023
725fabe
move into openai
ZanSara Aug 31, 2023
4d4f9d4
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 31, 2023
c3bef8f
fix test
ZanSara Aug 31, 2023
c1a7696
improve tests
ZanSara Aug 31, 2023
246ca63
Merge branch 'generators-module' into chatgpt-generator
ZanSara Aug 31, 2023
ec809e4
add e2e test and small fixes
ZanSara Aug 31, 2023
5d946f8
linting
ZanSara Aug 31, 2023
aa9ce33
Add ChatGPTGenerator example
vblagoje Aug 31, 2023
9310057
review feedback
ZanSara Aug 31, 2023
7c36db1
Merge branch 'chatgpt-generator' of github.com:deepset-ai/haystack in…
ZanSara Aug 31, 2023
b2e421d
support for metadata
ZanSara Aug 31, 2023
6d81d79
Merge branch 'main' into chatgpt-generator
ZanSara Aug 31, 2023
2895697
mypy
ZanSara Aug 31, 2023
1538d61
mypy
ZanSara Sep 1, 2023
02cd61f
extract backend from generator and make it accept chats
ZanSara Sep 1, 2023
84332c6
fix tests
ZanSara Sep 1, 2023
329b54d
mypy
ZanSara Sep 4, 2023
5ee2aac
query->complete
ZanSara Sep 4, 2023
429a3ae
mypy
ZanSara Sep 4, 2023
c0b237d
Merge branch 'main' into chatgpt-generator
ZanSara Sep 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
chat token limit
  • Loading branch information
ZanSara committed Aug 31, 2023
commit e9c3de74862b0fb4c96bff43137c0641ec7498cb
44 changes: 44 additions & 0 deletions e2e/preview/components/test_chatgpt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import pytest
from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator


@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_chatgpt_generator_run():
component = ChatGPTGenerator(api_key=os.environ.get("OPENAI_API_KEY"))
results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"], n=1)

assert len(results["replies"]) == 2
assert len(results["replies"][0]) == 1
assert "Paris" in results["replies"][0][0]
assert len(results["replies"][1]) == 1
assert "Berlin" in results["replies"][1][0]


# @pytest.mark.skipif(
# not os.environ.get("OPENAI_API_KEY", None),
# reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
# )
# def test_chatgpt_generator_run_streaming():

# class Callback:
# def __init__(self):
# self.response = ""

# def __call__(self, token):
# self.responses += token
# return token

# callback = Callback()

# component = ChatGPTGenerator(
# os.environ.get("OPENAI_API_KEY"), stream=True, streaming_callback=callback
# )
# results = component.run(prompts=["test-prompt-1", "test-prompt-2"])

# assert results == {"replies": [["test-response-a"], ["test-response-b"]]}

# assert callback.responses == "test-response-a\ntest-response-b\n"
44 changes: 42 additions & 2 deletions haystack/preview/components/generators/_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import logging


Expand All @@ -18,11 +20,49 @@ def enforce_token_limit(prompt: str, tokenizer, max_tokens_limit: int) -> str:
tokens_count = len(tokens)
if tokens_count > max_tokens_limit:
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens so that the prompt fits within the max token "
"limit. Reduce the length of the prompt to prevent it from being cut off.",
"The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. "
"Reduce the length of the prompt to prevent it from being cut off.",
tokens_count,
max_tokens_limit,
)
tokenized_payload = tokenizer.encode(prompt)
prompt = tokenizer.decode(tokenized_payload[:max_tokens_limit])
return prompt


def enforce_token_limit_chat(
prompts: List[str], tokenizer, max_tokens_limit: int, tokens_per_message_overhead: int
) -> List[str]:
"""
Ensure that the length of the list of prompts is within the max tokens limit of the model.
If needed, truncate the prompts text and list so that it fits within the limit.

:param prompts: Prompts text to be sent to the generative model.
:param tokenizer: The tokenizer used to encode the prompt.
:param max_tokens_limit: The max tokens limit of the model.
:param tokens_per_message_overhead: The number of tokens that are added to the prompt text for each message.
:return: A list of prompts that fits within the max tokens limit of the model.
"""
prompts_lens = [len(tokenizer.encode(prompt)) for prompt in prompts]
if (total_prompt_length := sum(prompts_lens) + (tokens_per_message_overhead * len(prompts))) <= max_tokens_limit:
return prompts

logger.warning(
"The prompts have been truncated from %s tokens to %s tokens to fit within the max token limit. "
"Reduce the length of the prompt to prevent it from being cut off.",
total_prompt_length,
max_tokens_limit,
)
cut_prompts = []
cut_prompts_lens = []
for prompt, prompt_len in zip(prompts, prompts_lens):
prompt_len = prompt_len + sum(cut_prompts_lens) + (tokens_per_message_overhead * (len(cut_prompts) + 1))
if prompt_len <= max_tokens_limit:
cut_prompts.append(prompt)
cut_prompts_lens.append(prompt_len)
else:
remaining_tokens = (
max_tokens_limit - sum(cut_prompts_lens) - (tokens_per_message_overhead * (len(cut_prompts) + 1))
)
cut_prompts.append(enforce_token_limit(prompt, tokenizer, remaining_tokens))
return cut_prompts
19 changes: 7 additions & 12 deletions haystack/preview/components/generators/openai/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,14 @@ def raise_for_status(response: requests.Response):
:raises OpenAIError: If the response status code is not 200.
"""
if response.status_code >= 400:
openai_error: OpenAIError
if response.status_code == 429:
openai_error = OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
elif response.status_code == 401:
openai_error = OpenAIUnauthorizedError(f"API key is invalid: {response.text}")
else:
openai_error = OpenAIError(
f"OpenAI returned an error.\n"
f"Status code: {response.status_code}\n"
f"Response body: {response.text}",
status_code=response.status_code,
)
raise openai_error
raise OpenAIRateLimitError(f"API rate limit exceeded: {response.text}")
if response.status_code == 401:
raise OpenAIUnauthorizedError(f"API key is invalid: {response.text}")
raise OpenAIError(
f"OpenAI returned an error.\n" f"Status code: {response.status_code}\n" f"Response body: {response.text}",
status_code=response.status_code,
)


def check_truncated_answers(result: Dict[str, Any], payload: Dict[str, Any]):
Expand Down
18 changes: 12 additions & 6 deletions haystack/preview/components/generators/openai/chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from haystack.preview.lazy_imports import LazyImport
from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.components.generators._helpers import enforce_token_limit
from haystack.preview.components.generators._helpers import enforce_token_limit_chat
from haystack.preview.components.generators.openai._helpers import (
default_streaming_callback,
query_chat_model,
Expand All @@ -21,6 +21,9 @@
logger = logging.getLogger(__name__)


TOKENS_PER_MESSAGE_OVERHEAD = 4


@component
class ChatGPTGenerator:
"""
Expand Down Expand Up @@ -260,13 +263,16 @@ def run(

replies = []
for prompt in prompts:
system_prompt, prompt = enforce_token_limit_chat(
prompts=[system_prompt, prompts[0]],
tokenizer=self.tokenizer,
max_tokens_limit=self.max_tokens_limit,
tokens_per_message_overhead=TOKENS_PER_MESSAGE_OVERHEAD,
)

payload = {
**parameters,
"messages": enforce_token_limit(
prompt=[{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
tokenizer=self.tokenizer,
max_tokens_limit=self.max_tokens_limit,
),
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}],
}
if stream:
reply = query_chat_model_stream(
Expand Down
3 changes: 3 additions & 0 deletions haystack/preview/components/generators/openai/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ def __init__(self, message: Optional[str] = None, status_code: Optional[int] = N
self.message = message
self.status_code = status_code

def __str__(self):
return self.message + f"(status code {self.status_code})" if self.status_code else ""


class OpenAIRateLimitError(OpenAIError):
"""
Expand Down
11 changes: 11 additions & 0 deletions test/preview/components/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from unittest.mock import patch
import pytest


@pytest.fixture(autouse=True)
def tenacity_wait():
"""
Mocks tenacity's wait function to speed up tests.
"""
with patch("tenacity.nap.time"):
yield
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
)


@pytest.fixture(autouse=True)
def tenacity_wait():
with patch("tenacity.nap.time"):
yield


@pytest.mark.unit
def test_raise_for_status_200():
response = Mock()
Expand Down
39 changes: 36 additions & 3 deletions test/preview/components/generators/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from haystack.preview.components.generators._helpers import enforce_token_limit
from haystack.preview.components.generators._helpers import enforce_token_limit, enforce_token_limit_chat


@pytest.mark.unit
Expand All @@ -13,8 +13,8 @@ def test_enforce_token_limit_above_limit(caplog):

assert enforce_token_limit("This is a test prompt.", tokenizer=tokenizer, max_tokens_limit=3) == "This is a"
assert caplog.records[0].message == (
"The prompt has been truncated from 5 tokens to 3 tokens so that the prompt fits within the max token "
"limit. Reduce the length of the prompt to prevent it from being cut off."
"The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token limit. "
"Reduce the length of the prompt to prevent it from being cut off."
)


Expand All @@ -29,3 +29,36 @@ def test_enforce_token_limit_below_limit(caplog):
== "This is a test prompt."
)
assert not caplog.records


@pytest.mark.unit
def test_enforce_token_limit_chat_above_limit(caplog):
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)

assert enforce_token_limit_chat(
["System Prompt", "This is a test prompt."],
tokenizer=tokenizer,
max_tokens_limit=7,
tokens_per_message_overhead=2,
) == ["System Prompt", "This is a"]
assert caplog.records[0].message == (
"The prompts have been truncated from 11 tokens to 7 tokens to fit within the max token limit. "
"Reduce the length of the prompt to prevent it from being cut off."
)


@pytest.mark.unit
def test_enforce_token_limit_chat_below_limit(caplog):
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)

assert enforce_token_limit_chat(
["System Prompt", "This is a test prompt."],
tokenizer=tokenizer,
max_tokens_limit=100,
tokens_per_message_overhead=2,
) == ["System Prompt", "This is a test prompt."]
assert not caplog.records