From 12eec95baf6979af20fc28edcc9442cb203a1802 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Mon, 24 Jun 2024 07:13:11 +0200 Subject: [PATCH 1/3] feat: Use non-gated tokenizer as fallback for mistral --- .../amazon_bedrock/chat/adapters.py | 12 +++++- .../tests/test_chat_generator.py | 37 +++++++++++++++---- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 9d33a682d..9c8e7392b 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -1,5 +1,6 @@ import json import logging +import os from abc import ABC, abstractmethod from typing import Any, Callable, ClassVar, Dict, List @@ -332,7 +333,16 @@ def __init__(self, generation_kwargs: Dict[str, Any]): # Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer # a) we should get good estimates for the prompt length # b) we can use apply_chat_template with the template above to delineate ChatMessages - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. + if os.environ.get("HF_TOKEN"): + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + else: + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") + logger.warning(f"Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " + f"estimating the prompt length because no HF_TOKEN was found. Using " + f"NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " + f"HF_TOKEN containing a Hugging Face token and make sure you have access to the model.") + self.prompt_handler = DefaultPromptHandler( tokenizer=tokenizer, model_max_length=model_max_length, diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index ff807bcbe..e41419fe1 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -1,7 +1,10 @@ +import logging import os from typing import Optional, Type +from unittest.mock import patch import pytest +from _pytest.monkeypatch import MonkeyPatch from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -183,13 +186,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body -@pytest.mark.skipif( - not os.environ.get("HF_API_TOKEN", None), - reason=( - "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " - "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" - ), -) class TestMistralAdapter: def test_prepare_body_with_default_params(self) -> None: layer = MistralChatAdapter(generation_kwargs={}) @@ -245,6 +241,33 @@ def test_mistral_chat_template_incorrect_order(self): except Exception as e: assert "Conversation roles must alternate user/assistant/" in str(e) + def test_use_mistral_adapter_without_hf_token(self, monkeypatch: MonkeyPatch, caplog) -> None: + monkeypatch.delenv("HF_TOKEN", raising=False) + with ( + patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), + caplog.at_level(logging.WARNING) + ): + MistralChatAdapter(generation_kwargs={}) + mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") + assert "no HF_TOKEN was found" in caplog.text + + def test_use_mistral_adapter_with_hf_token(self, monkeypatch: MonkeyPatch) -> None: + monkeypatch.setenv("HF_TOKEN", "test") + with ( + patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler") + ): + MistralChatAdapter(generation_kwargs={}) + mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") + + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason=( + "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " + "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" + ), + ) @pytest.mark.parametrize("model_name", MISTRAL_MODELS) @pytest.mark.integration def test_default_inference_params(self, model_name, chat_messages): From 510192a4572c7296bd9fe6d1b97beeabd978f656 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Mon, 24 Jun 2024 07:20:37 +0200 Subject: [PATCH 2/3] formatting --- .../generators/amazon_bedrock/chat/adapters.py | 10 ++++++---- .../amazon_bedrock/tests/test_chat_generator.py | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 9c8e7392b..9f3524e83 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -338,10 +338,12 @@ def __init__(self, generation_kwargs: Dict[str, Any]): tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") else: tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") - logger.warning(f"Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " - f"estimating the prompt length because no HF_TOKEN was found. Using " - f"NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " - f"HF_TOKEN containing a Hugging Face token and make sure you have access to the model.") + logger.warning( + "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " + "estimating the prompt length because no HF_TOKEN was found. Using " + "NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var " + "HF_TOKEN containing a Hugging Face token and make sure you have access to the model." + ) self.prompt_handler = DefaultPromptHandler( tokenizer=tokenizer, diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index e41419fe1..d962e5030 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -246,7 +246,7 @@ def test_use_mistral_adapter_without_hf_token(self, monkeypatch: MonkeyPatch, ca with ( patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), - caplog.at_level(logging.WARNING) + caplog.at_level(logging.WARNING), ): MistralChatAdapter(generation_kwargs={}) mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") @@ -256,7 +256,7 @@ def test_use_mistral_adapter_with_hf_token(self, monkeypatch: MonkeyPatch) -> No monkeypatch.setenv("HF_TOKEN", "test") with ( patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, - patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler") + patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"), ): MistralChatAdapter(generation_kwargs={}) mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1") @@ -264,8 +264,8 @@ def test_use_mistral_adapter_with_hf_token(self, monkeypatch: MonkeyPatch) -> No @pytest.mark.skipif( not os.environ.get("HF_API_TOKEN", None), reason=( - "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " - "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" + "To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also " + "have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`" ), ) @pytest.mark.parametrize("model_name", MISTRAL_MODELS) From a50af3e03f79a13b6249aca856c4aee0fe478203 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Tue, 25 Jun 2024 23:58:18 +0200 Subject: [PATCH 3/3] fix linter issues --- .../components/generators/amazon_bedrock/chat/adapters.py | 5 +++-- integrations/amazon_bedrock/tests/test_chat_generator.py | 5 ++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 9f3524e83..162100934 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -334,10 +334,11 @@ def __init__(self, generation_kwargs: Dict[str, Any]): # a) we should get good estimates for the prompt length # b) we can use apply_chat_template with the template above to delineate ChatMessages # Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model. + tokenizer: PreTrainedTokenizer if os.environ.get("HF_TOKEN"): - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") else: - tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") logger.warning( "Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for " "estimating the prompt length because no HF_TOKEN was found. Using " diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index d962e5030..3e62b56ea 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -4,7 +4,6 @@ from unittest.mock import patch import pytest -from _pytest.monkeypatch import MonkeyPatch from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -241,7 +240,7 @@ def test_mistral_chat_template_incorrect_order(self): except Exception as e: assert "Conversation roles must alternate user/assistant/" in str(e) - def test_use_mistral_adapter_without_hf_token(self, monkeypatch: MonkeyPatch, caplog) -> None: + def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None: monkeypatch.delenv("HF_TOKEN", raising=False) with ( patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained, @@ -252,7 +251,7 @@ def test_use_mistral_adapter_without_hf_token(self, monkeypatch: MonkeyPatch, ca mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf") assert "no HF_TOKEN was found" in caplog.text - def test_use_mistral_adapter_with_hf_token(self, monkeypatch: MonkeyPatch) -> None: + def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None: monkeypatch.setenv("HF_TOKEN", "test") with ( patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,