Skip to content

Commit

Permalink
Merge branch 'main' into sparse-emb-eq
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Apr 23, 2024
2 parents 29fd8c0 + 081757c commit a5d646e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 19 deletions.
2 changes: 1 addition & 1 deletion haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class HuggingFaceAPIGenerator:
from haystack.utils import Secret
generator = HuggingFaceAPIGenerator(api_type="serverless_inference_api",
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_token("<your-api-key>"))
result = generator.run(prompt="What's Natural Language Processing?")
Expand Down
9 changes: 4 additions & 5 deletions test/components/generators/chat/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_init_tgi_no_url(self):
def test_to_dict(self, mock_check_valid_model):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
Expand All @@ -136,14 +136,14 @@ def test_to_dict(self, mock_check_valid_model):
init_params = result["init_parameters"]

assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API
assert init_params["api_params"] == {"model": "mistralai/Mistral-7B-v0.1"}
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}

def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceAPIChatGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
Expand All @@ -154,7 +154,7 @@ def test_from_dict(self, mock_check_valid_model):
# now deserialize, call from_dict
generator_2 = HuggingFaceAPIChatGenerator.from_dict(result)
assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
assert generator_2.api_params == {"model": "mistralai/Mistral-7B-v0.1"}
assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False)
assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}
assert generator_2.streaming_callback is streaming_callback_handler
Expand Down Expand Up @@ -225,7 +225,6 @@ def mock_iter(self):

# Generate text response with streaming callback
response = generator.run(chat_messages)
print(response)

# check kwargs passed to text_generation
_, kwargs = mock_chat_completion.call_args
Expand Down
22 changes: 9 additions & 13 deletions test/components/generators/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_init_tgi_no_url(self):
def test_to_dict(self, mock_check_valid_model):
generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
Expand All @@ -128,7 +128,7 @@ def test_to_dict(self, mock_check_valid_model):
init_params = result["init_parameters"]

assert init_params["api_type"] == HFGenerationAPIType.SERVERLESS_INFERENCE_API
assert init_params["api_params"] == {"model": "mistralai/Mistral-7B-v0.1"}
assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert init_params["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert init_params["generation_kwargs"] == {
"temperature": 0.6,
Expand All @@ -139,7 +139,7 @@ def test_to_dict(self, mock_check_valid_model):
def test_from_dict(self, mock_check_valid_model):
generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
Expand All @@ -150,7 +150,7 @@ def test_from_dict(self, mock_check_valid_model):
# now deserialize, call from_dict
generator_2 = HuggingFaceAPIGenerator.from_dict(result)
assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API
assert generator_2.api_params == {"model": "mistralai/Mistral-7B-v0.1"}
assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"}
assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False)
assert generator_2.generation_kwargs == {
"temperature": 0.6,
Expand All @@ -164,7 +164,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(
):
generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
token=Secret.from_env_var("ENV_VAR", strict=False),
generation_kwargs={"temperature": 0.6},
stop_words=["stop", "words"],
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_generate_text_response_with_valid_prompt_and_generation_parameters(

def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_model, mock_text_generation):
generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "mistralai/Mistral-7B-v0.1"}
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}
)

generation_kwargs = {"temperature": 0.8, "max_new_tokens": 100}
Expand All @@ -217,9 +217,7 @@ def test_generate_text_with_custom_generation_parameters(self, mock_check_valid_
assert len(response["meta"]) > 0
assert [isinstance(reply, str) for reply in response["replies"]]

def test_generate_text_with_streaming_callback(
self, mock_check_valid_model, mock_auto_tokenizer, mock_text_generation
):
def test_generate_text_with_streaming_callback(self, mock_check_valid_model, mock_text_generation):
streaming_call_count = 0

# Define the streaming callback function
Expand All @@ -228,10 +226,9 @@ def streaming_callback_fn(chunk: StreamingChunk):
streaming_call_count += 1
assert isinstance(chunk, StreamingChunk)

# Create an instance of HuggingFaceRemoteGenerator
generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
streaming_callback=streaming_callback_fn,
)

Expand Down Expand Up @@ -282,12 +279,11 @@ def mock_iter(self):
def test_run_serverless(self):
generator = HuggingFaceAPIGenerator(
api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API,
api_params={"model": "mistralai/Mistral-7B-v0.1"},
api_params={"model": "HuggingFaceH4/zephyr-7b-beta"},
generation_kwargs={"max_new_tokens": 20},
)

response = generator.run("How are you?")

# Assert that the response contains the generated replies
assert "replies" in response
assert isinstance(response["replies"], list)
Expand Down

0 comments on commit a5d646e

Please sign in to comment.