From 081757c6b9b1cadae5f783b3e1f872a9ecf87dc2 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 23 Apr 2024 13:56:07 +0200 Subject: [PATCH] test: replace mistral-7b with zephyr-7b-beta in tests (#7576) * replace mistral-7b with gemma-2b-it in tests * rm wrong comment * change model --- .../components/generators/hugging_face_api.py | 2 +- .../generators/chat/test_hugging_face_api.py | 9 ++++---- .../generators/test_hugging_face_api.py | 22 ++++++++----------- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/haystack/components/generators/hugging_face_api.py b/haystack/components/generators/hugging_face_api.py index ad1ede4ac0..803f432a63 100644 --- a/haystack/components/generators/hugging_face_api.py +++ b/haystack/components/generators/hugging_face_api.py @@ -35,7 +35,7 @@ class HuggingFaceAPIGenerator: from haystack.utils import Secret generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api", - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_token("")) result = generator.run(prompt="What's Natural Language Processing?") diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py index 0eb48e9bd5..df2b33618b 100644 --- a/test/components/generators/chat/test_hugging_face_api.py +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -126,7 +126,7 @@ def test_init_tgi_no_url(self): def test_to_dict(self, mock_check_valid_model): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], @@ -136,14 +136,14 @@ def test_to_dict(self, mock_check_valid_model): init_params = result["init_parameters"] assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API - assert init_params["api_params"] == {"model": "mistralai/Mistral-7B-v0.1"} + assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} def test_from_dict(self, mock_check_valid_model): generator = HuggingFaceAPIChatGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], @@ -154,7 +154,7 @@ def test_from_dict(self, mock_check_valid_model): # now deserialize, call from_dict generator_2 = HuggingFaceAPIChatGenerator.from_dict(result) assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API - assert generator_2.api_params == {"model": "mistralai/Mistral-7B-v0.1"} + assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} assert generator_2.streaming_callback is streaming_callback_handler @@ -225,7 +225,6 @@ def mock_iter(self): # Generate text response with streaming callback response = generator.run(chat_messages) - print(response) # check kwargs passed to text_generation _, kwargs = mock_chat_completion.call_args diff --git a/test/components/generators/test_hugging_face_api.py b/test/components/generators/test_hugging_face_api.py index 93d69585c7..8786e7f536 100644 --- a/test/components/generators/test_hugging_face_api.py +++ b/test/components/generators/test_hugging_face_api.py @@ -118,7 +118,7 @@ def test_init_tgi_no_url(self): def test_to_dict(self, mock_check_valid_model): generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], @@ -128,7 +128,7 @@ def test_to_dict(self, mock_check_valid_model): init_params = result["init_parameters"] assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API - assert init_params["api_params"] == {"model": "mistralai/Mistral-7B-v0.1"} + assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"} assert init_params["generation_kwargs"] == { "temperature": 0.6, @@ -139,7 +139,7 @@ def test_to_dict(self, mock_check_valid_model): def test_from_dict(self, mock_check_valid_model): generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], @@ -150,7 +150,7 @@ def test_from_dict(self, mock_check_valid_model): # now deserialize, call from_dict generator_2 = HuggingFaceAPIGenerator.from_dict(result) assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API - assert generator_2.api_params == {"model": "mistralai/Mistral-7B-v0.1"} + assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) assert generator_2.generation_kwargs == { "temperature": 0.6, @@ -164,7 +164,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( ): generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, token=Secret.from_env_var("ENV_VAR", strict=False), generation_kwargs={"temperature": 0.6}, stop_words=["stop", "words"], @@ -194,7 +194,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters( def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_model, mock_text_generation): generator = HuggingFaceAPIGenerator( - api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "mistralai/Mistral-7B-v0.1"} + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"} ) generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100} @@ -217,9 +217,7 @@ def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_ assert len(response["meta"]) > 0 assert [isinstance(reply, str) for reply in response["replies"]] - def test_generate_text_with_streaming_callback( - self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation - ): + def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_text_generation): streaming_call_count = 0 # Define the streaming callback function @@ -228,10 +226,9 @@ def streaming_callback_fn(chunk: StreamingChunk): streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) - # Create an instance of HuggingFaceRemoteGenerator generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, streaming_callback=streaming_callback_fn, ) @@ -282,12 +279,11 @@ def mock_iter(self): def test_run_serverless(self): generator = HuggingFaceAPIGenerator( api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, - api_params={"model": "mistralai/Mistral-7B-v0.1"}, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, generation_kwargs={"max_new_tokens": 20}, ) response = generator.run("How are you?") - # Assert that the response contains the generated replies assert "replies" in response assert isinstance(response["replies"], list)