Skip to content

Commit

Permalink
feat: generators (2.0) (#5690)
Browse files Browse the repository at this point in the history
* add generators module

* add tests for module helper

* reno

* add another test

* move into openai

* improve tests
  • Loading branch information
ZanSara authored Aug 31, 2023
1 parent 6787ad2 commit 5f1256a
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 0 deletions.
Empty file.
Empty file.
33 changes: 33 additions & 0 deletions haystack/preview/components/generators/openai/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import logging

from haystack.preview.lazy_imports import LazyImport

with LazyImport("Run 'pip install tiktoken'") as tiktoken_import:
import tiktoken


logger = logging.getLogger(__name__)


def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str:
"""
Ensure that the length of the prompt is within the max tokens limit of the model.
If needed, truncate the prompt text so that it fits within the limit.
:param prompt: Prompt text to be sent to the generative model.
:param tokenizer: The tokenizer used to encode the prompt.
:param max_tokens_limit: The max tokens limit of the model.
:return: The prompt text that fits within the max tokens limit of the model.
"""
tiktoken_import.check()
tokens = tokenizer.encode(prompt)
tokens_count = len(tokens)
if tokens_count > max_tokens_limit:
logger.warning(
"The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. "
"Reduce the length of the prompt to prevent it from being cut off.",
tokens_count,
max_tokens_limit,
)
prompt = tokenizer.decode(tokens[:max_tokens_limit])
return prompt
2 changes: 2 additions & 0 deletions releasenotes/notes/generators-module-261376beb9c031cc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
preview:
- Add generators module for LLM generator components.
20 changes: 20 additions & 0 deletions test/preview/components/generators/openai/test_openai_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import pytest

from haystack.preview.components.generators.openai._helpers import enforce_token_limit


@pytest.mark.unit
def test_enforce_token_limit_above_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3)
assert prompt == "This is a"
assert caplog.records[0].message == (
"The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token "
"limit. Reduce the length of the prompt to prevent it from being cut off."
)


@pytest.mark.unit
def test_enforce_token_limit_below_limit(caplog, mock_tokenizer):
prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100)
assert prompt == "This is a test prompt."
assert not caplog.records
13 changes: 13 additions & 0 deletions test/preview/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from unittest.mock import Mock
import pytest


@pytest.fixture()
def mock_tokenizer():
"""
Tokenizes the string by splitting on spaces.
"""
tokenizer = Mock()
tokenizer.encode = lambda text: text.split()
tokenizer.decode = lambda tokens: " ".join(tokens)
return tokenizer

0 comments on commit 5f1256a

Please sign in to comment.