Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Use non-gated tokenizer as fallback for mistral in AmazonBedrockChatGenerator #843

Merged
merged 3 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, List

Expand Down Expand Up @@ -332,7 +333,19 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
# Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer
# a) we should get good estimates for the prompt length
# b) we can use apply_chat_template with the template above to delineate ChatMessages
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
# Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model.
tokenizer: PreTrainedTokenizer
if os.environ.get("HF_TOKEN"):
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
else:
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
logger.warning(
"Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for "
"estimating the prompt length because no HF_TOKEN was found. Using "
"NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var "
"HF_TOKEN containing a Hugging Face token and make sure you have access to the model."
)

self.prompt_handler = DefaultPromptHandler(
tokenizer=tokenizer,
model_max_length=model_max_length,
Expand Down
36 changes: 29 additions & 7 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from typing import Optional, Type
from unittest.mock import patch

import pytest
from haystack.components.generators.utils import print_streaming_chunk
Expand Down Expand Up @@ -183,13 +185,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None:
assert body == expected_body


@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason=(
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
),
)
class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MistralChatAdapter(generation_kwargs={})
Expand Down Expand Up @@ -245,6 +240,33 @@ def test_mistral_chat_template_incorrect_order(self):
except Exception as e:
assert "Conversation roles must alternate user/assistant/" in str(e)

def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None:
monkeypatch.delenv("HF_TOKEN", raising=False)
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
caplog.at_level(logging.WARNING),
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf")
assert "no HF_TOKEN was found" in caplog.text

def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None:
monkeypatch.setenv("HF_TOKEN", "test")
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1")

@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason=(
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
),
)
@pytest.mark.parametrize("model_name", MISTRAL_MODELS)
@pytest.mark.integration
def test_default_inference_params(self, model_name, chat_messages):
Expand Down