Skip to content

Commit

Permalink
refactor: avoid downloading tokenizer if truncate is False (#1152)
Browse files Browse the repository at this point in the history
* avoid downloading tokenizer if truncate is False

* fix
  • Loading branch information
anakin87 authored Oct 23, 2024
1 parent f57ec1a commit ae207f0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,16 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
# We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed
model_max_length = kwargs.get("model_max_length", 4096)

# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# we use GPT2 tokenizer which will likely provide good token count approximation

self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
max_length=self.max_length or 100,
)
# we initialize the prompt handler only if truncate is True: we avoid unnecessarily downloading the tokenizer
if self.truncate:
# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
max_length=self.max_length or 100,
)

model_adapter_cls = self.get_model_adapter(model=model)
if not model_adapter_cls:
Expand Down
8 changes: 8 additions & 0 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_
assert layer.prompt_handler.model_max_length == 4096


def test_prompt_handler_absent_when_truncate_false(mock_boto3_session):
"""
Test that the prompt_handler is not initialized when truncate is set to False.
"""
generator = AmazonBedrockGenerator(model="anthropic.claude-v2", truncate=False)
assert not hasattr(generator, "prompt_handler")


def test_constructor_with_model_kwargs(mock_boto3_session):
"""
Test that model_kwargs are correctly set in the constructor
Expand Down

0 comments on commit ae207f0

Please sign in to comment.