diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 9d33a682d..162100934 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -1,5 +1,6 @@ import json import logging +import os from abc import ABC, abstractmethod from typing import Any, Callable, ClassVar, Dict, List @@ -332,7 +333,19 @@ def __init__(self, generation_kwargs: Dict[str, Any]): # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer # a) we should get good estimates for the prompt length # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. + tokenizer: PreTrainedTokenizer + if os.environ.get("HF_TOKEN"): + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + else: + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") + logger.warning( + "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " + "estimating the prompt length because no HF_TOKEN was found. Using " + "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " + "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." + ) + self.prompt_handler = DefaultPromptHandler( tokenizer=tokenizer, model_max_length=model_max_length, diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ff807bcbe..3e62b56ea 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,5 +1,7 @@ +import logging import os from typing import Optional, Type +from unittest.mock import patch import pytest from haystack.components.generators.utils import print_streaming_chunk @@ -183,13 +185,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body -@pytest.mark.skipif( - not os.environ.get("HF_API_TOKEN", None), - reason=( - "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " - "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" - ), -) class TestMistralAdapter: def test_prepare_body_with_default_params(self) -> None: layer = MistralChatAdapter(generation_kwargs={}) @@ -245,6 +240,33 @@ def test_mistral_chat_template_incorrect_order(self): except Exception as e: assert "Conversation roles must alternate user/assistant/" in str(e) + def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None: + monkeypatch.delenv("HF_TOKEN", raising=False) + with ( + patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), + caplog.at_level(logging.WARNING), + ): + MistralChatAdapter(generation_kwargs={}) + mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") + assert "no HF_TOKEN was found" in caplog.text + + def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: + monkeypatch.setenv("HF_TOKEN", "test") + with ( + patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), + ): + MistralChatAdapter(generation_kwargs={}) + mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") + + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason=( + "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " + "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" + ), + ) @pytest.mark.parametrize("model_name", MISTRAL_MODELS) @pytest.mark.integration def test_default_inference_params(self, model_name, chat_messages):