Skip to content

Commit

Permalink
feat: Use non-gated tokenizer as fallback for mistral in AmazonBedroc…
Browse files Browse the repository at this point in the history
…kChatGenerator (#843)

* feat: Use non-gated tokenizer as fallback for mistral

* formatting

* fix linter issues
  • Loading branch information
julian-risch authored Jun 26, 2024
1 parent 05ccdb2 commit 6945503
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Callable, ClassVar, Dict, List

Expand Down Expand Up @@ -332,7 +333,19 @@ def __init__(self, generation_kwargs: Dict[str, Any]):
# Use `mistralai/Mistral-7B-v0.1` as tokenizer, all mistral models likely use the same tokenizer
# a) we should get good estimates for the prompt length
# b) we can use apply_chat_template with the template above to delineate ChatMessages
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
# Mistral models are gated on HF Hub. If no HF_TOKEN is found we use a non-gated alternative tokenizer model.
tokenizer: PreTrainedTokenizer
if os.environ.get("HF_TOKEN"):
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
else:
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
logger.warning(
"Gated mistralai/Mistral-7B-Instruct-v0.1 model cannot be used as a tokenizer for "
"estimating the prompt length because no HF_TOKEN was found. Using "
"NousResearch/Llama-2-7b-chat-hf instead. To use a mistral tokenizer export an env var "
"HF_TOKEN containing a Hugging Face token and make sure you have access to the model."
)

self.prompt_handler = DefaultPromptHandler(
tokenizer=tokenizer,
model_max_length=model_max_length,
Expand Down
36 changes: 29 additions & 7 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
from typing import Optional, Type
from unittest.mock import patch

import pytest
from haystack.components.generators.utils import print_streaming_chunk
Expand Down Expand Up @@ -183,13 +185,6 @@ def test_prepare_body_with_custom_inference_params(self) -> None:
assert body == expected_body


@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason=(
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
),
)
class TestMistralAdapter:
def test_prepare_body_with_default_params(self) -> None:
layer = MistralChatAdapter(generation_kwargs={})
Expand Down Expand Up @@ -245,6 +240,33 @@ def test_mistral_chat_template_incorrect_order(self):
except Exception as e:
assert "Conversation roles must alternate user/assistant/" in str(e)

def test_use_mistral_adapter_without_hf_token(self, monkeypatch, caplog) -> None:
monkeypatch.delenv("HF_TOKEN", raising=False)
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
caplog.at_level(logging.WARNING),
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("NousResearch/Llama-2-7b-chat-hf")
assert "no HF_TOKEN was found" in caplog.text

def test_use_mistral_adapter_with_hf_token(self, monkeypatch) -> None:
monkeypatch.setenv("HF_TOKEN", "test")
with (
patch("transformers.AutoTokenizer.from_pretrained") as mock_pretrained,
patch("haystack_integrations.components.generators.amazon_bedrock.chat.adapters.DefaultPromptHandler"),
):
MistralChatAdapter(generation_kwargs={})
mock_pretrained.assert_called_with("mistralai/Mistral-7B-Instruct-v0.1")

@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason=(
"To run this test, you need to set the HF_API_TOKEN environment variable. The associated account must also "
"have requested access to the gated model `mistralai/Mistral-7B-Instruct-v0.1`"
),
)
@pytest.mark.parametrize("model_name", MISTRAL_MODELS)
@pytest.mark.integration
def test_default_inference_params(self, model_name, chat_messages):
Expand Down

0 comments on commit 6945503

Please sign in to comment.