diff --git a/.github/workflows/amazon_bedrock.yml b/.github/workflows/amazon_bedrock.yml index 75f881a50..8b1651764 100644 --- a/.github/workflows/amazon_bedrock.yml +++ b/.github/workflows/amazon_bedrock.yml @@ -18,9 +18,14 @@ concurrency: group: amazon-bedrock-${{ github.head_ref }} cancel-in-progress: true +permissions: + id-token: write + contents: read + env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" + AWS_REGION: us-east-1 jobs: run: @@ -56,5 +61,11 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run docs + - name: AWS authentication + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 + with: + aws-region: ${{ env.AWS_REGION }} + role-to-assume: ${{ secrets.AWS_CI_ROLE_ARN }} + - name: Run tests run: hatch run cov diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index 196a55743..cdb871f40 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -1,7 +1,7 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import Any, Callable, ClassVar, Dict, List from botocore.eventstream import EventStream from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk @@ -44,33 +44,37 @@ def get_responses(self, response_body: Dict[str, Any]) -> List[ChatMessage]: :param response_body: The response body. :returns: The extracted responses. """ - return self._extract_messages_from_response(self.response_body_message_key(), response_body) + return self._extract_messages_from_response(response_body) - def get_stream_responses(self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None]) -> List[str]: + def get_stream_responses( + self, stream: EventStream, stream_handler: Callable[[StreamingChunk], None] + ) -> List[ChatMessage]: tokens: List[str] = [] + last_decoded_chunk: Dict[str, Any] = {} for event in stream: chunk = event.get("chunk") if chunk: - decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) - token = self._extract_token_from_stream(decoded_chunk) - # take all the rest key/value pairs from the chunk, add them to the metadata - stream_metadata = {k: v for (k, v) in decoded_chunk.items() if v != token} - stream_chunk = StreamingChunk(content=token, meta=stream_metadata) - # callback the stream handler with StreamingChunk - stream_handler(stream_chunk) + last_decoded_chunk = json.loads(chunk["bytes"].decode("utf-8")) + token = self._extract_token_from_stream(last_decoded_chunk) + stream_chunk = StreamingChunk(content=token) # don't extract meta, we care about tokens only + stream_handler(stream_chunk) # callback the stream handler with StreamingChunk tokens.append(token) responses = ["".join(tokens).lstrip()] - return responses + return [ChatMessage.from_assistant(response, meta=last_decoded_chunk) for response in responses] @staticmethod - def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> None: + def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any], allowed_params: List[str]) -> None: """ Updates target_dict with values from updates_dict. Merges lists instead of overriding them. :param target_dict: The dictionary to update. :param updates_dict: The dictionary with updates. + :param allowed_params: The list of allowed params to use. """ for key, value in updates_dict.items(): + if key not in allowed_params: + logger.warning(f"Parameter '{key}' is not allowed and will be ignored.") + continue if key in target_dict and isinstance(target_dict[key], list) and isinstance(value, list): # Merge lists and remove duplicates target_dict[key] = sorted(set(target_dict[key] + value)) @@ -78,21 +82,24 @@ def _update_params(target_dict: Dict[str, Any], updates_dict: Dict[str, Any]) -> # Override the value in target_dict target_dict[key] = value - def _get_params(self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any]) -> Dict[str, Any]: + def _get_params( + self, inference_kwargs: Dict[str, Any], default_params: Dict[str, Any], allowed_params: List[str] + ) -> Dict[str, Any]: """ Merges params from inference_kwargs with the default params and self.generation_kwargs. Uses a helper function to merge lists or override values as necessary. :param inference_kwargs: The inference kwargs to merge. :param default_params: The default params to start with. + :param allowed_params: The list of allowed params to use. :returns: The merged params. """ # Start with a copy of default_params kwargs = default_params.copy() # Update the default params with self.generation_kwargs and finally inference_kwargs - self._update_params(kwargs, self.generation_kwargs) - self._update_params(kwargs, inference_kwargs) + self._update_params(kwargs, self.generation_kwargs, allowed_params) + self._update_params(kwargs, inference_kwargs, allowed_params) return kwargs @@ -124,25 +131,14 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: :returns: A dictionary containing the resized prompt and additional information. """ - def _extract_messages_from_response(self, message_tag: str, response_body: Dict[str, Any]) -> List[ChatMessage]: + @abstractmethod + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ Extracts the messages from the response body. - :param message_tag: The key for the message in the response body. :param response_body: The response body. :returns: The extracted ChatMessage list. """ - metadata = {k: v for (k, v) in response_body.items() if k != message_tag} - return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] - - @abstractmethod - def response_body_message_key(self) -> str: - """ - Returns the key for the message in the response body. - Subclasses should override this method to return the correct message key - where the response is located. - - :returns: The key for the message in the response body. - """ @abstractmethod def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: @@ -159,8 +155,16 @@ class AnthropicClaudeChatAdapter(BedrockModelChatAdapter): Model adapter for the Anthropic Claude chat model. """ - ANTHROPIC_USER_TOKEN = "\n\nHuman:" - ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + ALLOWED_PARAMS: ClassVar[List[str]] = [ + "anthropic_version", + "max_tokens", + "stop_sequences", + "temperature", + "top_p", + "top_k", + "system", + ] def __init__(self, generation_kwargs: Dict[str, Any]): """ @@ -183,7 +187,7 @@ def __init__(self, generation_kwargs: Dict[str, Any]): self.prompt_handler = DefaultPromptHandler( tokenizer="gpt2", model_max_length=model_max_length, - max_length=self.generation_kwargs.get("max_tokens_to_sample") or 512, + max_length=self.generation_kwargs.get("max_tokens") or 512, ) def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[str, Any]: @@ -195,46 +199,33 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ :returns: The prepared body. """ default_params = { - "max_tokens_to_sample": self.generation_kwargs.get("max_tokens_to_sample") or 512, - "stop_sequences": ["\n\nHuman:"], + "anthropic_version": self.generation_kwargs.get("anthropic_version") or "bedrock-2023-05-31", + "max_tokens": self.generation_kwargs.get("max_tokens") or 512, # max_tokens is required } # combine stop words with default stop sequences, remove stop_words as Anthropic does not support it stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences - params = self._get_params(inference_kwargs, default_params) - body = {"prompt": self.prepare_chat_messages(messages=messages), **params} + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) + body = {**self.prepare_chat_messages(messages=messages), **params} return body - def prepare_chat_messages(self, messages: List[ChatMessage]) -> str: + def prepare_chat_messages(self, messages: List[ChatMessage]) -> Dict[str, Any]: """ Prepares the chat messages for the Anthropic Claude request. :param messages: The chat messages to prepare. :returns: The prepared chat messages as a string. """ - conversation = [] - for index, message in enumerate(messages): - if message.is_from(ChatRole.USER): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_USER_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.ASSISTANT): - conversation.append(f"{AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") - elif message.is_from(ChatRole.FUNCTION): - error_message = "Anthropic does not support function calls." - raise ValueError(error_message) - elif message.is_from(ChatRole.SYSTEM) and index == 0: - # Until we transition to the new chat message format system messages will be ignored - # see https://docs.anthropic.com/claude/reference/messages_post for more details - logger.warning( - "System messages are not fully supported by the current version of Claude and will be ignored." - ) - else: - invalid_role = f"Invalid role {message.role} for message {message.content}" - raise ValueError(invalid_role) - - prepared_prompt = "".join(conversation) + AnthropicClaudeChatAdapter.ANTHROPIC_ASSISTANT_TOKEN + " " - return self._ensure_token_limit(prepared_prompt) + body: Dict[str, Any] = {} + system = messages[0].content if messages and messages[0].is_from(ChatRole.SYSTEM) else None + body["messages"] = [ + self._to_anthropic_message(m) for m in messages if m.is_from(ChatRole.USER) or m.is_from(ChatRole.ASSISTANT) + ] + if system: + body["system"] = system + return body def check_prompt(self, prompt: str) -> Dict[str, Any]: """ @@ -245,13 +236,20 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Anthropic Claude i.e. "completion". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "completion" + messages: List[ChatMessage] = [] + if response_body.get("type") == "message": + for content in response_body["content"]: + if content.get("type") == "text": + meta = {k: v for k, v in response_body.items() if k not in ["type", "content", "role"]} + messages.append(ChatMessage.from_assistant(content["text"], meta=meta)) + return messages def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ @@ -260,7 +258,17 @@ def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: :param chunk: The streaming chunk. :returns: The extracted token. """ - return chunk.get("completion", "") + if chunk.get("type") == "content_block_delta" and chunk.get("delta", {}).get("type") == "text_delta": + return chunk.get("delta", {}).get("text", "") + return "" + + def _to_anthropic_message(self, m: ChatMessage) -> Dict[str, Any]: + """ + Convert a ChatMessage to a dictionary with the content and role fields. + :param m: The ChatMessage to convert. + :return: The dictionary with the content and role fields. + """ + return {"content": [{"type": "text", "text": m.content}], "role": m.role.value} class MetaLlama2ChatAdapter(BedrockModelChatAdapter): @@ -268,6 +276,9 @@ class MetaLlama2ChatAdapter(BedrockModelChatAdapter): Model adapter for the Meta Llama 2 models. """ + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html + ALLOWED_PARAMS: ClassVar[List[str]] = ["max_gen_len", "temperature", "top_p"] + chat_template = ( "{% if messages[0]['role'] == 'system' %}" "{% set loop_messages = messages[1:] %}" @@ -327,11 +338,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ """ default_params = {"max_gen_len": self.generation_kwargs.get("max_gen_len") or 512} - # combine stop words with default stop sequences, remove stop_words as MetaLlama2 does not support it - stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) - if stop_sequences: - inference_kwargs["stop_sequences"] = stop_sequences - params = self._get_params(inference_kwargs, default_params) + # no support for stop words in Meta Llama 2 + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {"prompt": self.prepare_chat_messages(messages=messages), **params} return body @@ -357,13 +365,16 @@ def check_prompt(self, prompt: str) -> Dict[str, Any]: """ return self.prompt_handler(prompt) - def response_body_message_key(self) -> str: + def _extract_messages_from_response(self, response_body: Dict[str, Any]) -> List[ChatMessage]: """ - Returns the key for the message in the response body for Meta Llama 2 i.e. "generation". + Extracts the messages from the response body. - :returns: The key for the message in the response body. + :param response_body: The response body. + :return: The extracted ChatMessage list. """ - return "generation" + message_tag = "generation" + metadata = {k: v for (k, v) in response_body.items() if k != message_tag} + return [ChatMessage.from_assistant(response_body[message_tag], meta=metadata)] def _extract_token_from_stream(self, chunk: Dict[str, Any]) -> str: """ diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index bea6924f6..5279dc001 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -25,20 +25,21 @@ class AmazonBedrockChatGenerator: """ `AmazonBedrockChatGenerator` enables text generation via Amazon Bedrock hosted chat LLMs. - For example, to use the Anthropic Claude model, simply initialize the `AmazonBedrockChatGenerator` with the - 'anthropic.claude-v2' model name. + For example, to use the Anthropic Claude 3 Sonnet model, simply initialize the `AmazonBedrockChatGenerator` with the + 'anthropic.claude-3-sonnet-20240229-v1:0' model name. ```python from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack.dataclasses import ChatMessage from haystack.components.generators.utils import print_streaming_chunk - messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant, answer in German only"), ChatMessage.from_user("What's Natural Language Processing?")] - client = AmazonBedrockChatGenerator(model="anthropic.claude-v2", streaming_callback=print_streaming_chunk) - client.run(messages, generation_kwargs={"max_tokens_to_sample": 512}) + client = AmazonBedrockChatGenerator(model="anthropic.claude-3-sonnet-20240229-v1:0", + streaming_callback=print_streaming_chunk) + client.run(messages, generation_kwargs={"max_tokens": 512}) ``` @@ -154,7 +155,7 @@ def invoke(self, *args, **kwargs): msg = f"The model {self.model} requires a list of ChatMessage objects as a prompt." raise ValueError(msg) - body = self.model_adapter.prepare_body(messages=messages, stop_words=self.stop_words, **kwargs) + body = self.model_adapter.prepare_body(messages=messages, **{"stop_words": self.stop_words, **kwargs}) try: if self.streaming_callback: response = self.client.invoke_model_with_response_stream( diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index 9ba4d5534..6e0356d42 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -2,7 +2,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockChatGenerator from haystack_integrations.components.generators.amazon_bedrock.chat.adapters import ( @@ -11,7 +11,8 @@ MetaLlama2ChatAdapter, ) -clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" +MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] def test_to_dict(mock_boto3_session): @@ -24,7 +25,7 @@ def test_to_dict(mock_boto3_session): streaming_callback=print_streaming_chunk, ) expected_dict = { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -47,7 +48,7 @@ def test_from_dict(mock_boto3_session): """ generator = AmazonBedrockChatGenerator.from_dict( { - "type": clazz, + "type": KLASS, "init_parameters": { "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, "aws_secret_access_key": {"type": "env_var", "env_vars": ["AWS_SECRET_ACCESS_KEY"], "strict": False}, @@ -146,9 +147,9 @@ def test_prepare_body_with_default_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={}) prompt = "Hello, how are you?" expected_body = { - "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - "max_tokens_to_sample": 512, - "stop_sequences": ["\n\nHuman:"], + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 512, + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], } body = layer.prepare_body([ChatMessage.from_user(prompt)]) @@ -159,12 +160,13 @@ def test_prepare_body_with_custom_inference_params(self) -> None: layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4}) prompt = "Hello, how are you?" expected_body = { - "prompt": "\n\nHuman: Hello, how are you?\n\nAssistant: ", - "max_tokens_to_sample": 69, - "stop_sequences": ["\n\nHuman:", "CUSTOM_STOP"], + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 512, + "messages": [{"content": [{"text": "Hello, how are you?", "type": "text"}], "role": "user"}], + "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, - "top_p": 0.8, "top_k": 5, + "top_p": 0.8, } body = layer.prepare_body( @@ -173,17 +175,14 @@ def test_prepare_body_with_custom_inference_params(self) -> None: assert body == expected_body - @pytest.mark.integration - def test_get_responses(self) -> None: - adapter = AnthropicClaudeChatAdapter(generation_kwargs={}) - response_body = {"completion": "This is a single response."} - expected_response = "This is a single response." - response_message = adapter.get_responses(response_body) - # assert that the type of each item in the list is a ChatMessage - for message in response_message: - assert isinstance(message, ChatMessage) - assert response_message == [ChatMessage.from_assistant(expected_response)] +@pytest.fixture +def chat_messages(): + messages = [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + return messages class TestMetaLlama2ChatAdapter: @@ -207,13 +206,13 @@ def test_prepare_body_with_custom_inference_params(self) -> None: generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 5, "stop_sequences": ["CUSTOM_STOP"]} ) prompt = "Hello, how are you?" + + # expected body is different because stop_sequences and top_k are not supported by MetaLlama2 expected_body = { "prompt": "[INST] Hello, how are you? [/INST]", "max_gen_len": 69, - "stop_sequences": ["CUSTOM_STOP"], "temperature": 0.7, "top_p": 0.8, - "top_k": 5, } body = layer.prepare_body( @@ -238,3 +237,52 @@ def test_get_responses(self) -> None: assert isinstance(message, ChatMessage) assert response_message == [ChatMessage.from_assistant(expected_response)] + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + @pytest.mark.integration + def test_default_inference_params(self, model_name, chat_messages): + + client = AmazonBedrockChatGenerator(model=model_name) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + @pytest.mark.parametrize("model_name", MODELS_TO_TEST) + @pytest.mark.integration + def test_default_inference_with_streaming(self, model_name, chat_messages): + streaming_callback_called = False + paris_found_in_response = False + + def streaming_callback(chunk: StreamingChunk): + nonlocal streaming_callback_called, paris_found_in_response + streaming_callback_called = True + assert isinstance(chunk, StreamingChunk) + assert chunk.content is not None + if not paris_found_in_response: + paris_found_in_response = "paris" in chunk.content.lower() + + client = AmazonBedrockChatGenerator(model=model_name, streaming_callback=streaming_callback) + response = client.run(chat_messages) + + assert streaming_callback_called, "Streaming callback was not called" + assert paris_found_in_response, "The streaming callback response did not contain 'paris'" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 0e52d1e62..8b0dd4f30 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -158,24 +158,21 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["haystack_integrations", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true -omit = [ - "example" -] +parallel = false -[tool.coverage.paths] -astra_haystack = ["src"] -tests = ["tests"] [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [tool.pytest.ini_options] minversion = "6.0" markers = [ diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 90e037942..2cfcbea20 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -160,24 +160,21 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true -omit = [ - "example" -] +parallel = false -[tool.coverage.paths] -chroma_haystack = ["src/haystack_integrations", "*/chroma-haystack/src/chroma_haystack"] -tests = ["tests"] [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [tool.pytest.ini_options] minversion = "6.0" markers = [ diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 4b612aca5..523bfa918 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -128,19 +128,20 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -cohere_haystack = [ - "src/haystack_integrations", - "*/cohere/src/haystack_integrations", -] -tests = ["tests", "*/cohere/tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index c360490f3..7fd588fec 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -348,7 +348,6 @@ def __call__(self, chunk: StreamingChunk) -> None: assert message.meta["finish_reason"] == "COMPLETE" - assert callback.counter > 1 assert "Paris" in callback.responses assert message.meta["documents"] is not None diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index c3dd59535..32fdfce50 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -13,7 +13,8 @@ class TestCohereGenerator: - def test_init_default(self): + def test_init_default(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "foo") component = CohereGenerator() assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.model == "command" diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 7ccf4766c..18a873b34 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -131,19 +131,20 @@ ban-relative-imports = "all" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -deepeval_haystack = [ - "src/haystack_integrations", - "*/deepeval-haystack/src/deepeval_haystack", -] -tests = ["tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index b67df7e03..92f49f35d 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -156,21 +156,21 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -elasticsearch_haystack = ["src/haystack_integrations", "*/elasticsearch/src/haystack_integrations"] -tests = ["tests", "*/elasticsearch/src/tests"] [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [tool.pytest.ini_options] minversion = "6.0" markers = [ diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index 6ebb99142..49d6a8f17 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -158,16 +158,13 @@ ban-relative-imports = "parents" "examples/**/*" = ["T201"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true - - -[tool.coverage.paths] -fastembed_haystack = ["src/haystack_integrations", "*/fastembed-haystack/src"] -tests = ["tests", "*/fastembed-haystack/tests"] +parallel = false [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index baf21c8a3..e44e50a61 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -1,5 +1,7 @@ from typing import ClassVar, Dict, List, Optional +from tqdm import tqdm + from fastembed import TextEmbedding @@ -39,7 +41,12 @@ def __init__( ): self.model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads) - def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]: + def embed(self, data: List[str], progress_bar=True, **kwargs) -> List[List[float]]: # the embed method returns a Iterable[np.ndarray], so we convert it to a list of lists - embeddings = [np_array.tolist() for np_array in self.model.embed(data, **kwargs)] + embeddings = [] + embeddings_iterable = self.model.embed(data, **kwargs) + for np_array in tqdm( + embeddings_iterable, disable=not progress_bar, desc="Calculating embeddings", total=len(data) + ): + embeddings.append(np_array.tolist()) return embeddings diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index b5dd71231..ec0b918d9 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -131,11 +131,11 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: meta_values_to_embed = [ str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None ] - text_to_embed = [ - self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix, - ] + text_to_embed = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ) - texts_to_embed.append(text_to_embed[0]) + texts_to_embed.append(text_to_embed) return texts_to_embed @component.output_types(documents=List[Document]) @@ -157,13 +157,11 @@ def run(self, documents: List[Document]): msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - # TODO: once non textual Documents are properly supported, we should also prepare them for embedding here - texts_to_embed = self._prepare_texts_to_embed(documents=documents) embeddings = self.embedding_backend.embed( texts_to_embed, batch_size=self.batch_size, - show_progress_bar=self.progress_bar, + progress_bar=self.progress_bar, parallel=self.parallel, ) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index 743884ec1..9bc4475a5 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -35,7 +35,6 @@ def __init__( threads: Optional[int] = None, prefix: str = "", suffix: str = "", - batch_size: int = 256, progress_bar: bool = True, parallel: Optional[int] = None, ): @@ -47,7 +46,6 @@ def __init__( Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. :param threads: The number of threads single onnxruntime session can use. Defaults to None. - :param batch_size: Number of strings to encode at once. :param prefix: A string to add to the beginning of each text. :param suffix: A string to add to the end of each text. :param progress_bar: If true, displays progress bar during embedding. @@ -62,7 +60,6 @@ def __init__( self.threads = threads self.prefix = prefix self.suffix = suffix - self.batch_size = batch_size self.progress_bar = progress_bar self.parallel = parallel @@ -80,7 +77,6 @@ def to_dict(self) -> Dict[str, Any]: threads=self.threads, prefix=self.prefix, suffix=self.suffix, - batch_size=self.batch_size, progress_bar=self.progress_bar, parallel=self.parallel, ) @@ -119,8 +115,7 @@ def run(self, text: str): embedding = list( self.embedding_backend.embed( text_to_embed, - batch_size=self.batch_size, - show_progress_bar=self.progress_bar, + progress_bar=self.progress_bar, parallel=self.parallel, )[0] ) diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index 797c295ba..75fdcc9c9 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -261,7 +261,7 @@ def test_embed_metadata(self): "meta_value 4\ndocument-number 4", ], batch_size=256, - show_progress_bar=True, + progress_bar=True, parallel=None, ) diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index d5982c319..402980485 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -19,7 +19,6 @@ def test_init_default(self): assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" - assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.parallel is None @@ -33,7 +32,6 @@ def test_init_with_parameters(self): threads=2, prefix="prefix", suffix="suffix", - batch_size=64, progress_bar=False, parallel=1, ) @@ -42,7 +40,6 @@ def test_init_with_parameters(self): assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 @@ -60,7 +57,6 @@ def test_to_dict(self): "threads": None, "prefix": "", "suffix": "", - "batch_size": 256, "progress_bar": True, "parallel": None, }, @@ -76,7 +72,6 @@ def test_to_dict_with_custom_init_parameters(self): threads=2, prefix="prefix", suffix="suffix", - batch_size=64, progress_bar=False, parallel=1, ) @@ -89,7 +84,6 @@ def test_to_dict_with_custom_init_parameters(self): "threads": 2, "prefix": "prefix", "suffix": "suffix", - "batch_size": 64, "progress_bar": False, "parallel": 1, }, @@ -107,7 +101,6 @@ def test_from_dict(self): "threads": None, "prefix": "", "suffix": "", - "batch_size": 256, "progress_bar": True, "parallel": None, }, @@ -118,7 +111,6 @@ def test_from_dict(self): assert embedder.threads is None assert embedder.prefix == "" assert embedder.suffix == "" - assert embedder.batch_size == 256 assert embedder.progress_bar is True assert embedder.parallel is None @@ -134,7 +126,6 @@ def test_from_dict_with_custom_init_parameters(self): "threads": 2, "prefix": "prefix", "suffix": "suffix", - "batch_size": 64, "progress_bar": False, "parallel": 1, }, @@ -145,7 +136,6 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.threads == 2 assert embedder.prefix == "prefix" assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.parallel == 1 diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index 01cd2cf8f..3224b1dc0 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -154,11 +154,8 @@ ban-relative-imports = "parents" [tool.coverage.run] source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -google_ai_haystack = ["src"] -tests = ["tests"] [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] @@ -168,6 +165,7 @@ exclude_lines = [ "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "google.*", diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index f846d5bc4..be7cd33ac 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -148,21 +148,21 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["haystack_integrations", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -google_vertex_haystack = ["src/"] -tests = ["tests"] [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "vertexai.*", diff --git a/integrations/gradient/pyproject.toml b/integrations/gradient/pyproject.toml index f0dae134f..013b1263c 100644 --- a/integrations/gradient/pyproject.toml +++ b/integrations/gradient/pyproject.toml @@ -157,11 +157,8 @@ ban-relative-imports = "parents" [tool.coverage.run] source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -gradient_haystack = ["src"] -tests = ["tests", "*/gradient-haystack/tests"] [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] @@ -172,6 +169,7 @@ exclude_lines = [ "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "gradientai.*", diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index 565def7a5..e1e719b27 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -128,16 +128,19 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["jina_haystack", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -jina_haystack = ["src"] -tests = ["tests", "*/jina-haystack/tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] [[tool.mypy.overrides]] module = ["haystack.*", "haystack_integrations.*", "pytest.*"] diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index 1b165dcaf..5e5fb6f59 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -158,22 +158,21 @@ ban-relative-imports = "parents" "tests/**" = ["T201"] [tool.coverage.run] -source_pkgs = ["llama_cpp_haystack", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -llama_cpp_haystack = ["src/haystack_integrations", "*/llama-cpp-haystack/src"] -tests = ["tests", "*/llama-cpp-haystack/tests"] - [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", ] + [tool.pytest.ini_options] markers = [ "integration: marks tests as slow (deselect with '-m \"not integration\"')", diff --git a/integrations/mistral/pyproject.toml b/integrations/mistral/pyproject.toml index 6abf98704..512993246 100644 --- a/integrations/mistral/pyproject.toml +++ b/integrations/mistral/pyproject.toml @@ -127,19 +127,20 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -mistral_haystack = [ - "src/haystack_integrations", - "*/mistral/src/haystack_integrations", -] -tests = ["tests", "*/mistral/tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index 0021884ad..6e6b55dfe 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -156,27 +156,26 @@ ban-relative-imports = "parents" "examples/**/*" = ["T201"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -tests = ["tests", "*/mongodb-atlas-haystack/tests"] - [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [[tool.mypy.overrides]] module = [ "haystack.*", "haystack_integrations.*", - "mongodb_atlas.*", - "psycopg.*", + "pymongo.*", "pytest.*" ] ignore_missing_imports = true diff --git a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py index 432b86d4c..ffad97789 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/components/retrievers/mongodb_atlas/embedding_retriever.py @@ -48,7 +48,9 @@ def __init__( Create the MongoDBAtlasDocumentStore component. :param document_store: An instance of MongoDBAtlasDocumentStore. - :param filters: Filters applied to the retrieved Documents. + :param filters: Filters applied to the retrieved Documents. Make sure that the fields used in the filters are + included in the configuration of the `vector_search_index`. The configuration must be done manually + in the Web UI of MongoDB Atlas. :param top_k: Maximum number of Documents to return. :raises ValueError: If `document_store` is not an instance of `MongoDBAtlasDocumentStore`. diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 27cb853db..c9e8f1dae 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -10,10 +10,10 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack_integrations.document_stores.mongodb_atlas.filters import haystack_filters_to_mongo -from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne # type: ignore -from pymongo.driver_info import DriverInfo # type: ignore -from pymongo.errors import BulkWriteError # type: ignore +from haystack_integrations.document_stores.mongodb_atlas.filters import _normalize_filters +from pymongo import InsertOne, MongoClient, ReplaceOne, UpdateOne +from pymongo.driver_info import DriverInfo +from pymongo.errors import BulkWriteError logger = logging.getLogger(__name__) @@ -144,8 +144,8 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: The filters to apply. It returns only the documents that match the filters. :returns: A list of Documents that match the given filters. """ - mongo_filters = haystack_filters_to_mongo(filters) - documents = list(self.collection.find(mongo_filters)) + filters = _normalize_filters(filters) if filters else None + documents = list(self.collection.find(filters)) for doc in documents: doc.pop("_id", None) # MongoDB's internal id doesn't belong into a Haystack document, so we remove it. return [Document.from_dict(doc) for doc in documents] @@ -170,7 +170,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D if policy == DuplicatePolicy.NONE: policy = DuplicatePolicy.FAIL - mongo_documents = [doc.to_dict() for doc in documents] + mongo_documents = [doc.to_dict(flatten=False) for doc in documents] operations: List[Union[UpdateOne, InsertOne, ReplaceOne]] written_docs = len(documents) @@ -221,7 +221,8 @@ def _embedding_retrieval( msg = "Query embedding must not be empty" raise ValueError(msg) - filters = haystack_filters_to_mongo(filters) + filters = _normalize_filters(filters) if filters else None + pipeline = [ { "$vectorSearch": { @@ -230,7 +231,7 @@ def _embedding_retrieval( "queryVector": query_embedding, "numCandidates": 100, "limit": top_k, - # "filter": filters, + "filter": filters, } }, { @@ -249,6 +250,11 @@ def _embedding_retrieval( documents = list(self.collection.aggregate(pipeline)) except Exception as e: msg = f"Retrieval of documents from MongoDB Atlas failed: {e}" + if filters: + msg += ( + "\nMake sure that the fields used in the filters are included " + "in the `vector_search_index` configuration" + ) raise DocumentStoreError(msg) from e documents = [self._mongo_doc_to_haystack_doc(doc) for doc in documents] diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py deleted file mode 100644 index 132156bd0..000000000 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/errors.py +++ /dev/null @@ -1,4 +0,0 @@ -class MongoDBAtlasDocumentStoreError(Exception): - """Exception for issues that occur in a MongoDBAtlas document store""" - - pass diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py index f03ca88c0..4583d6cd3 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/filters.py @@ -1,9 +1,152 @@ -from typing import Any, Dict, Optional +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from typing import Any, Dict +from haystack.errors import FilterError +from haystack.utils.filters import convert +from pandas import DataFrame -def haystack_filters_to_mongo(filters: Optional[Dict[str, Any]]): - # TODO - if filters: - msg = "Filtering not yet implemented for MongoDBAtlasDocumentStore" - raise ValueError(msg) - return {} +UNSUPPORTED_TYPES_FOR_COMPARISON = (list, DataFrame) + + +def _normalize_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts Haystack filters to MongoDB filters. + """ + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise FilterError(msg) + + if "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + # logical conditions can be nested, so we need to parse them recursively + conditions = [] + for c in condition["conditions"]: + if "field" in c: + conditions.append(_parse_comparison_condition(c)) + else: + conditions.append(_parse_logical_condition(c)) + + operator = condition["operator"] + if operator == "AND": + return {"$and": conditions} + elif operator == "OR": + return {"$or": conditions} + elif operator == "NOT": + # MongoDB doesn't support our NOT operator (logical NAND) directly. + # we combine $nor and $and to achieve the same effect. + return {"$nor": [{"$and": conditions}]} + + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR', 'NOT'" + raise FilterError(msg) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: + field: str = condition["field"] + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + value: Any = condition["value"] + + if isinstance(value, DataFrame): + value = value.to_json() + + return COMPARISON_OPERATORS[operator](field, value) + + +def _equal(field: str, value: Any) -> Dict[str, Any]: + return {field: {"$eq": value}} + + +def _not_equal(field: str, value: Any) -> Dict[str, Any]: + return {field: {"$ne": value}} + + +def _validate_type_for_comparison(value: Any) -> None: + msg = f"Cant compare {type(value)} using operators '>', '>=', '<', '<='." + if isinstance(value, UNSUPPORTED_TYPES_FOR_COMPARISON): + raise FilterError(msg) + elif isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg += "\nStrings are only comparable if they are ISO formatted dates." + raise FilterError(msg) from exc + + +def _greater_than(field: str, value: Any) -> Dict[str, Any]: + _validate_type_for_comparison(value) + return {field: {"$gt": value}} + + +def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # we want {field: {"$gte": null}} to return an empty result + # $gte with null values in MongoDB returns a non-empty result, while $gt aligns with our expectations + return {field: {"$gt": value}} + + _validate_type_for_comparison(value) + return {field: {"$gte": value}} + + +def _less_than(field: str, value: Any) -> Dict[str, Any]: + _validate_type_for_comparison(value) + return {field: {"$lt": value}} + + +def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: + if value is None: + # we want {field: {"$lte": null}} to return an empty result + # $lte with null values in MongoDB returns a non-empty result, while $lt aligns with our expectations + return {field: {"$lt": value}} + _validate_type_for_comparison(value) + + return {field: {"$lte": value}} + + +def _not_in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + return {field: {"$nin": value}} + + +def _in(field: str, value: Any) -> Dict[str, Any]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + return {field: {"$in": value}} + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} diff --git a/integrations/mongodb_atlas/tests/test_document_store.py b/integrations/mongodb_atlas/tests/test_document_store.py index 39a4465c1..89810ec8b 100644 --- a/integrations/mongodb_atlas/tests/test_document_store.py +++ b/integrations/mongodb_atlas/tests/test_document_store.py @@ -8,42 +8,43 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy -from haystack.testing.document_store import CountDocumentsTest, DeleteDocumentsTest, WriteDocumentsTest +from haystack.testing.document_store import DocumentStoreBaseTests from haystack.utils import Secret from haystack_integrations.document_stores.mongodb_atlas import MongoDBAtlasDocumentStore from pandas import DataFrame -from pymongo import MongoClient # type: ignore -from pymongo.driver_info import DriverInfo # type: ignore - - -@pytest.fixture -def document_store(): - database_name = "haystack_integration_test" - collection_name = "test_collection_" + str(uuid4()) - - connection: MongoClient = MongoClient( - os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") - ) - database = connection[database_name] - if collection_name in database.list_collection_names(): - database[collection_name].drop() - database.create_collection(collection_name) - database[collection_name].create_index("id", unique=True) - - store = MongoDBAtlasDocumentStore( - database_name=database_name, - collection_name=collection_name, - vector_search_index="cosine_index", - ) - yield store - database[collection_name].drop() +from pymongo import MongoClient +from pymongo.driver_info import DriverInfo @pytest.mark.skipif( "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) -class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): +@pytest.mark.integration +class TestDocumentStore(DocumentStoreBaseTests): + + @pytest.fixture + def document_store(self): + database_name = "haystack_integration_test" + collection_name = "test_collection_" + str(uuid4()) + + connection: MongoClient = MongoClient( + os.environ["MONGO_CONNECTION_STRING"], driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + ) + database = connection[database_name] + if collection_name in database.list_collection_names(): + database[collection_name].drop() + database.create_collection(collection_name) + database[collection_name].create_index("id", unique=True) + + store = MongoDBAtlasDocumentStore( + database_name=database_name, + collection_name=collection_name, + vector_search_index="cosine_index", + ) + yield store + database[collection_name].drop() + def test_write_documents(self, document_store: MongoDBAtlasDocumentStore): docs = [Document(content="some text")] assert document_store.write_documents(docs) == 1 @@ -104,3 +105,37 @@ def test_from_dict(self): assert docstore.database_name == "haystack_integration_test" assert docstore.collection_name == "test_embeddings_collection" assert docstore.vector_search_index == "cosine_index" + + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) diff --git a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py index 54bbdedfd..a03c735e0 100644 --- a/integrations/mongodb_atlas/tests/test_embedding_retrieval.py +++ b/integrations/mongodb_atlas/tests/test_embedding_retrieval.py @@ -13,6 +13,7 @@ "MONGO_CONNECTION_STRING" not in os.environ, reason="No MongoDB Atlas connection string provided", ) +@pytest.mark.integration class TestEmbeddingRetrieval: def test_embedding_retrieval_cosine_similarity(self): document_store = MongoDBAtlasDocumentStore( @@ -72,3 +73,34 @@ def test_query_embedding_wrong_dimension(self): query_embedding = [0.1] * 4 with pytest.raises(DocumentStoreError): document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_embedding_retrieval_with_filters(self): + """ + Note: we can combine embedding retrieval with filters + becuse the `cosine_index` vector_search_index was created with the `content` field as the filter field. + { + "fields": [ + { + "type": "vector", + "path": "embedding", + "numDimensions": 768, + "similarity": "cosine" + }, + { + "type": "filter", + "path": "content" + } + ] + } + """ + document_store = MongoDBAtlasDocumentStore( + database_name="haystack_integration_test", + collection_name="test_embeddings_collection", + vector_search_index="cosine_index", + ) + query_embedding = [0.1] * 768 + filters = {"field": "content", "operator": "!=", "value": "Document A"} + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=2, filters=filters) + assert len(results) == 2 + for doc in results: + assert doc.content != "Document A" diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index ba25812a8..f443e91f9 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -129,19 +129,20 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -nvidia_haystack = [ - "src/haystack_integrations", - "*/nvidia/src/haystack_integrations", -] -tests = ["tests", "*/nvidia/tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_schema.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_schema.py index a0598be86..fc4e0e5bf 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_schema.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/_schema.py @@ -1,31 +1,10 @@ from dataclasses import asdict, dataclass from typing import Any, Dict, List, Literal, Union -from haystack_integrations.utils.nvidia import NvidiaCloudFunctionsClient - -from .models import NvidiaEmbeddingModel - MAX_INPUT_STRING_LENGTH = 2048 MAX_INPUTS = 50 -def get_model_nvcf_id(model: NvidiaEmbeddingModel, client: NvidiaCloudFunctionsClient) -> str: - """ - Returns the Nvidia Cloud Functions UUID for the given model. - """ - - available_functions = client.available_functions() - func = available_functions.get(str(model)) - if func is None: - msg = f"Model '{model}' was not found on the Nvidia Cloud Functions backend" - raise ValueError(msg) - elif func.status != "ACTIVE": - msg = f"Model '{model}' is not currently active/usable on the Nvidia Cloud Functions backend" - raise ValueError(msg) - - return func.id - - @dataclass class EmbeddingsRequest: input: Union[str, List[str]] diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index bbc68b492..25c104b97 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -5,7 +5,7 @@ from haystack_integrations.utils.nvidia import NvidiaCloudFunctionsClient from tqdm import tqdm -from ._schema import MAX_INPUTS, EmbeddingsRequest, EmbeddingsResponse, Usage, get_model_nvcf_id +from ._schema import MAX_INPUTS, EmbeddingsRequest, EmbeddingsResponse, Usage from .models import NvidiaEmbeddingModel @@ -96,7 +96,7 @@ def warm_up(self): if self._initialized: return - self.nvcf_id = get_model_nvcf_id(self.model, self.client) + self.nvcf_id = self.client.get_model_nvcf_id(str(self.model)) self._initialized = True def to_dict(self) -> Dict[str, Any]: diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index a2636b4b8..a377934e3 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -4,7 +4,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.utils.nvidia import NvidiaCloudFunctionsClient -from ._schema import EmbeddingsRequest, EmbeddingsResponse, get_model_nvcf_id +from ._schema import EmbeddingsRequest, EmbeddingsResponse from .models import NvidiaEmbeddingModel @@ -74,7 +74,7 @@ def warm_up(self): if self._initialized: return - self.nvcf_id = get_model_nvcf_id(self.model, self.client) + self.nvcf_id = self.client.get_model_nvcf_id(str(self.model)) self._initialized = True def to_dict(self) -> Dict[str, Any]: diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py index e873bc332..3a315843d 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py @@ -1,3 +1,7 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH +# SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .generator import NvidiaGenerator +from .models import NvidiaGeneratorModel + +__all__ = ["NvidiaGenerator", "NvidiaGeneratorModel"] diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py new file mode 100644 index 000000000..4e19d05ac --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/_schema.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional + + +@dataclass +class Message: + content: str + role: str + + +@dataclass +class GenerationRequest: + messages: List[Message] + temperature: float = 0.2 + top_p: float = 0.7 + max_tokens: int = 1024 + seed: Optional[int] = None + bad: Optional[List[str]] = None + stop: Optional[List[str]] = None + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class Choice: + index: int + message: Message + finish_reason: str + + +@dataclass +class Usage: + completion_tokens: int + prompt_tokens: int + total_tokens: int + + +@dataclass +class GenerationResponse: + id: str + choices: List[Choice] + usage: Usage + + @classmethod + def from_dict(cls, data: dict) -> "GenerationResponse": + try: + return cls( + id=data["id"], + choices=[ + Choice( + index=choice["index"], + message=Message(content=choice["message"]["content"], role=choice["message"]["role"]), + finish_reason=choice["finish_reason"], + ) + for choice in data["choices"] + ], + usage=Usage( + completion_tokens=data["usage"]["completion_tokens"], + prompt_tokens=data["usage"]["prompt_tokens"], + total_tokens=data["usage"]["total_tokens"], + ), + ) + except (KeyError, TypeError) as e: + msg = f"Failed to parse {cls.__name__} from data: {data}" + raise ValueError(msg) from e diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/__init__.py index e873bc332..6b5e14dc1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/chat/__init__.py @@ -1,3 +1,3 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH +# SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py new file mode 100644 index 000000000..46550baab --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack_integrations.utils.nvidia import NvidiaCloudFunctionsClient + +from ._schema import GenerationRequest, GenerationResponse, Message +from .models import NvidiaGeneratorModel + + +@component +class NvidiaGenerator: + """ + A component for generating text using generative models provided by + [NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/). + + Usage example: + ```python + from haystack_integrations.components.generators.nvidia import NvidiaGenerator, NvidiaGeneratorModel + + generator = NvidiaGenerator( + model=NvidiaGeneratorModel.NV_LLAMA2_RLHF_70B, + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + }, + ) + generator.warm_up() + + result = generator.run(prompt="What is the answer?") + print(result["replies"]) + print(result["meta"]) + ``` + """ + + def __init__( + self, + model: Union[str, NvidiaGeneratorModel], + api_key: Secret = Secret.from_env_var("NVIDIA_API_KEY"), + model_arguments: Optional[Dict[str, Any]] = None, + ): + """ + Create a NvidiaGenerator component. + + :param model: + Name of the model to use for text generation. + See the [Nvidia catalog](https://catalog.ngc.nvidia.com/ai-foundation-models) + for more information on the supported models. + :param api_key: + Nvidia API key to use for authentication. + :param model_arguments: + Additional arguments to pass to the model provider. Different models accept different arguments. + Search your model in the [Nvidia catalog](https://catalog.ngc.nvidia.com/ai-foundation-models) + to know the supported arguments. + + :raises ValueError: If `model` is not supported. + """ + if isinstance(model, str): + model = NvidiaGeneratorModel.from_str(model) + + self._model = model + self._api_key = api_key + self._model_arguments = model_arguments or {} + # This is initialized in warm_up + self._model_id = None + + self._client = NvidiaCloudFunctionsClient( + api_key=api_key, + headers={ + "Content-Type": "application/json", + "Accept": "application/json", + }, + ) + + def warm_up(self): + """ + Initializes the component. + """ + if self._model_id is not None: + return + self._model_id = self._client.get_model_nvcf_id(str(self._model)) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, model=str(self._model), api_key=self._api_key.to_dict(), model_arguments=self._model_arguments + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "NvidiaGenerator": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) + return default_from_dict(cls, data) + + @component.output_types(replies=List[str], meta=List[Dict[str, Any]], usage=Dict[str, int]) + def run(self, prompt: str): + """ + Queries the model with the provided prompt. + + :param prompt: + Text to be sent to the generative model. + :returns: + A dictionary with the following keys: + - `replies` - Replies generated by the model. + - `meta` - Metadata for each reply. + - `usage` - Usage statistics for the model. + """ + if self._model_id is None: + msg = "The generation model has not been loaded. Call warm_up() before running." + raise RuntimeError(msg) + + messages = [Message(role="user", content=prompt)] + request = GenerationRequest(messages=messages, **self._model_arguments).to_dict() + json_response = self._client.query_function(self._model_id, request) + + replies = [] + meta = [] + data = GenerationResponse.from_dict(json_response) + for choice in data.choices: + replies.append(choice.message.content) + meta.append( + { + "role": choice.message.role, + "finish_reason": choice.finish_reason, + } + ) + + usage = { + "completion_tokens": data.usage.completion_tokens, + "prompt_tokens": data.usage.prompt_tokens, + "total_tokens": data.usage.total_tokens, + } + + return {"replies": replies, "meta": meta, "usage": usage} diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/models.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/models.py new file mode 100644 index 000000000..448fb7aec --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/models.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from enum import Enum + + +class NvidiaGeneratorModel(Enum): + """ + Generator models supported by NvidiaGenerator and NvidiaChatGenerator. + """ + + NV_LLAMA2_RLHF_70B = "playground_nv_llama2_rlhf_70b" + STEERLM_LLAMA_70B = "playground_steerlm_llama_70b" + NEMOTRON_STEERLM_8B = "playground_nemotron_steerlm_8b" + NEMOTRON_QA_8B = "playground_nemotron_qa_8b" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "NvidiaGeneratorModel": + """ + Create a generator model from a string. + + :param string: + String to convert. + :returns: + A generator model. + """ + enum_map = {e.value: e for e in NvidiaGeneratorModel} + models = enum_map.get(string) + if models is None: + msg = f"Unknown model '{string}'. Supported models are: {list(enum_map.keys())}" + raise ValueError(msg) + return models diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py index e582b09ba..b486f05b3 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/client.py @@ -64,3 +64,19 @@ def available_functions(self) -> Dict[str, AvailableNvidiaCloudFunctions]: ) for f in response.json()["functions"] } + + def get_model_nvcf_id(self, model: str) -> str: + """ + Returns the Nvidia Cloud Functions UUID for the given model. + """ + + available_functions = self.available_functions() + func = available_functions.get(model) + if func is None: + msg = f"Model '{model}' was not found on the Nvidia Cloud Functions backend" + raise ValueError(msg) + elif func.status != "ACTIVE": + msg = f"Model '{model}' is not currently active/usable on the Nvidia Cloud Functions backend" + raise ValueError(msg) + + return func.id diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index 4f19633e8..ed8af93c9 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -20,6 +20,9 @@ def available_functions(self): ) } + def get_model_nvcf_id(self, model): + return "fake-id" + class TestNvidiaDocumentEmbedder: def test_init_default(self, monkeypatch): diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py new file mode 100644 index 000000000..b10b60951 --- /dev/null +++ b/integrations/nvidia/tests/test_generator.py @@ -0,0 +1,176 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import patch + +import pytest +from haystack.utils import Secret +from haystack_integrations.components.generators.nvidia import NvidiaGenerator +from haystack_integrations.components.generators.nvidia.models import NvidiaGeneratorModel + + +class TestNvidiaGenerator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaGenerator(NvidiaGeneratorModel.NV_LLAMA2_RLHF_70B) + + assert generator._api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert generator._model == NvidiaGeneratorModel.NV_LLAMA2_RLHF_70B + assert generator._model_arguments == {} + + def test_init_with_parameters(self): + generator = NvidiaGenerator( + api_key=Secret.from_token("fake-api-key"), + model="playground_nemotron_steerlm_8b", + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + }, + ) + assert generator._api_key == Secret.from_token("fake-api-key") + assert generator._model == NvidiaGeneratorModel.NEMOTRON_STEERLM_8B + assert generator._model_arguments == { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + } + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("NVIDIA_API_KEY", raising=False) + with pytest.raises(ValueError): + NvidiaGenerator("playground_nemotron_steerlm_8b") + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaGenerator(NvidiaGeneratorModel.NEMOTRON_STEERLM_8B) + data = generator.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.nvidia.generator.NvidiaGenerator", + "init_parameters": { + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "model": "playground_nemotron_steerlm_8b", + "model_arguments": {}, + }, + } + + def test_to_dict_with_custom_init_parameters(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaGenerator( + model=NvidiaGeneratorModel.NEMOTRON_STEERLM_8B, + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + }, + ) + data = generator.to_dict() + assert data == { + "type": "haystack_integrations.components.generators.nvidia.generator.NvidiaGenerator", + "init_parameters": { + "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, + "model": "playground_nemotron_steerlm_8b", + "model_arguments": { + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + }, + }, + } + + @patch("haystack_integrations.components.generators.nvidia.generator.NvidiaCloudFunctionsClient") + def test_run(self, mock_client): + generator = NvidiaGenerator( + model=NvidiaGeneratorModel.NEMOTRON_STEERLM_8B, + api_key=Secret.from_token("fake-api-key"), + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + }, + ) + mock_client.get_model_nvcf_id.return_value = "some_id" + generator._client = mock_client + generator.warm_up() + mock_client.get_model_nvcf_id.assert_called_once_with("playground_nemotron_steerlm_8b") + + mock_client.query_function.return_value = { + "id": "some_id", + "choices": [ + { + "index": 0, + "message": {"content": "42", "role": "assistant"}, + "finish_reason": "stop", + } + ], + "usage": {"total_tokens": 21, "prompt_tokens": 19, "completion_tokens": 2}, + } + result = generator.run(prompt="What is the answer?") + mock_client.query_function.assert_called_once_with( + "some_id", + { + "messages": [ + {"content": "What is the answer?", "role": "user"}, + ], + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + }, + ) + assert result == { + "replies": ["42"], + "meta": [ + { + "finish_reason": "stop", + "role": "assistant", + }, + ], + "usage": { + "total_tokens": 21, + "prompt_tokens": 19, + "completion_tokens": 2, + }, + } + + @pytest.mark.skipif( + not os.environ.get("NVIDIA_API_KEY", None), + reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", + ) + @pytest.mark.integration + def test_run_integration(self): + generator = NvidiaGenerator( + model=NvidiaGeneratorModel.NV_LLAMA2_RLHF_70B, + model_arguments={ + "temperature": 0.2, + "top_p": 0.7, + "max_tokens": 1024, + "seed": None, + "bad": None, + "stop": None, + }, + ) + generator.warm_up() + result = generator.run(prompt="What is the answer?") + + assert result["replies"] + assert result["meta"] + assert result["usage"] diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index b4239308b..8ba2f6783 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -18,6 +18,9 @@ def available_functions(self): ) } + def get_model_nvcf_id(self, model): + return "fake-id" + class TestNvidiaTextEmbedder: def test_init_default(self, monkeypatch): diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 794fa73fa..f2ed07c18 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -157,21 +157,21 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -opensearch_haystack = ["src/haystack_integrations", "*/opensearch-haystack/src"] -tests = ["tests", "*/opensearch-haystack/tests"] [tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true exclude_lines = [ "no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + [tool.pytest.ini_options] minversion = "6.0" markers = [ diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index d89e02b2c..e6903b4dc 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -148,16 +148,20 @@ ban-relative-imports = "parents" "tests/**" = ["T201"] [tool.coverage.run] -source_pkgs = ["optimum", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -optimum = ["src/haystack_integrations"] -tests = ["tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index 33eec8789..d0639983e 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -143,17 +143,19 @@ ban-relative-imports = "parents" "examples/**/*" = ["T201"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true -omit = ["examples"] +parallel = false -[tool.coverage.paths] -pinecone_haystack = ["src/*"] -tests = ["tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] [tool.pytest.ini_options] minversion = "6.0" diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index 1db58ea0d..40a97d9b9 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -127,21 +127,22 @@ ban-relative-imports = "parents" # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] + [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -qdrant_haystack = [ - "src/qdrant_haystack", - "*/qdrant-haystack/src/qdrant_haystack", +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", ] -tests = ["tests", "*/qdrant-haystack/tests"] -[tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ diff --git a/integrations/ragas/pyproject.toml b/integrations/ragas/pyproject.toml index 080dc68e4..504f6ddb1 100644 --- a/integrations/ragas/pyproject.toml +++ b/integrations/ragas/pyproject.toml @@ -132,19 +132,20 @@ ban-relative-imports = "all" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -ragas_haystack = [ - "src/haystack_integrations", - "*/ragas-haystack/src/ragas_haystack", -] -tests = ["tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py index 71dacd6c7..5c8613553 100644 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py +++ b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/evaluator.py @@ -11,7 +11,6 @@ from .metrics import ( METRIC_DESCRIPTORS, InputConverters, - MetricParamsValidator, OutputConverters, RagasMetric, ) @@ -66,7 +65,7 @@ def __init__( on required parameters. """ self.metric = metric if isinstance(metric, RagasMetric) else RagasMetric.from_str(metric) - self.metric_params = metric_params or {} + self.metric_params = metric_params self.descriptor = METRIC_DESCRIPTORS[self.metric] self._init_backend() @@ -79,10 +78,24 @@ def _init_backend(self): self._backend_callable = RagasEvaluator._invoke_evaluate def _init_metric(self): - MetricParamsValidator.validate_metric_parameters( - self.metric, self.descriptor.init_parameters, self.metric_params - ) - self._backend_metric = self.descriptor.backend(**self.metric_params) + if self.descriptor.init_parameters is not None: + if self.metric_params is None: + msg = f"Ragas metric '{self.metric}' expected init parameters but got none" + raise ValueError(msg) + elif not all(k in self.descriptor.init_parameters for k in self.metric_params.keys()): + msg = ( + f"Invalid init parameters for Ragas metric '{self.metric}'. " + f"Expected: {self.descriptor.init_parameters}" + ) + raise ValueError(msg) + elif self.metric_params is not None: + msg = ( + f"Invalid init parameters for Ragas metric '{self.metric}'. " + f"None expected but {self.metric_params} given" + ) + raise ValueError(msg) + metric_params = self.metric_params or {} + self._backend_metric = self.descriptor.backend(**metric_params) @staticmethod def _invoke_evaluate(dataset: Dataset, metric: Metric) -> Result: diff --git a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py index 72f3e8a3b..ed807aa81 100644 --- a/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py +++ b/integrations/ragas/src/haystack_integrations/components/evaluators/ragas/metrics.py @@ -134,7 +134,7 @@ class MetricDescriptor: backend: Type[Metric] input_parameters: Dict[str, Type] input_converter: Callable[[Any], Iterable[Dict[str, str]]] - output_converter: Callable[[Result, RagasMetric, Dict[str, Any]], List[MetricResult]] + output_converter: Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]] init_parameters: Optional[List[str]] = None @classmethod @@ -143,7 +143,9 @@ def new( metric: RagasMetric, backend: Type[Metric], input_converter: Callable[[Any], Iterable[Dict[str, str]]], - output_converter: Optional[Callable[[Result, RagasMetric, Dict[str, Any]], List[MetricResult]]] = None, + output_converter: Optional[ + Callable[[Result, RagasMetric, Optional[Dict[str, Any]]], List[MetricResult]] + ] = None, *, init_parameters: Optional[List[str]] = None, ) -> "MetricDescriptor": @@ -166,24 +168,6 @@ def new( ) -class MetricParamsValidator: - """ - Validates metric parameters. - - Depending on the metric type, different metric parameters are allowed. - The validator functions are responsible for validating the parameters and raising an error if they are invalid. - """ - - @staticmethod - def validate_metric_parameters(metric: RagasMetric, allowed: List[str], received: Dict[str, Any]) -> None: - if not set(received).issubset(allowed): - msg = ( - f"Invalid init parameters for Ragas metric '{metric}'. " - f"Allowed metric parameters {allowed} but got '{received}'" - ) - raise ValueError(msg) - - class InputConverters: """ Converters for input parameters. @@ -292,12 +276,15 @@ def _extract_default_results(output: Result, metric_name: str) -> List[MetricRes raise ValueError(msg) from e @staticmethod - def default(output: Result, metric: RagasMetric, _: Dict) -> List[MetricResult]: + def default(output: Result, metric: RagasMetric, _: Optional[Dict]) -> List[MetricResult]: metric_name = metric.value return OutputConverters._extract_default_results(output, metric_name) @staticmethod - def aspect_critique(output: Result, _: RagasMetric, metric_params: Dict[str, Any]) -> List[MetricResult]: + def aspect_critique(output: Result, _: RagasMetric, metric_params: Optional[Dict[str, Any]]) -> List[MetricResult]: + if metric_params is None: + msg = "Aspect critique metric requires metric parameters" + raise ValueError(msg) metric_name = metric_params["name"] return OutputConverters._extract_default_results(output, metric_name) @@ -307,55 +294,50 @@ def aspect_critique(output: Result, _: RagasMetric, metric_params: Dict[str, Any RagasMetric.ANSWER_CORRECTNESS, AnswerCorrectness, InputConverters.question_response_ground_truth, # type: ignore - init_parameters=["name", "weights", "answer_similarity"], + init_parameters=["weights"], ), RagasMetric.FAITHFULNESS: MetricDescriptor.new( RagasMetric.FAITHFULNESS, Faithfulness, InputConverters.question_context_response, # type: ignore - init_parameters=["name"], ), RagasMetric.ANSWER_SIMILARITY: MetricDescriptor.new( RagasMetric.ANSWER_SIMILARITY, AnswerSimilarity, InputConverters.response_ground_truth, # type: ignore - init_parameters=["name", "model_name", "threshold"], + init_parameters=["threshold"], ), RagasMetric.CONTEXT_PRECISION: MetricDescriptor.new( RagasMetric.CONTEXT_PRECISION, ContextPrecision, InputConverters.question_context_ground_truth, # type: ignore - init_parameters=["name"], ), RagasMetric.CONTEXT_UTILIZATION: MetricDescriptor.new( RagasMetric.CONTEXT_UTILIZATION, ContextUtilization, InputConverters.question_context_response, # type: ignore - init_parameters=["name"], ), RagasMetric.CONTEXT_RECALL: MetricDescriptor.new( RagasMetric.CONTEXT_RECALL, ContextRecall, InputConverters.question_context_ground_truth, # type: ignore - init_parameters=["name"], ), RagasMetric.ASPECT_CRITIQUE: MetricDescriptor.new( RagasMetric.ASPECT_CRITIQUE, AspectCritique, InputConverters.question_context_response, # type: ignore OutputConverters.aspect_critique, - init_parameters=["name", "definition", "strictness", "llm"], + init_parameters=["name", "definition", "strictness"], ), RagasMetric.CONTEXT_RELEVANCY: MetricDescriptor.new( RagasMetric.CONTEXT_RELEVANCY, ContextRelevancy, InputConverters.question_context, # type: ignore - init_parameters=["name"], ), RagasMetric.ANSWER_RELEVANCY: MetricDescriptor.new( RagasMetric.ANSWER_RELEVANCY, AnswerRelevancy, InputConverters.question_context_response, # type: ignore - init_parameters=["name", "strictness", "embeddings"], + init_parameters=["strictness"], ), } diff --git a/integrations/ragas/tests/test_evaluator.py b/integrations/ragas/tests/test_evaluator.py index 09f8cdfcc..0decc96cd 100644 --- a/integrations/ragas/tests/test_evaluator.py +++ b/integrations/ragas/tests/test_evaluator.py @@ -5,11 +5,10 @@ import pytest from datasets import Dataset from haystack import DeserializationError +from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric from ragas.evaluation import Result from ragas.metrics.base import Metric -from haystack_integrations.components.evaluators.ragas import RagasEvaluator, RagasMetric - DEFAULT_QUESTIONS = [ "Which is the most popular global sport?", "Who created the Python language?", @@ -59,43 +58,58 @@ def evaluate(self, _, metric: Metric, **kwargs): return output_map[self.metric] -def test_evaluator_metric_init_params(): +@pytest.mark.parametrize( + "metric, init_params", + [ + (RagasMetric.ANSWER_CORRECTNESS, {"weights": [0.5, 0.5]}), + (RagasMetric.FAITHFULNESS, None), + (RagasMetric.ANSWER_SIMILARITY, {"threshold": 0.5}), + (RagasMetric.CONTEXT_PRECISION, None), + (RagasMetric.CONTEXT_UTILIZATION, None), + (RagasMetric.CONTEXT_RECALL, None), + ( + RagasMetric.ASPECT_CRITIQUE, + { + "name": "harmfulness", + "definition": "Does the submission causes or has the potential to " + "cause harm to individuals, groups, or society at " + "large?", + }, + ), + (RagasMetric.CONTEXT_RELEVANCY, None), + (RagasMetric.ANSWER_RELEVANCY, {"strictness": 2}), + ], +) +def test_evaluator_valid_metric_init_params(metric, init_params): eval = RagasEvaluator( - RagasMetric.ASPECT_CRITIQUE, - metric_params={ - 'name': 'harmfulness', - 'definition': 'Does the submission causes or has the potential to cause ' - 'harm to individuals, groups, or society at large?', - }, + metric, + metric_params=init_params, ) - assert eval.metric_params == { - 'definition': 'Does the submission causes or has the potential to cause harm to ' - 'individuals, groups, or society at large?', - 'name': 'harmfulness', - } + assert eval.metric_params == init_params - with pytest.raises(ValueError, match="Expects a name"): - RagasEvaluator(RagasMetric.ASPECT_CRITIQUE, metric_params=None) - - with pytest.raises(ValueError, match="Expects a name"): - RagasEvaluator(RagasMetric.ASPECT_CRITIQUE, metric_params={}) - - with pytest.raises(ValueError, match="Expects a name"): + msg = f"Invalid init parameters for Ragas metric '{metric}'. " + with pytest.raises(ValueError, match=msg): RagasEvaluator( - RagasMetric.ASPECT_CRITIQUE, - metric_params={"definition": "custom definition"}, + metric, + metric_params={"invalid_param": "invalid_value"}, ) - with pytest.raises(ValueError, match="Expects definition"): - RagasEvaluator( - RagasMetric.ASPECT_CRITIQUE, - metric_params={"name": "custom name"}, - ) - with pytest.raises(ValueError, match="Invalid init parameters"): +@pytest.mark.parametrize( + "metric", + [ + RagasMetric.ANSWER_CORRECTNESS, + RagasMetric.ANSWER_SIMILARITY, + RagasMetric.ASPECT_CRITIQUE, + RagasMetric.ANSWER_RELEVANCY, + ], +) +def test_evaluator_fails_with_no_metric_init_params(metric): + msg = f"Ragas metric '{metric}' expected init parameters but got none" + with pytest.raises(ValueError, match=msg): RagasEvaluator( - RagasMetric.FAITHFULNESS, - metric_params={"check_numbers": True}, + metric, + metric_params=None, ) @@ -103,10 +117,10 @@ def test_evaluator_serde(): init_params = { "metric": RagasMetric.ASPECT_CRITIQUE, "metric_params": { - 'name': 'harmfulness', - 'definition': 'Does the submission causes or has the potential to ' - 'cause harm to individuals, groups, or society at ' - 'large?', + "name": "harmfulness", + "definition": "Does the submission causes or has the potential to " + "cause harm to individuals, groups, or society at " + "large?", }, } eval = RagasEvaluator(**init_params) @@ -126,9 +140,13 @@ def test_evaluator_serde(): @pytest.mark.parametrize( "current_metric, inputs, params", [ - (RagasMetric.ANSWER_CORRECTNESS, {"questions": [], "responses": [], "ground_truths": []}, None), + ( + RagasMetric.ANSWER_CORRECTNESS, + {"questions": [], "responses": [], "ground_truths": []}, + {"weights": [0.5, 0.5]}, + ), (RagasMetric.FAITHFULNESS, {"questions": [], "contexts": [], "responses": []}, None), - (RagasMetric.ANSWER_SIMILARITY, {"responses": [], "ground_truths": []}, None), + (RagasMetric.ANSWER_SIMILARITY, {"responses": [], "ground_truths": []}, {"threshold": 0.5}), (RagasMetric.CONTEXT_PRECISION, {"questions": [], "contexts": [], "ground_truths": []}, None), (RagasMetric.CONTEXT_UTILIZATION, {"questions": [], "contexts": [], "responses": []}, None), (RagasMetric.CONTEXT_RECALL, {"questions": [], "contexts": [], "ground_truths": []}, None), @@ -136,14 +154,14 @@ def test_evaluator_serde(): RagasMetric.ASPECT_CRITIQUE, {"questions": [], "contexts": [], "responses": []}, { - 'name': 'harmfulness', - 'definition': 'Does the submission causes or has the potential to ' - 'cause harm to individuals, groups, or society at ' - 'large?', + "name": "harmfulness", + "definition": "Does the submission causes or has the potential to " + "cause harm to individuals, groups, or society at " + "large?", }, ), (RagasMetric.CONTEXT_RELEVANCY, {"questions": [], "contexts": []}, None), - (RagasMetric.ANSWER_RELEVANCY, {"questions": [], "contexts": [], "responses": []}, None), + (RagasMetric.ANSWER_RELEVANCY, {"questions": [], "contexts": [], "responses": []}, {"strictness": 2}), ], ) def test_evaluator_valid_inputs(current_metric, inputs, params): @@ -170,9 +188,9 @@ def test_evaluator_valid_inputs(current_metric, inputs, params): RagasMetric.ANSWER_RELEVANCY, {"questions": [""], "responses": [], "contexts": []}, "Mismatching counts ", - None, + {"strictness": 2}, ), - (RagasMetric.ANSWER_RELEVANCY, {"responses": []}, "expected input parameter ", None), + (RagasMetric.ANSWER_RELEVANCY, {"responses": []}, "expected input parameter ", {"strictness": 2}), ], ) def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): @@ -195,7 +213,7 @@ def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): RagasMetric.ANSWER_CORRECTNESS, {"questions": ["q1"], "responses": ["r1"], "ground_truths": ["gt1"]}, [[(None, 0.5)]], - None, + {"weights": [0.5, 0.5]}, ), ( RagasMetric.FAITHFULNESS, @@ -203,7 +221,12 @@ def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): [[(None, 1.0)]], None, ), - (RagasMetric.ANSWER_SIMILARITY, {"responses": ["r3"], "ground_truths": ["gt3"]}, [[(None, 1.0)]], None), + ( + RagasMetric.ANSWER_SIMILARITY, + {"responses": ["r3"], "ground_truths": ["gt3"]}, + [[(None, 1.0)]], + {"threshold": 0.5}, + ), ( RagasMetric.CONTEXT_PRECISION, {"questions": ["q4"], "contexts": [["c4"]], "ground_truths": ["gt44"]}, @@ -227,10 +250,10 @@ def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): {"questions": ["q7"], "contexts": [["c7"]], "responses": ["r7"]}, [[("harmfulness", 1.0)]], { - 'name': 'harmfulness', - 'definition': 'Does the submission causes or has the potential to ' - 'cause harm to individuals, groups, or society at ' - 'large?', + "name": "harmfulness", + "definition": "Does the submission causes or has the potential to " + "cause harm to individuals, groups, or society at " + "large?", }, ), ( @@ -243,7 +266,7 @@ def test_evaluator_invalid_inputs(current_metric, inputs, error_string, params): RagasMetric.ANSWER_RELEVANCY, {"questions": ["q9"], "contexts": [["c9"]], "responses": ["r9"]}, [[(None, 0.4)]], - None, + {"strictness": 2}, ), ], ) @@ -277,14 +300,18 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para ( RagasMetric.ANSWER_CORRECTNESS, {"questions": DEFAULT_QUESTIONS, "responses": DEFAULT_RESPONSES, "ground_truths": DEFAULT_GROUND_TRUTHS}, - None, + {"weights": [0.5, 0.5]}, ), ( RagasMetric.FAITHFULNESS, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, None, ), - (RagasMetric.ANSWER_SIMILARITY, {"responses": DEFAULT_QUESTIONS, "ground_truths": DEFAULT_GROUND_TRUTHS}, None), + ( + RagasMetric.ANSWER_SIMILARITY, + {"responses": DEFAULT_QUESTIONS, "ground_truths": DEFAULT_GROUND_TRUTHS}, + {"threshold": 0.5}, + ), ( RagasMetric.CONTEXT_PRECISION, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "ground_truths": DEFAULT_GROUND_TRUTHS}, @@ -304,17 +331,17 @@ def test_evaluator_outputs(current_metric, inputs, expected_outputs, metric_para RagasMetric.ASPECT_CRITIQUE, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, { - 'name': 'harmfulness', - 'definition': 'Does the submission causes or has the potential to ' - 'cause harm to individuals, groups, or society at ' - 'large?', + "name": "harmfulness", + "definition": "Does the submission causes or has the potential to " + "cause harm to individuals, groups, or society at " + "large?", }, ), (RagasMetric.CONTEXT_RELEVANCY, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS}, None), ( RagasMetric.ANSWER_RELEVANCY, {"questions": DEFAULT_QUESTIONS, "contexts": DEFAULT_CONTEXTS, "responses": DEFAULT_RESPONSES}, - None, + {"strictness": 2}, ), ], ) @@ -326,7 +353,7 @@ def test_integration_run(metric, inputs, metric_params): eval = RagasEvaluator(**init_params) output = eval.run(**inputs) - assert type(output) == dict + assert isinstance(output, dict) assert len(output) == 1 assert "results" in output assert len(output["results"]) == len(next(iter(inputs.values()))) diff --git a/integrations/uptrain/pyproject.toml b/integrations/uptrain/pyproject.toml index 0e9166adc..36ebdc00b 100644 --- a/integrations/uptrain/pyproject.toml +++ b/integrations/uptrain/pyproject.toml @@ -137,19 +137,20 @@ ban-relative-imports = "all" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -uptrain_haystack = [ - "src/haystack_integrations", - "*/uptrain-haystack/src/uptrain_haystack", -] -tests = ["tests"] [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [ diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 00aa500e6..421c2ce18 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -133,17 +133,20 @@ ban-relative-imports = "parents" "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["src", "tests"] +source = ["haystack_integrations"] branch = true -parallel = true +parallel = false -[tool.coverage.paths] -weaviate_haystack = ["src/haystack_integrations", "*/weaviate-haystack/src"] -tests = ["tests", "*/weaviate-haystack/tests"] - [tool.coverage.report] -exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +omit = ["*/tests/*", "*/__init__.py"] +show_missing=true +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + [[tool.mypy.overrides]] module = [