diff --git a/.github/labeler.yml b/.github/labeler.yml index 85f15788f..4e2899e4b 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -14,6 +14,11 @@ integration:astra: - any-glob-to-any-file: "integrations/astra/**/*" - any-glob-to-any-file: ".github/workflows/astra.yml" +integration:azure-ai-search: + - changed-files: + - any-glob-to-any-file: "integrations/azure_ai_search/**/*" + - any-glob-to-any-file: ".github/workflows/azure_ai_search.yml" + integration:chroma: - changed-files: - any-glob-to-any-file: "integrations/chroma/**/*" diff --git a/.github/workflows/CI_pypi_release.yml b/.github/workflows/CI_pypi_release.yml index e29d3ea36..1162ca3b3 100644 --- a/.github/workflows/CI_pypi_release.yml +++ b/.github/workflows/CI_pypi_release.yml @@ -54,7 +54,7 @@ jobs: with: config: cliff.toml args: > - --include-path "${{ steps.pathfinder.outputs.project_path }}/*" + --include-path "${{ steps.pathfinder.outputs.project_path }}/**/*" --tag-pattern "${{ steps.pathfinder.outputs.project_path }}-v*" - name: Commit changelog diff --git a/.github/workflows/pgvector.yml b/.github/workflows/pgvector.yml index 0fe20e037..ab5c984ed 100644 --- a/.github/workflows/pgvector.yml +++ b/.github/workflows/pgvector.yml @@ -33,7 +33,7 @@ jobs: python-version: ["3.9", "3.10", "3.11"] services: pgvector: - image: ankane/pgvector:latest + image: pgvector/pgvector:pg17 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres diff --git a/README.md b/README.md index af83d045d..0f8b2f0ee 100644 --- a/README.md +++ b/README.md @@ -38,12 +38,12 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | | [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | | [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | -| [jina-haystack](integrations/jina/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | +| [jina-haystack](integrations/jina/) | Connector, Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / jina](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | | [langfuse-haystack](integrations/langfuse/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/langfuse-haystack.svg?color=orange)](https://pypi.org/project/langfuse-haystack) | [![Test / langfuse](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/langfuse.yml) | | [llama-cpp-haystack](integrations/llama_cpp/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/llama-cpp-haystack.svg?color=orange)](https://pypi.org/project/llama-cpp-haystack) | [![Test / llama-cpp](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/llama_cpp.yml) | | [mistral-haystack](integrations/mistral/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/mistral-haystack.svg)](https://pypi.org/project/mistral-haystack) | [![Test / mistral](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mistral.yml) | | [mongodb-atlas-haystack](integrations/mongodb_atlas/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/mongodb-atlas-haystack.svg?color=orange)](https://pypi.org/project/mongodb-atlas-haystack) | [![Test / mongodb-atlas](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/mongodb_atlas.yml) | -| [nvidia-haystack](integrations/nvidia/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/nvidia-haystack.svg?color=orange)](https://pypi.org/project/nvidia-haystack) | [![Test / nvidia](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml) | +| [nvidia-haystack](integrations/nvidia/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/nvidia-haystack.svg?color=orange)](https://pypi.org/project/nvidia-haystack) | [![Test / nvidia](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/nvidia.yml) | | [ollama-haystack](integrations/ollama/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/ollama-haystack.svg?color=orange)](https://pypi.org/project/ollama-haystack) | [![Test / ollama](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/ollama.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | | [optimum-haystack](integrations/optimum/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/optimum-haystack.svg)](https://pypi.org/project/optimum-haystack) | [![Test / optimum](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/optimum.yml) | @@ -85,3 +85,8 @@ GitHub. The GitHub Actions workflow will take care of the rest. git push --tags origin ``` 3. Wait for the CI to do its magic + +> [!IMPORTANT] +> When releasing a new integration version, always tag a commit that includes the changes for that integration +> (usually the PR merge commit). If you tag a commit that doesn't include changes for the integration being released, +> the generated changelog will be incorrect. diff --git a/cliff.toml b/cliff.toml index 29228543e..45e4c647b 100644 --- a/cliff.toml +++ b/cliff.toml @@ -19,11 +19,43 @@ body = """ ## [unreleased] {% endif %}\ {% for group, commits in commits | group_by(attribute="group") %} + {# + Skip the whole section if it contains only a single commit + and it's the commit that updated the changelog. + If we don't do this we get an empty section since we don't show + commits that update the changelog + #}\ + {% if commits | length == 1 and commits[0].message == 'Update the changelog' %}\ + {% continue %}\ + {% endif %}\ ### {{ group | striptags | trim | upper_first }} - {% for commit in commits %} + {% for commit in commits %}\ + {# + Skip commits that update the changelog, they're not useful to the user + #}\ + {% if commit.message == 'Update the changelog' %}\ + {% continue %}\ + {% endif %} - {% if commit.scope %}*({{ commit.scope }})* {% endif %}\ {% if commit.breaking %}[**breaking**] {% endif %}\ - {{ commit.message | upper_first }}\ + {# + We first try to render the conventional commit message if present. + If it's not a conventional commit we get the PR title if present. + If the commit is neither conventional, nor has a PR title set + we fallback to whatever the commit message is. + + We do this cause when merging PRs with multiple commits that don't + have a title following conventional commit guidelines we might get + a commit message that is multiple lines. That makes the changelog + look a bit funky so we handle it like so. + #}\ + {% if commit.conventional %}\ + {{ commit.message | upper_first }}\ + {% elif commit.remote.pr_title %}\ + {{ commit.remote.pr_title | upper_first }} (#{{ commit.remote.pr_number }})\ + {% else %}\ + {{ commit.message | upper_first }}\ + {% endif %}\ {% endfor %} {% endfor %}\n """ @@ -35,7 +67,7 @@ footer = """ trim = true # postprocessors postprocessors = [ - # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL + # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL ] [git] @@ -47,24 +79,26 @@ filter_unconventional = false split_commits = false # regex for preprocessing the commit messages commit_preprocessors = [ - # Replace issue numbers - #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, - # Check spelling of the commit with https://github.com/crate-ci/typos - # If the spelling is incorrect, it will be automatically fixed. - #{ pattern = '.*', replace_command = 'typos --write-changes -' }, + # Replace issue numbers + #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, + # Check spelling of the commit with https://github.com/crate-ci/typos + # If the spelling is incorrect, it will be automatically fixed. + #{ pattern = '.*', replace_command = 'typos --write-changes -' }, ] # regex for parsing and grouping commits commit_parsers = [ - { message = "^feat", group = "๐Ÿš€ Features" }, - { message = "^fix", group = "๐Ÿ› Bug Fixes" }, - { message = "^doc", group = "๐Ÿ“š Documentation" }, - { message = "^perf", group = "โšก Performance" }, - { message = "^refactor", group = "๐Ÿšœ Refactor" }, - { message = "^style", group = "๐ŸŽจ Styling" }, - { message = "^test", group = "๐Ÿงช Testing" }, - { message = "^chore|^ci", group = "โš™๏ธ Miscellaneous Tasks" }, - { body = ".*security", group = "๐Ÿ›ก๏ธ Security" }, - { message = "^revert", group = "โ—€๏ธ Revert" }, + { message = "^feat", group = "๐Ÿš€ Features" }, + { message = "^fix", group = "๐Ÿ› Bug Fixes" }, + { message = "^refactor", group = "๐Ÿšœ Refactor" }, + { message = "^doc", group = "๐Ÿ“š Documentation" }, + { message = "^perf", group = "โšก Performance" }, + { message = "^style", group = "๐ŸŽจ Styling" }, + { message = "^test", group = "๐Ÿงช Testing" }, + { body = ".*security", group = "๐Ÿ›ก๏ธ Security" }, + { message = "^revert", group = "โ—€๏ธ Revert" }, + { message = "^ci", group = "โš™๏ธ CI" }, + { message = "^chore", group = "๐Ÿงน Chores" }, + { message = ".*", group = "๐ŸŒ€ Miscellaneous" }, ] # protect breaking changes from being skipped due to matching a skipping commit_parser protect_breaking_commits = false @@ -82,3 +116,7 @@ topo_order = false sort_commits = "oldest" # limit the number of commits included in the changelog. # limit_commits = 42 + +[remote.github] +owner = "deepset-ai" +repo = "haystack-core-integrations" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py index b2efefdc8..2ebd35979 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/__init__.py @@ -4,4 +4,4 @@ from .document_embedder import AmazonBedrockDocumentEmbedder from .text_embedder import AmazonBedrockTextEmbedder -__all__ = ["AmazonBedrockTextEmbedder", "AmazonBedrockDocumentEmbedder"] +__all__ = ["AmazonBedrockDocumentEmbedder", "AmazonBedrockTextEmbedder"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py index 9dc8cbcc5..f15601f57 100755 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/embedders/amazon_bedrock/document_embedder.py @@ -243,7 +243,7 @@ def run(self, documents: List[Document]): - `documents`: The `Document`s with the `embedding` field populated. :raises AmazonBedrockInferenceError: If the inference fails. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "AmazonBedrockDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the AmazonBedrockTextEmbedder." diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 2d33beb42..ab3f0dfd5 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -4,4 +4,4 @@ from .chat.chat_generator import AmazonBedrockChatGenerator from .generator import AmazonBedrockGenerator -__all__ = ["AmazonBedrockGenerator", "AmazonBedrockChatGenerator"] +__all__ = ["AmazonBedrockChatGenerator", "AmazonBedrockGenerator"] diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py index f5e8f8181..cbb5ee370 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/adapters.py @@ -212,6 +212,8 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ stop_sequences = inference_kwargs.get("stop_sequences", []) + inference_kwargs.pop("stop_words", []) if stop_sequences: inference_kwargs["stop_sequences"] = stop_sequences + # pop stream kwarg from inference_kwargs as Anthropic does not support it (if provided) + inference_kwargs.pop("stream", None) params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {**self.prepare_chat_messages(messages=messages), **params} return body @@ -384,6 +386,10 @@ def prepare_body(self, messages: List[ChatMessage], **inference_kwargs) -> Dict[ stop_words = inference_kwargs.pop("stop_words", []) if stop_words: inference_kwargs["stop"] = stop_words + + # pop stream kwarg from inference_kwargs as Mistral does not support it (if provided) + inference_kwargs.pop("stream", None) + params = self._get_params(inference_kwargs, default_params, self.ALLOWED_PARAMS) body = {"prompt": self.prepare_chat_messages(messages=messages), **params} return body diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index da783979f..8d6a5c3ee 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -17,7 +17,7 @@ ) KLASS = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator" -MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", "meta.llama2-13b-chat-v1"] +MODELS_TO_TEST = ["anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1"] MODELS_TO_TEST_WITH_TOOLS = ["anthropic.claude-3-haiku-20240307-v1:0"] MISTRAL_MODELS = [ "mistral.mistral-7b-instruct-v0:2", @@ -248,10 +248,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.model_adapter.get_responses = MagicMock( return_value=[ - ChatMessage( + ChatMessage.from_assistant( content="Some text", - role=ChatRole.ASSISTANT, - name=None, meta={ "model": "claude-3-sonnet-20240229", "index": 0, diff --git a/integrations/anthropic/CHANGELOG.md b/integrations/anthropic/CHANGELOG.md index a219da6a5..a7cdc7d09 100644 --- a/integrations/anthropic/CHANGELOG.md +++ b/integrations/anthropic/CHANGELOG.md @@ -1,11 +1,29 @@ # Changelog +## [unreleased] + +### โš™๏ธ CI + +- Adopt uv as installer (#1142) + +### ๐Ÿงน Chores + +- Update ruff linting scripts and settings (#1105) + +### ๐ŸŒ€ Miscellaneous + +- Add AnthropicVertexChatGenerator component (#1192) + ## [integrations/anthropic-v1.1.0] - 2024-09-20 ### ๐Ÿš€ Features - Add Anthropic prompt caching support, add example (#1006) +### ๐ŸŒ€ Miscellaneous + +- Chore: Update Anthropic example, use ChatPromptBuilder properly (#978) + ## [integrations/anthropic-v1.0.0] - 2024-08-12 ### ๐Ÿ› Bug Fixes @@ -20,12 +38,18 @@ - Do not retry tests in `hatch run test` command (#954) + ## [integrations/anthropic-v0.4.1] - 2024-07-17 -### โš™๏ธ Miscellaneous Tasks +### ๐Ÿงน Chores - Update ruff invocation to include check parameter (#853) +### ๐ŸŒ€ Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Add meta deprecration warning (#910) + ## [integrations/anthropic-v0.4.0] - 2024-06-21 ### ๐Ÿš€ Features @@ -33,12 +57,24 @@ - Update Anthropic/Cohere for tools use (#790) - Update Anthropic default models, pydocs (#839) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) +### ๐ŸŒ€ Miscellaneous + +- Remove references to Python 3.7 (#601) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Docs: add missing api references (#728) + ## [integrations/anthropic-v0.2.0] - 2024-03-15 +### ๐ŸŒ€ Miscellaneous + +- Docs: Replace amazon-bedrock with anthropic in readme (#584) +- Chore: Use the correct sonnet model name (#587) + ## [integrations/anthropic-v0.1.0] - 2024-03-15 ### ๐Ÿš€ Features diff --git a/integrations/anthropic/pydoc/config.yml b/integrations/anthropic/pydoc/config.yml index 9c1e39daf..bd3811571 100644 --- a/integrations/anthropic/pydoc/config.yml +++ b/integrations/anthropic/pydoc/config.yml @@ -4,6 +4,7 @@ loaders: modules: [ "haystack_integrations.components.generators.anthropic.generator", "haystack_integrations.components.generators.anthropic.chat.chat_generator", + "haystack_integrations.components.generators.anthropic.chat.vertex_chat_generator", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py index c2c1ee40d..12c588dc4 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AnthropicChatGenerator +from .chat.vertex_chat_generator import AnthropicVertexChatGenerator from .generator import AnthropicGenerator -__all__ = ["AnthropicGenerator", "AnthropicChatGenerator"] +__all__ = ["AnthropicChatGenerator", "AnthropicGenerator", "AnthropicVertexChatGenerator"] diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py new file mode 100644 index 000000000..4ece944cd --- /dev/null +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py @@ -0,0 +1,135 @@ +import os +from typing import Any, Callable, Dict, Optional + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable + +from anthropic import AnthropicVertex + +from .chat_generator import AnthropicChatGenerator + +logger = logging.getLogger(__name__) + + +@component +class AnthropicVertexChatGenerator(AnthropicChatGenerator): + """ + + Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. + It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, + accessible through the Vertex AI API endpoint. + + To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. + Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. + Before making requests, you may need to authenticate with GCP using `gcloud auth login`. + For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). + + Any valid text generation parameters for the Anthropic messaging API can be passed to + the AnthropicVertex API. Users can provide these parameters directly to the component via + the `generation_kwargs` parameter in `__init__` or the `run` method. + + For more details on the parameters supported by the Anthropic API, refer to the + Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). + + ```python + from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + client = AnthropicVertexChatGenerator( + model="claude-3-sonnet@20240229", + project_id="your-project-id", region="your-region" + ) + response = client.run(messages) + print(response) + + >> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that + >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing + >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and + >> communicate in natural languages like English, Spanish, or Chinese.', role=, + >> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', + >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} + ``` + + For more details on supported models and their capabilities, refer to the Anthropic + [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). + + """ + + def __init__( + self, + region: str, + project_id: str, + model: str = "claude-3-5-sonnet@20240620", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, + ): + """ + Creates an instance of AnthropicVertexChatGenerator. + + :param region: The region where the Anthropic model is deployed. Defaults to "us-central1". + :param project_id: The GCP project ID where the Anthropic model is deployed. + :param model: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to + the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) + for more details. + + Supported generation_kwargs parameters are: + - `system`: The system message to be passed to the model. + - `max_tokens`: The maximum number of tokens to generate. + - `metadata`: A dictionary of metadata to be passed to the model. + - `stop_sequences`: A list of strings that the model should stop generating at. + - `temperature`: The temperature to use for sampling. + - `top_p`: The top_p value to use for nucleus sampling. + - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) + for more details. + """ + self.region = region or os.environ.get("REGION") + self.project_id = project_id or os.environ.get("PROJECT_ID") + self.model = model + self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback + self.client = AnthropicVertex(region=self.region, project_id=self.project_id) + self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + region=self.region, + project_id=self.project_id, + model=self.model, + streaming_callback=callback_name, + generation_kwargs=self.generation_kwargs, + ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py new file mode 100644 index 000000000..a67e801ad --- /dev/null +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -0,0 +1,197 @@ +import os + +import anthropic +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + +class TestAnthropicVertexChatGenerator: + def test_init_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.ignore_tools_thinking_messages + + def test_init_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ignore_tools_thinking_messages=False, + ) + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.ignore_tools_thinking_messages is False + + def test_to_dict_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": None, + "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_lambda_streaming_callback(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=lambda x: x, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "tests.test_vertex_chat_generator.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_from_dict(self): + data = { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + component = AnthropicVertexChatGenerator.from_dict(data) + assert component.model == "claude-3-5-sonnet@20240620" + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_run(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator( + region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = AnthropicVertexChatGenerator( + model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") + ) + with pytest.raises(anthropic.NotFoundError): + component.run(chat_messages) + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_default_inference_params(self, chat_messages): + client = AnthropicVertexChatGenerator( + region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" + ) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + # Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, + # remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator. diff --git a/integrations/astra/CHANGELOG.md b/integrations/astra/CHANGELOG.md index fff6cb65f..6ad660a0e 100644 --- a/integrations/astra/CHANGELOG.md +++ b/integrations/astra/CHANGELOG.md @@ -1,16 +1,29 @@ # Changelog +## [integrations/astra-v0.9.4] - 2024-11-25 + +### ๐ŸŒ€ Miscellaneous + +- Fix: Astra - fix embedding retrieval top-k limit (#1210) + ## [integrations/astra-v0.10.0] - 2024-10-22 ### ๐Ÿš€ Features - Update astradb integration for latest client library (#1145) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI -- Update ruff linting scripts and settings (#1105) - Adopt uv as installer (#1142) +### ๐Ÿงน Chores + +- Update ruff linting scripts and settings (#1105) + +### ๐ŸŒ€ Miscellaneous + +- Fix: #1047 Remove count_documents from delete_documents (#1049) + ## [integrations/astra-v0.9.3] - 2024-09-12 ### ๐Ÿ› Bug Fixes @@ -22,8 +35,13 @@ - Do not retry tests in `hatch run test` command (#954) + ## [integrations/astra-v0.9.2] - 2024-07-22 +### ๐ŸŒ€ Miscellaneous + +- Normalize logical filter conditions (#874) + ## [integrations/astra-v0.9.1] - 2024-07-15 ### ๐Ÿš€ Features @@ -37,27 +55,48 @@ - Fix typing checks - `Astra` - Fallback to default filter policy when deserializing retrievers without the init parameter (#896) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) +### ๐ŸŒ€ Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Fix: Incorrect astra not equal operator (#868) +- Chore: Minor retriever pydoc fix (#884) + ## [integrations/astra-v0.7.0] - 2024-05-15 ### ๐Ÿ› Bug Fixes - Make unit tests pass (#720) +### ๐ŸŒ€ Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- [Astra DB] Explicit projection when reading from Astra DB (#733) + ## [integrations/astra-v0.6.0] - 2024-04-24 ### ๐Ÿ› Bug Fixes - Pass namespace in the docstore init (#683) +### ๐ŸŒ€ Miscellaneous + +- Chore: add license classifiers (#680) +- Bug fix for document_store.py (#618) + ## [integrations/astra-v0.5.1] - 2024-04-09 ### ๐Ÿ› Bug Fixes -- Fix haystack-ai pin (#649) +- Fix `haystack-ai` pins (#649) + +### ๐ŸŒ€ Miscellaneous + +- Remove references to Python 3.7 (#601) +- Make Document Stores initially skip `SparseEmbedding` (#606) ## [integrations/astra-v0.5.0] - 2024-03-18 @@ -67,9 +106,15 @@ - Small consistency improvements (#536) - Disable-class-def (#556) +### ๐ŸŒ€ Miscellaneous + +- Fix example code for Astra DB pipeline (#481) +- Make tests show coverage (#566) +- Astra DB: Add integration usage tracking (#568) + ## [integrations/astra-v0.4.2] - 2024-02-21 -### FIX +### ๐ŸŒ€ Miscellaneous - Proper name for the sort param (#454) @@ -78,9 +123,7 @@ ### ๐Ÿ› Bug Fixes - Fix order of API docs (#447) - -This PR will also push the docs to Readme -- Fix integration tests (#450) +- Astra: fix integration tests (#450) ## [integrations/astra-v0.4.0] - 2024-02-20 @@ -88,20 +131,35 @@ This PR will also push the docs to Readme - Update category slug (#442) +### ๐ŸŒ€ Miscellaneous + +- Update the Astra DB Integration to fit latest conventions (#428) + ## [integrations/astra-v0.3.0] - 2024-02-15 -## [integrations/astra-v0.2.0] - 2024-02-13 +### ๐ŸŒ€ Miscellaneous -### Astra +- Model_name_or_path > model (#418) +- [Astra] Change authentication parameters (#423) -- Generate api docs (#327) +## [integrations/astra-v0.2.0] - 2024-02-13 -### Refact +### ๐ŸŒ€ Miscellaneous - [**breaking**] Change import paths (#277) +- Generate api docs (#327) +- Astra: rename retriever (#399) ## [integrations/astra-v0.1.1] - 2024-01-18 +### ๐ŸŒ€ Miscellaneous + +- Update the import paths for beta5 (#235) + ## [integrations/astra-v0.1.0] - 2024-01-11 +### ๐ŸŒ€ Miscellaneous + +- Adding AstraDB as a DocumentStore (#144) + diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 6f2289786..1a3481e0c 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -202,7 +202,7 @@ def _format_query_response(responses, include_metadata, include_values): return QueryResponse(final_res) def _query(self, vector, top_k, filters=None): - query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}} + query = {"sort": {"$vector": vector}, "limit": top_k, "includeSimilarity": True} if filters is not None: query["filter"] = filters @@ -222,6 +222,7 @@ def find_documents(self, find_query): filter=find_query.get("filter"), sort=find_query.get("sort"), limit=find_query.get("limit"), + include_similarity=find_query.get("includeSimilarity"), projection={"*": 1}, ) diff --git a/integrations/astra/tests/test_embedding_retrieval.py b/integrations/astra/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..bf23fe9f5 --- /dev/null +++ b/integrations/astra/tests/test_embedding_retrieval.py @@ -0,0 +1,48 @@ +import os + +import pytest +from haystack import Document +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.astra import AstraDocumentStore + + +@pytest.mark.integration +@pytest.mark.skipif( + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" +) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set") +class TestEmbeddingRetrieval: + + @pytest.fixture + def document_store(self) -> AstraDocumentStore: + return AstraDocumentStore( + collection_name="haystack_integration", + duplicates_policy=DuplicatePolicy.OVERWRITE, + embedding_dimension=768, + ) + + @pytest.fixture(autouse=True) + def run_before_and_after_tests(self, document_store: AstraDocumentStore): + """ + Cleaning up document store + """ + document_store.delete_documents(delete_all=True) + assert document_store.count_documents() == 0 + + def test_search_with_top_k(self, document_store): + query_embedding = [0.1] * 768 + common_embedding = [0.8] * 768 + + documents = [Document(content=f"This is document number {i}", embedding=common_embedding) for i in range(0, 3)] + + document_store.write_documents(documents) + + top_k = 2 + + result = document_store.search(query_embedding, top_k) + + assert top_k == len(result) + + for document in result: + assert document.score is not None diff --git a/integrations/azure_ai_search/CHANGELOG.md b/integrations/azure_ai_search/CHANGELOG.md new file mode 100644 index 000000000..6a8d26c9d --- /dev/null +++ b/integrations/azure_ai_search/CHANGELOG.md @@ -0,0 +1,22 @@ +# Changelog + +## [integrations/azure_ai_search-v0.1.1] - 2024-11-22 + +### ๐Ÿ› Bug Fixes + +- Fix error in README file (#1207) + + +## [integrations/azure_ai_search-v0.1.0] - 2024-11-21 + +### ๐Ÿš€ Features + +- Add Azure AI Search integration (#1122) +- Add BM25 and Hybrid Search Retrievers to Azure AI Search Integration (#1175) + +### ๐ŸŒ€ Miscellaneous + +- Enable kwargs in SearchIndex and Embedding Retriever (#1185) +- Fix: Fix tag name for version release (#1206) + + diff --git a/integrations/azure_ai_search/README.md b/integrations/azure_ai_search/README.md index 915a23b63..51cc7720c 100644 --- a/integrations/azure_ai_search/README.md +++ b/integrations/azure_ai_search/README.md @@ -19,7 +19,7 @@ pip install azure-ai-search-haystack ``` ## Examples -You can find a code example showing how to use the Document Store and the Retriever in the documentation or in [this Colab](https://colab.research.google.com/drive/1YpDetI8BRbObPDEVdfqUcwhEX9UUXP-m?usp=sharing). +Refer to the documentation for code examples on utilizing the Document Store and its associated Retrievers. For more usage scenarios, check out the [examples](https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search/example). ## License diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py index 779f28935..92a641717 100644 --- a/integrations/azure_ai_search/example/document_store.py +++ b/integrations/azure_ai_search/example/document_store.py @@ -1,5 +1,4 @@ from haystack import Document -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -30,7 +29,7 @@ meta={"version": 2.0, "label": "chapter_three"}, ), ] -document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) +document_store.write_documents(documents) filters = { "operator": "AND", diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py index 088b08653..188f8525a 100644 --- a/integrations/azure_ai_search/example/embedding_retrieval.py +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -1,7 +1,6 @@ from haystack import Document, Pipeline from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder from haystack.components.writers import DocumentWriter -from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore @@ -38,9 +37,7 @@ # Indexing Pipeline indexing_pipeline = Pipeline() indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") -indexing_pipeline.add_component( - instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" -) +indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer") indexing_pipeline.connect("doc_embedder", "doc_writer") indexing_pipeline.run({"doc_embedder": {"documents": documents}}) diff --git a/integrations/azure_ai_search/pyproject.toml b/integrations/azure_ai_search/pyproject.toml index 49ca623e7..cb967b1e0 100644 --- a/integrations/azure_ai_search/pyproject.toml +++ b/integrations/azure_ai_search/pyproject.toml @@ -33,11 +33,11 @@ packages = ["src/haystack_integrations"] [tool.hatch.version] source = "vcs" -tag-pattern = 'integrations\/azure-ai-search-v(?P.*)' +tag-pattern = 'integrations\/azure_ai_search-v(?P.*)' [tool.hatch.version.raw-options] root = "../.." -git_describe_command = 'git describe --tags --match="integrations/azure-ai-search-v[0-9]*"' +git_describe_command = 'git describe --tags --match="integrations/azure_ai_search-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py index eb75ffa6c..56dc30db4 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -1,3 +1,5 @@ +from .bm25_retriever import AzureAISearchBM25Retriever from .embedding_retriever import AzureAISearchEmbeddingRetriever +from .hybrid_retriever import AzureAISearchHybridRetriever -__all__ = ["AzureAISearchEmbeddingRetriever"] +__all__ = ["AzureAISearchBM25Retriever", "AzureAISearchEmbeddingRetriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py new file mode 100644 index 000000000..4a1c7f98c --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -0,0 +1,135 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchBM25Retriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, + ): + """ + Create the AzureAISearchBM25Retriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the BM25 search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If the query is not valid, or if the document store is not correctly configured. + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._kwargs = kwargs + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise TypeError(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the BM25 retrieval process. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + filters = filters or self._filters + if filters: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = _normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._bm25_retrieval( + query=query, + filters=normalized_filters, + top_k=top_k, + **self._kwargs, + ) + except Exception as e: + msg = ( + "An error occurred during the bm25 retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py index ab649f874..69fad7208 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -5,7 +5,7 @@ from haystack.document_stores.types import FilterPolicy from haystack.document_stores.types.filter_policy import apply_filter_policy -from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters logger = logging.getLogger(__name__) @@ -25,16 +25,23 @@ def __init__( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, ): """ Create the AzureAISearchEmbeddingRetriever component. :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. :param filters: Filters applied when fetching documents from the Document Store. - Filters are applied during the approximate kNN search to ensure the Retriever returns - `top_k` matching documents. :param top_k: Maximum number of documents to return. - :filter_policy: Policy to determine how filters are applied. Possible options: + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ self._filters = filters or {} @@ -43,6 +50,7 @@ def __init__( self._filter_policy = ( filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) ) + self._kwargs = kwargs if not isinstance(document_store, AzureAISearchDocumentStore): message = "document_store must be an instance of AzureAISearchDocumentStore" @@ -61,6 +69,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, document_store=self._document_store.to_dict(), filter_policy=self._filter_policy.value, + **self._kwargs, ) @classmethod @@ -88,29 +97,32 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): """Retrieve documents from the AzureAISearchDocumentStore. - :param query_embedding: floats representing the query embedding + :param query_embedding: A list of floats representing the query embedding. :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on - the `filter_policy` chosen at retriever initialization. See init method docstring for more - details. - :param top_k: the maximum number of documents to retrieve. - :returns: a dictionary with the following keys: - - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :returns: Dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. """ top_k = top_k or self._top_k - if filters is not None: + filters = filters or self._filters + if filters: applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) - normalized_filters = normalize_filters(applied_filters) + normalized_filters = _normalize_filters(applied_filters) else: normalized_filters = "" try: docs = self._document_store._embedding_retrieval( - query_embedding=query_embedding, - filters=normalized_filters, - top_k=top_k, + query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs ) except Exception as e: - raise e + msg = ( + "An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query embedding is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py new file mode 100644 index 000000000..79282933f --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -0,0 +1,139 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchHybridRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, + ): + """ + Create the AzureAISearchHybridRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the hybrid search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If query or query_embedding are invalid, or if document store is not correctly configured. + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._kwargs = kwargs + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise TypeError(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param query_embedding: A list of floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the hybrid retrieval process. + :returns: A dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + filters = filters or self._filters + if filters: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = _normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._hybrid_retrieval( + query=query, query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs + ) + except Exception as e: + msg = ( + "An error occurred during the hybrid retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query and query_embedding are valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py index 635878a38..dcee0e622 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -2,6 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore -from .filters import normalize_filters +from .filters import _normalize_filters -__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] +__all__ = ["DEFAULT_VECTOR_SEARCH", "AzureAISearchDocumentStore", "_normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 0b59b6e37..137ff621c 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -31,7 +31,7 @@ from haystack.utils import Secret, deserialize_secrets_inplace from .errors import AzureAISearchDocumentStoreConfigError -from .filters import normalize_filters +from .filters import _normalize_filters type_mapping = { str: "Edm.String", @@ -70,7 +70,7 @@ def __init__( embedding_dimension: int = 768, metadata_fields: Optional[Dict[str, type]] = None, vector_search_configuration: VectorSearch = None, - **kwargs, + **index_creation_kwargs, ): """ A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) @@ -87,19 +87,20 @@ def __init__( :param vector_search_configuration: Configuration option related to vector search. Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. - :param kwargs: Optional keyword parameters for Azure AI Search. - Some of the supported parameters: - - `api_version`: The Search API version to use for requests. - - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). - The audience is not considered when using a shared key. If audience is not provided, - the public cloud audience will be assumed. + :param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class + during index creation. Some of the supported parameters: + - `semantic_search`: Defines semantic configuration of the search index. This parameter is needed + to enable semantic search capabilities in index. + - `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents + matching a search query. The similarity algorithm can only be defined at index creation time and + cannot be modified on existing indexes. - For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). """ azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None if not azure_endpoint: - msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT." raise ValueError(msg) api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None @@ -114,7 +115,7 @@ def __init__( self._dummy_vector = [-10.0] * self._embedding_dimension self._metadata_fields = metadata_fields self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH - self._kwargs = kwargs + self._index_creation_kwargs = index_creation_kwargs @property def client(self) -> SearchClient: @@ -128,7 +129,10 @@ def client(self) -> SearchClient: credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() try: if not self._index_client: - self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + self._index_client = SearchIndexClient( + resolved_endpoint, + credential, + ) if not self._index_exists(self._index_name): # Create a new index if it does not exist logger.debug( @@ -151,7 +155,7 @@ def client(self) -> SearchClient: return self._client - def _create_index(self, index_name: str, **kwargs) -> None: + def _create_index(self, index_name: str) -> None: """ Creates a new search index. :param index_name: Name of the index to create. If None, the index name from the constructor is used. @@ -177,7 +181,10 @@ def _create_index(self, index_name: str, **kwargs) -> None: if self._metadata_fields: default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) index = SearchIndex( - name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + name=index_name, + fields=default_fields, + vector_search=self._vector_search_configuration, + **self._index_creation_kwargs, ) if self._index_client: self._index_client.create_index(index) @@ -194,13 +201,13 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, - api_key=self._api_key.to_dict() if self._api_key is not None else None, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None, + api_key=self._api_key.to_dict() if self._api_key else None, index_name=self._index_name, embedding_dimension=self._embedding_dimension, metadata_fields=self._metadata_fields, vector_search_configuration=self._vector_search_configuration.as_dict(), - **self._kwargs, + **self._index_creation_kwargs, ) @classmethod @@ -233,6 +240,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D Writes the provided documents to search index. :param documents: documents to write to the index. + :param policy: Policy to determine how duplicates are handled. + :raises ValueError: If the documents are not of type Document. + :raises TypeError: If the document ids are not strings. :return: the number of documents added to index. """ @@ -240,7 +250,7 @@ def _convert_input_document(documents: Document): document_dict = asdict(documents) if not isinstance(document_dict["id"], str): msg = f"Document id {document_dict['id']} is not a string, " - raise Exception(msg) + raise TypeError(msg) index_document = self._convert_haystack_documents_to_azure(document_dict) return index_document @@ -298,7 +308,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :returns: A list of Documents that match the given filters. """ if filters: - normalized_filters = normalize_filters(filters) + normalized_filters = _normalize_filters(filters) result = self.client.search(filter=normalized_filters) return self._convert_search_result_to_documents(result) else: @@ -409,12 +419,12 @@ def _embedding_retrieval( query_embedding: List[float], *, top_k: int = 10, - fields: Optional[List[str]] = None, filters: Optional[Dict[str, Any]] = None, + **kwargs, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. - It uses the vector configuration of the document store. By default it uses the HNSW algorithm + It uses the vector configuration specified in the document store. By default, it uses the HNSW algorithm with cosine similarity. This method is not meant to be part of the public interface of @@ -422,12 +432,12 @@ def _embedding_retrieval( `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. :param query_embedding: Embedding of the query. - :param filters: Filters applied to the retrieved Documents. Defaults to None. - Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. - :param top_k: Maximum number of Documents to return, defaults to 10 + :param top_k: Maximum number of Documents to return. + :param filters: Filters applied to the retrieved Documents. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. - :raises ValueError: If `query_embedding` is an empty list - :returns: List of Document that are most similar to `query_embedding` + :raises ValueError: If `query_embedding` is an empty list. + :returns: List of Document that are most similar to `query_embedding`. """ if not query_embedding: @@ -435,6 +445,83 @@ def _embedding_retrieval( raise ValueError(msg) vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") - result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _bm25_retrieval( + self, + query: str, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Document]: + """ + Retrieves documents that are most similar to `query`, using the BM25 algorithm. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + + :raises ValueError: If `query` is an empty string. + :returns: List of Document that are most similar to `query`. + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + + result = self.client.search(search_text=query, filter=filters, top=top_k, query_type="simple", **kwargs) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _hybrid_retrieval( + self, + query: str, + query_embedding: List[float], + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Document]: + """ + Retrieves documents similar to query using the vector configuration in the document store and + the BM25 algorithm. This method combines vector similarity and BM25 for improved retrieval. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchHybridRetriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + :raises ValueError: If `query` or `query_embedding` is empty. + :returns: List of Document that are most similar to `query`. + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search( + search_text=query, + vector_queries=[vector_query], + filter=filters, + top=top_k, + query_type="simple", + **kwargs, + ) azure_docs = list(result) return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py index 650e3f8be..0f105bc91 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -7,7 +7,7 @@ LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} -def normalize_filters(filters: Dict[str, Any]) -> str: +def _normalize_filters(filters: Dict[str, Any]) -> str: """ Converts Haystack filters in Azure AI Search compatible filters. """ diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py index 3017c79c2..89369c87e 100644 --- a/integrations/azure_ai_search/tests/conftest.py +++ b/integrations/azure_ai_search/tests/conftest.py @@ -6,12 +6,14 @@ from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ResourceNotFoundError from azure.search.documents.indexes import SearchIndexClient +from haystack import logging from haystack.document_stores.types import DuplicatePolicy from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore # This is the approximate time in seconds it takes for the documents to be available in Azure Search index -SLEEP_TIME_IN_SECONDS = 5 +SLEEP_TIME_IN_SECONDS = 10 +MAX_WAIT_TIME_FOR_INDEX_DELETION = 5 @pytest.fixture() @@ -46,23 +48,35 @@ def document_store(request): # Override some methods to wait for the documents to be available original_write_documents = store.write_documents + original_delete_documents = store.delete_documents def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): written_docs = original_write_documents(documents, policy) time.sleep(SLEEP_TIME_IN_SECONDS) return written_docs - original_delete_documents = store.delete_documents - def delete_documents_and_wait(filters): original_delete_documents(filters) time.sleep(SLEEP_TIME_IN_SECONDS) + # Helper function to wait for the index to be deleted, needed to cover latency + def wait_for_index_deletion(client, index_name): + start_time = time.time() + while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION: + if index_name not in client.list_index_names(): + return True + time.sleep(1) + return False + store.write_documents = write_documents_and_wait store.delete_documents = delete_documents_and_wait yield store try: client.delete_index(index_name) + if not wait_for_index_deletion(client, index_name): + logging.error(f"Index {index_name} was not properly deleted.") except ResourceNotFoundError: - pass + logging.info(f"Index {index_name} was already deleted or not found.") + except Exception as e: + logging.error(f"Unexpected error when deleting index {index_name}: {e}") diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py new file mode 100644 index 000000000..6ebb20949 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import Mock + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchBM25Retriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "metadata_fields": None, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchBM25Retriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, filters={"field": "type", "operator": "==", "value": "article"}, top_k=11 + ) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run(query="Test query", filters={"field": "type", "operator": "==", "value": "book"}, top_k=5) + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", filters="type eq 'book'", top_k=5, select="name" + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1", content="Test document")] + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.run(query="Test document") + assert res["documents"] == docs + + def test_document_retrieval(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document(content="This is first document"), + Document(content="This is second document"), + Document(content="This is third document"), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + results = retriever.run(query="This is first document") + assert results["documents"][0].content == "This is first document" diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py index d4615ec44..576ecda08 100644 --- a/integrations/azure_ai_search/tests/test_embedding_retriever.py +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -103,6 +103,66 @@ def test_from_dict(): assert retriever._filter_policy == FilterPolicy.REPLACE +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], filters={"field": "type", "operator": "==", "value": "book"}, top_k=9 + ) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + @pytest.mark.skipif( not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", diff --git a/integrations/azure_ai_search/tests/test_hybrid_retriever.py b/integrations/azure_ai_search/tests/test_hybrid_retriever.py new file mode 100644 index 000000000..bf305c4fe --- /dev/null +++ b/integrations/azure_ai_search/tests/test_hybrid_retriever.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchHybridRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchHybridRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], + query="Test query", + filters={"field": "type", "operator": "==", "value": "book"}, + top_k=9, + ) + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.run(query="Test document", query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_hybrid_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is third document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + results = retriever.run(query="This is first document", query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index cfe7a606e..c91cc6cb0 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -22,7 +22,12 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "chromadb>=0.5.17", "typing_extensions>=4.8.0"] +dependencies = [ + "haystack-ai", + "chromadb>=0.5.17", + "typing_extensions>=4.8.0", + "tokenizers>=0.13.2,<=0.20.3" # TODO: remove when Chroma pins tokenizers internally +] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme" diff --git a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py index 53120c97c..e240ba136 100644 --- a/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py +++ b/integrations/chroma/src/haystack_integrations/components/retrievers/chroma/__init__.py @@ -1,3 +1,3 @@ from .retriever import ChromaEmbeddingRetriever, ChromaQueryTextRetriever -__all__ = ["ChromaQueryTextRetriever", "ChromaEmbeddingRetriever"] +__all__ = ["ChromaEmbeddingRetriever", "ChromaQueryTextRetriever"] diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 3201168a8..d311662fe 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -146,7 +146,7 @@ def run(self, documents: List[Document]): - `meta`: metadata about the embedding process. :raises TypeError: if the input is not a list of `Documents`. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "CohereDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the CohereTextEmbedder." diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py index 93c0947e4..7d50682e8 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/__init__.py @@ -4,4 +4,4 @@ from .chat.chat_generator import CohereChatGenerator from .generator import CohereGenerator -__all__ = ["CohereGenerator", "CohereChatGenerator"] +__all__ = ["CohereChatGenerator", "CohereGenerator"] diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 0eb65b368..e4eaf8670 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional from haystack import component -from haystack.dataclasses import ChatMessage, ChatRole +from haystack.dataclasses import ChatMessage from haystack.utils import Secret from .chat.chat_generator import CohereChatGenerator @@ -64,7 +64,7 @@ def run(self, prompt: str): - `replies`: A list of replies generated by the model. - `meta`: Information about the request. """ - chat_message = ChatMessage(content=prompt, role=ChatRole.USER, name="", meta={}) + chat_message = ChatMessage.from_user(prompt) # Note we have to call super() like this because of the way components are dynamically built with the decorator results = super(CohereGenerator, self).run([chat_message]) # noqa return {"replies": [results["replies"][0].content], "meta": [results["replies"][0].meta]} diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 175a6d14b..b7cc0534a 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -27,7 +27,7 @@ def streaming_chunk(text: str): @pytest.fixture def chat_messages(): - return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)] + return [ChatMessage.from_assistant(content="What's the capital of France")] class TestCohereChatGenerator: @@ -164,7 +164,7 @@ def test_message_to_dict(self, chat_messages): ) @pytest.mark.integration def test_live_run(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] + chat_messages = [ChatMessage.from_user(content="What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages) assert len(results["replies"]) == 1 @@ -201,9 +201,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() component = CohereChatGenerator(streaming_callback=callback) - results = component.run( - [ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)] - ) + results = component.run([ChatMessage.from_user(content="What's the capital of France? answer in a word")]) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -224,7 +222,7 @@ def __call__(self, chunk: StreamingChunk) -> None: ) @pytest.mark.integration def test_live_run_with_connector(self): - chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] + chat_messages = [ChatMessage.from_user(content="What's the capital of France")] component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 @@ -249,7 +247,7 @@ def __call__(self, chunk: StreamingChunk) -> None: self.responses += chunk.content if chunk.content else "" callback = Callback() - chat_messages = [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)] + chat_messages = [ChatMessage.from_user(content="What's the capital of France? answer in a word")] component = CohereChatGenerator(streaming_callback=callback) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md index 5dd62d130..841781660 100644 --- a/integrations/fastembed/CHANGELOG.md +++ b/integrations/fastembed/CHANGELOG.md @@ -1,11 +1,23 @@ # Changelog +## [integrations/fastembed-v1.4.1] - 2024-11-19 + +### ๐ŸŒ€ Miscellaneous + +- Add new example with the ranker and fix the old ones (#1198) +- Fix: Fastembed - Change default Sparse model as the used one is deprecated due to a typo (#1201) + ## [integrations/fastembed-v1.4.0] - 2024-11-13 -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Adopt uv as installer (#1142) +### ๐ŸŒ€ Miscellaneous + +- Chore: fastembed - pin `onnxruntime<1.20.0` (#1164) +- Feat: Fastembed - add FastembedRanker (#1178) + ## [integrations/fastembed-v1.3.0] - 2024-10-07 ### ๐Ÿš€ Features @@ -16,20 +28,35 @@ - Do not retry tests in `hatch run test` command (#954) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) + +### ๐Ÿงน Chores + - Update ruff invocation to include check parameter (#853) - Update ruff linting scripts and settings (#1105) -### Fix +### ๐ŸŒ€ Miscellaneous +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) - Typo on Sparse embedders. The parameter should be "progress_bar" โ€ฆ (#814) +- Chore: fastembed - ruff update, don't ruff tests (#997) ## [integrations/fastembed-v1.1.0] - 2024-05-15 +### ๐ŸŒ€ Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- Use the local_files_only option available as of fastembed==0.2.7. It โ€ฆ (#736) + ## [integrations/fastembed-v1.0.0] - 2024-05-06 +### ๐ŸŒ€ Miscellaneous + +- Chore: add license classifiers (#680) +- `FastembedSparseTextEmbedder` - remove `batch_size` (#688) + ## [integrations/fastembed-v0.1.0] - 2024-04-10 ### ๐Ÿš€ Features @@ -40,6 +67,10 @@ - Disable-class-def (#556) +### ๐ŸŒ€ Miscellaneous + +- Remove references to Python 3.7 (#601) + ## [integrations/fastembed-v0.0.6] - 2024-03-07 ### ๐Ÿ“š Documentation @@ -47,20 +78,32 @@ - Review and normalize docstrings - `integrations.fastembed` (#519) - Small consistency improvements (#536) +### ๐ŸŒ€ Miscellaneous + +- Docs: Fix `integrations.fastembed` API docs (#540) +- Improvements to FastEmbed integration (#558) + ## [integrations/fastembed-v0.0.5] - 2024-02-20 ### ๐Ÿ› Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### ๐Ÿ“š Documentation - Update category slug (#442) +### ๐ŸŒ€ Miscellaneous + +- Fastembed integration new parameters (#446) + ## [integrations/fastembed-v0.0.4] - 2024-02-16 +### ๐ŸŒ€ Miscellaneous + +- Fastembed integration: add example (#401) +- Fastembed fix: add parallel (#403) + ## [integrations/fastembed-v0.0.3] - 2024-02-12 ### ๐Ÿ› Bug Fixes @@ -73,6 +116,15 @@ This PR will also push the docs to Readme ## [integrations/fastembed-v0.0.2] - 2024-02-11 +### ๐ŸŒ€ Miscellaneous + +- Updated labeler and readme (#389) +- Fastembed fix: added prefix and suffix (#390) + ## [integrations/fastembed-v0.0.1] - 2024-02-10 +### ๐ŸŒ€ Miscellaneous + +- Add Fastembed Embeddings integration (#383) + diff --git a/integrations/fastembed/README.md b/integrations/fastembed/README.md index c021dec3b..f3c2bb135 100644 --- a/integrations/fastembed/README.md +++ b/integrations/fastembed/README.md @@ -8,6 +8,7 @@ **Table of Contents** - [Installation](#installation) +- [Usage](#Usage) - [License](#license) ## Installation @@ -33,7 +34,7 @@ embedding = text_embedder.run(text)["embedding"] ```python from haystack_integrations.components.embedders.fastembed import FastembedDocumentEmbedder -from haystack.dataclasses import Document +from haystack import Document embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", @@ -50,24 +51,50 @@ from haystack_integrations.components.embedders.fastembed import FastembedSparse text = "fastembed is supported by and maintained by Qdrant." text_embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1" + model="prithivida/Splade_PP_en_v1" ) text_embedder.warm_up() -embedding = text_embedder.run(text)["embedding"] +embedding = text_embedder.run(text)["sparse_embedding"] ``` ```python from haystack_integrations.components.embedders.fastembed import FastembedSparseDocumentEmbedder -from haystack.dataclasses import Document +from haystack import Document embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() doc = Document(content="fastembed is supported by and maintained by Qdrant.", meta={"long_answer": "no",}) result = embedder.run(documents=[doc]) ``` +You can use `FastembedRanker` by importing as: + +```python +from haystack import Document + +from haystack_integrations.components.rankers.fastembed import FastembedRanker + +query = "Who is maintaining Qdrant?" +documents = [ + Document( + content="This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="fastembed is supported by and maintained by Qdrant."), +] + +ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2") +ranker.warm_up() +reranked_documents = ranker.run(query=query, documents=documents)["documents"] + +print(reranked_documents[0]) + +# Document(id=..., +# content: 'fastembed is supported by and maintained by Qdrant.', +# score: 5.472434997558594..) +``` + ## License `fastembed-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/fastembed/examples/ranker_example.py b/integrations/fastembed/examples/ranker_example.py index 7a31e4646..593334e90 100644 --- a/integrations/fastembed/examples/ranker_example.py +++ b/integrations/fastembed/examples/ranker_example.py @@ -15,7 +15,7 @@ reranked_documents = ranker.run(query=query, documents=documents)["documents"] -print(reranked_documents["documents"][0]) +print(reranked_documents[0]) # Document(id=..., # content: 'fastembed is supported by and maintained by Qdrant.', diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py index e943a8ca1..d73c29766 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py @@ -8,7 +8,7 @@ __all__ = [ "FastembedDocumentEmbedder", - "FastembedTextEmbedder", "FastembedSparseDocumentEmbedder", "FastembedSparseTextEmbedder", + "FastembedTextEmbedder", ] diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index 8b63582c5..b064173fe 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -158,7 +158,7 @@ def run(self, documents: List[Document]): :returns: A dictionary with the following keys: - `documents`: List of Documents with each Document's `embedding` field set to the computed embeddings. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "FastembedDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the FastembedTextEmbedder." diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index f79f08c90..fb3df9162 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -16,7 +16,7 @@ class FastembedSparseDocumentEmbedder: from haystack.dataclasses import Document sparse_doc_embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", batch_size=32, ) @@ -53,7 +53,7 @@ class FastembedSparseDocumentEmbedder: def __init__( self, - model: str = "prithvida/Splade_PP_en_v1", + model: str = "prithivida/Splade_PP_en_v1", cache_dir: Optional[str] = None, threads: Optional[int] = None, batch_size: int = 32, @@ -68,7 +68,7 @@ def __init__( Create an FastembedDocumentEmbedder component. :param model: Local path or name of the model in Hugging Face's model hub, - such as `prithvida/Splade_PP_en_v1`. + such as `prithivida/Splade_PP_en_v1`. :param cache_dir: The path to the cache directory. Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. @@ -150,7 +150,7 @@ def run(self, documents: List[Document]): - `documents`: List of Documents with each Document's `sparse_embedding` field set to the computed embeddings. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "FastembedSparseDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the FastembedTextEmbedder." diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index 2ebab35b4..c7296525f 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -19,7 +19,7 @@ class FastembedSparseTextEmbedder: "The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!") sparse_text_embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1" + model="prithivida/Splade_PP_en_v1" ) sparse_text_embedder.warm_up() @@ -29,7 +29,7 @@ class FastembedSparseTextEmbedder: def __init__( self, - model: str = "prithvida/Splade_PP_en_v1", + model: str = "prithivida/Splade_PP_en_v1", cache_dir: Optional[str] = None, threads: Optional[int] = None, progress_bar: bool = True, @@ -40,7 +40,7 @@ def __init__( """ Create a FastembedSparseTextEmbedder component. - :param model: Local path or name of the model in Fastembed's model hub, such as `prithvida/Splade_PP_en_v1` + :param model: Local path or name of the model in Fastembed's model hub, such as `prithivida/Splade_PP_en_v1` :param cache_dir: The path to the cache directory. Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py index 8f077a30c..370344df5 100644 --- a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py @@ -157,7 +157,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None :raises ValueError: If `top_k` is not > 0. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = "FastembedRanker expects a list of Documents as input. " raise TypeError(msg) if query == "": diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index 90e94908d..7c0de196a 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -15,8 +15,8 @@ def test_init_default(self): """ Test default initialization parameters for FastembedSparseDocumentEmbedder. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.batch_size == 32 @@ -31,7 +31,7 @@ def test_init_with_parameters(self): Test custom initialization parameters for FastembedSparseDocumentEmbedder. """ embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, batch_size=64, @@ -41,7 +41,7 @@ def test_init_with_parameters(self): meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.batch_size == 64 @@ -55,12 +55,12 @@ def test_to_dict(self): """ Test serialization of FastembedSparseDocumentEmbedder to a dictionary, using default initialization parameters. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "batch_size": 32, @@ -78,7 +78,7 @@ def test_to_dict_with_custom_init_parameters(self): Test serialization of FastembedSparseDocumentEmbedder to a dictionary, using custom initialization parameters. """ embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, batch_size=64, @@ -92,7 +92,7 @@ def test_to_dict_with_custom_init_parameters(self): assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "batch_size": 64, @@ -113,7 +113,7 @@ def test_from_dict(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "batch_size": 32, @@ -125,7 +125,7 @@ def test_from_dict(self): }, } embedder = default_from_dict(FastembedSparseDocumentEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.batch_size == 32 @@ -143,7 +143,7 @@ def test_from_dict_with_custom_init_parameters(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "batch_size": 64, @@ -155,7 +155,7 @@ def test_from_dict_with_custom_init_parameters(self): }, } embedder = default_from_dict(FastembedSparseDocumentEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.batch_size == 64 @@ -172,11 +172,11 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", + model_name="prithivida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, @@ -190,7 +190,7 @@ def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() embedder.warm_up() @@ -211,7 +211,7 @@ def test_embed(self): """ Test for checking output dimensions and embedding dimensions. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") embedder.embedding_backend = MagicMock() embedder.embedding_backend.embed = lambda x, **kwargs: self._generate_mocked_sparse_embedding( # noqa: ARG005 len(x) @@ -235,7 +235,7 @@ def test_embed_incorrect_input_format(self): """ Test for checking incorrect input format when creating embedding. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") string_input = "text" list_integers_input = [1, 2, 3] @@ -330,7 +330,7 @@ def test_run_with_model_kwargs(self): @pytest.mark.integration def test_run(self): embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index 4f438fd15..9b73f5f3a 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -15,8 +15,8 @@ def test_init_default(self): """ Test default initialization parameters for FastembedSparseTextEmbedder. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.progress_bar is True @@ -27,13 +27,13 @@ def test_init_with_parameters(self): Test custom initialization parameters for FastembedSparseTextEmbedder. """ embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, progress_bar=False, parallel=1, ) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.progress_bar is False @@ -43,12 +43,12 @@ def test_to_dict(self): """ Test serialization of FastembedSparseTextEmbedder to a dictionary, using default initialization parameters. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "progress_bar": True, @@ -63,7 +63,7 @@ def test_to_dict_with_custom_init_parameters(self): Test serialization of FastembedSparseTextEmbedder to a dictionary, using custom initialization parameters. """ embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, progress_bar=False, @@ -74,7 +74,7 @@ def test_to_dict_with_custom_init_parameters(self): assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "progress_bar": False, @@ -91,7 +91,7 @@ def test_from_dict(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "progress_bar": True, @@ -99,7 +99,7 @@ def test_from_dict(self): }, } embedder = default_from_dict(FastembedSparseTextEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.progress_bar is True @@ -112,7 +112,7 @@ def test_from_dict_with_custom_init_parameters(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "progress_bar": False, @@ -120,7 +120,7 @@ def test_from_dict_with_custom_init_parameters(self): }, } embedder = default_from_dict(FastembedSparseTextEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.progress_bar is False @@ -133,11 +133,11 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", + model_name="prithivida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, @@ -151,7 +151,7 @@ def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() embedder.warm_up() @@ -252,7 +252,7 @@ def test_run_with_model_kwargs(self): @pytest.mark.integration def test_run(self): embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 8f09db79a..7171b0069 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,15 +1,23 @@ # Changelog +## [integrations/google_ai-v3.0.2] - 2024-11-19 + +### ๐Ÿ› Bug Fixes + +- Fix missing usage metadata in GoogleAIGeminiChatGenerator (#1195) + + ## [integrations/google_ai-v3.0.0] - 2024-11-12 ### ๐Ÿ› Bug Fixes - `GoogleAIGeminiGenerator` - remove support for tools and change output type (#1177) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Adopt uv as installer (#1142) + ## [integrations/google_ai-v2.0.1] - 2024-10-15 ### ๐Ÿš€ Features @@ -26,16 +34,22 @@ - Do not retry tests in `hatch run test` command (#954) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) + +### ๐Ÿงน Chores + - Update ruff invocation to include check parameter (#853) - Update ruff linting scripts and settings (#1105) -### Docs +### ๐ŸŒ€ Miscellaneous +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Fix Google AI tests failing (#885) - Update GeminiGenerator docstrings (#964) - Update GoogleChatGenerator docstrings (#962) +- Feat: enable streaming in GoogleAIGemini (#1016) ## [integrations/google_ai-v1.1.0] - 2024-06-05 @@ -43,31 +57,50 @@ - Handle `TypeError: Could not create Blob` in `GoogleAIGeminiChatGenerator` (#772) +### ๐ŸŒ€ Miscellaneous + +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Fix Google AI integration tests (#786) +- Test: Fix tests skipping for Google AI integration (#788) + ## [integrations/google_ai-v1.0.0] - 2024-03-27 ### ๐Ÿ› Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### ๐Ÿ“š Documentation - Update category slug (#442) - Disable-class-def (#556) +### ๐ŸŒ€ Miscellaneous + +- Google AI - review docstrings (#533) +- Make tests show coverage (#566) +- Remove references to Python 3.7 (#601) +- Google Generators: change `answers` to `replies` (#626) + ## [integrations/google_ai-v0.2.0] - 2024-02-15 -### Google_ai +### ๐ŸŒ€ Miscellaneous - Create api docs (#354) +- Google AI - new secrets management (#424) ## [integrations/google_ai-v0.1.0] - 2024-01-25 -### Refact +### ๐ŸŒ€ Miscellaneous +- Add docstrings for `GoogleAIGeminiGenerator` and `GoogleAIGeminiChatGenerator` (#175) - [**breaking**] Adjust import paths (#268) ## [integrations/google_ai-v0.0.1] - 2024-01-03 +### ๐ŸŒ€ Miscellaneous + +- Gemini with Makersuite (#156) +- Fix google_ai integration versioning + diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py index 2b77c813f..c62129f9d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/__init__.py @@ -4,4 +4,4 @@ from .chat.gemini import GoogleAIGeminiChatGenerator from .gemini import GoogleAIGeminiGenerator -__all__ = ["GoogleAIGeminiGenerator", "GoogleAIGeminiChatGenerator"] +__all__ = ["GoogleAIGeminiChatGenerator", "GoogleAIGeminiGenerator"] diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 8efa8cda7..ef7d583be 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -313,25 +313,35 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess """ replies: List[ChatMessage] = [] metadata = response_body.to_dict() + + # currently Google only supports one candidate and usage metadata reflects this + # this should be refactored when multiple candidates are supported + usage_metadata_openai_format = {} + + usage_metadata = metadata.get("usage_metadata") + if usage_metadata: + usage_metadata_openai_format = { + "prompt_tokens": usage_metadata["prompt_token_count"], + "completion_tokens": usage_metadata["candidates_token_count"], + "total_tokens": usage_metadata["total_token_count"], + } + for idx, candidate in enumerate(response_body.candidates): candidate_metadata = metadata["candidates"][idx] candidate_metadata.pop("content", None) # we remove content from the metadata + if usage_metadata_openai_format: + candidate_metadata["usage"] = usage_metadata_openai_format for part in candidate.content.parts: if part.text != "": - replies.append( - ChatMessage(content=part.text, role=ChatRole.ASSISTANT, name=None, meta=candidate_metadata) - ) + replies.append(ChatMessage.from_assistant(content=part.text, meta=candidate_metadata)) elif part.function_call: candidate_metadata["function_call"] = part.function_call - replies.append( - ChatMessage( - content=dict(part.function_call.args.items()), - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=candidate_metadata, - ) + new_message = ChatMessage.from_assistant( + content=dict(part.function_call.args.items()), meta=candidate_metadata ) + new_message.name = part.function_call.name + replies.append(new_message) return replies def _get_stream_response( @@ -353,18 +363,13 @@ def _get_stream_response( for part in candidate["content"]["parts"]: if "text" in part and part["text"] != "": content = part["text"] - replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) + replies.append(ChatMessage.from_assistant(content=content, meta=metadata)) elif "function_call" in part and len(part["function_call"]) > 0: metadata["function_call"] = part["function_call"] content = part["function_call"]["args"] - replies.append( - ChatMessage( - content=content, - role=ChatRole.ASSISTANT, - name=part["function_call"]["name"], - meta=metadata, - ) - ) + new_message = ChatMessage.from_assistant(content=content, meta=metadata) + new_message.name = part["function_call"]["name"] + replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index c4372db0d..cb42f0ff8 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -295,5 +295,11 @@ def test_past_conversation(): ] response = gemini_chat.run(messages=messages) assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + replies = response["replies"] + assert len(replies) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + assert all("usage" in reply.meta for reply in replies) + assert all("prompt_tokens" in reply.meta["usage"] for reply in replies) + assert all("completion_tokens" in reply.meta["usage"] for reply in replies) + assert all("total_tokens" in reply.meta["usage"] for reply in replies) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index ed2cc3c3b..ea2a8fb18 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/google_vertex-v3.0.0] - 2024-11-14 + +### ๐Ÿ› Bug Fixes + +- VertexAIGeminiGenerator - remove support for tools and change output type (#1180) + +### โš™๏ธ Miscellaneous Tasks + +- Fix Vertex tests (#1163) + ## [integrations/google_vertex-v2.2.0] - 2024-10-23 ### ๐Ÿ› Bug Fixes diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py index 07c2a5260..e5f556637 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/__init__.py @@ -11,8 +11,8 @@ __all__ = [ "VertexAICodeGenerator", - "VertexAIGeminiGenerator", "VertexAIGeminiChatGenerator", + "VertexAIGeminiGenerator", "VertexAIImageCaptioner", "VertexAIImageGenerator", "VertexAIImageQA", diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index c52f76dc6..c94367b41 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -279,19 +279,14 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]: # Remove content from metadata metadata.pop("content", None) if part._raw_part.text != "": - replies.append( - ChatMessage(content=part._raw_part.text, role=ChatRole.ASSISTANT, name=None, meta=metadata) - ) + replies.append(ChatMessage.from_assistant(content=part._raw_part.text, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call - replies.append( - ChatMessage( - content=dict(part.function_call.args.items()), - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=metadata, - ) + new_message = ChatMessage.from_assistant( + content=dict(part.function_call.args.items()), meta=metadata ) + new_message.name = part.function_call.name + replies.append(new_message) return replies def _get_stream_response( @@ -313,18 +308,13 @@ def _get_stream_response( for part in candidate.content.parts: if part._raw_part.text: content = chunk.text - replies.append(ChatMessage(content, role=ChatRole.ASSISTANT, name=None, meta=metadata)) + replies.append(ChatMessage.from_assistant(content, meta=metadata)) elif part.function_call: metadata["function_call"] = part.function_call content = dict(part.function_call.args.items()) - replies.append( - ChatMessage( - content=content, - role=ChatRole.ASSISTANT, - name=part.function_call.name, - meta=metadata, - ) - ) + new_message = ChatMessage.from_assistant(content, meta=metadata) + new_message.name = part.function_call.name + replies.append(new_message) streaming_callback(StreamingChunk(content=content, meta=metadata)) return replies diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 737f2e668..c9473b428 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -15,8 +15,6 @@ HarmBlockThreshold, HarmCategory, Part, - Tool, - ToolConfig, ) logger = logging.getLogger(__name__) @@ -50,6 +48,16 @@ class VertexAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs or "tool_config" in kwargs: + msg = ( + "VertexAIGeminiGenerator does not support `tools` and `tool_config` parameters. " + "Use VertexAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(VertexAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -58,8 +66,6 @@ def __init__( location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, - tool_config: Optional[ToolConfig] = None, system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): @@ -86,10 +92,6 @@ def __init__( for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. - :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) - the list of supported arguments. - :param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. @@ -105,8 +107,6 @@ def __init__( # model parameters self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._tool_config = tool_config self._system_instruction = system_instruction self._streaming_callback = streaming_callback @@ -115,8 +115,6 @@ def __init__( self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, ) @@ -132,18 +130,6 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A "stop_sequences": config._raw_generation_config.stop_sequences, } - def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: - """Serializes the ToolConfig object into a dictionary.""" - - mode = tool_config._gapic_tool_config.function_calling_config.mode - allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names - config_dict = {"function_calling_config": {"mode": mode}} - - if allowed_function_names: - config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names - - return config_dict - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -160,15 +146,10 @@ def to_dict(self) -> Dict[str, Any]: location=self._location, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -184,22 +165,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": Deserialized component. """ - def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: - """Deserializes the ToolConfig object from a dictionary.""" - function_calling_config = config_dict["function_calling_config"] - return ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=function_calling_config["mode"], - allowed_function_names=function_calling_config.get("allowed_function_names"), - ) - ) - - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config) if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -215,7 +182,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -257,12 +224,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[str]: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 73c99fe2f..614b83909 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -161,13 +161,13 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { - "type_": "OBJECT", + "type": "OBJECT", "properties": { "location": { - "type_": "STRING", + "type": "STRING", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + "unit": {"type": "STRING", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], "property_ordering": ["location", "unit"], diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 277851224..ff692c6f4 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,38 +1,17 @@ from unittest.mock import MagicMock, Mock, patch +import pytest from haystack import Pipeline from haystack.components.builders import PromptBuilder from haystack.dataclasses import StreamingChunk from vertexai.generative_models import ( - FunctionDeclaration, GenerationConfig, HarmBlockThreshold, HarmCategory, - Tool, - ToolConfig, ) from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") @@ -48,32 +27,28 @@ def test_init(mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] - assert gemini._tool_config == tool_config assert gemini._system_instruction == "Please provide brief answers." +def test_init_fails_with_tools_or_tool_config(): + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tools=["tool1", "tool2"]) + + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tool_config={"custom": "config"}) + + @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): @@ -88,8 +63,6 @@ def test_to_dict(_mock_vertexai_init, _mock_generative_model): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, - "tool_config": None, "system_instruction": None, }, } @@ -108,21 +81,11 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) assert gemini.to_dict() == { @@ -141,34 +104,6 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, "streaming_callback": None, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - "property_ordering": ["location", "unit"], - }, - } - ] - } - ], - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -186,9 +121,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, - "tools": None, "streaming_callback": None, - "tool_config": None, "system_instruction": None, }, } @@ -198,8 +131,6 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id is None assert gemini._location is None assert gemini._safety_settings is None - assert gemini._tools is None - assert gemini._tool_config is None assert gemini._system_instruction is None assert gemini._generation_config is None @@ -223,40 +154,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "stop_sequences": ["stop"], }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - "description": "Get the current weather in a given location", - } - ] - } - ], "streaming_callback": None, - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -266,13 +164,8 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._project_id == "TestID123" assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) - assert isinstance(gemini._tool_config, ToolConfig) assert gemini._system_instruction == "Please provide brief answers." - assert ( - gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY - ) @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") diff --git a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py index 734798f46..c05c37733 100644 --- a/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py +++ b/integrations/instructor_embedders/src/haystack_integrations/components/embedders/instructor_embedders/instructor_document_embedder.py @@ -158,7 +158,7 @@ def run(self, documents: List[Document]): param documents: A list of Documents to embed. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "InstructorDocumentEmbedder expects a list of Documents as input. " "In case you want to embed a list of strings, please use the InstructorTextEmbedder." diff --git a/integrations/jina/CHANGELOG.md b/integrations/jina/CHANGELOG.md index 918a764f0..01de2abc1 100644 --- a/integrations/jina/CHANGELOG.md +++ b/integrations/jina/CHANGELOG.md @@ -1,17 +1,59 @@ # Changelog +## [integrations/jina-v0.5.1] - 2024-11-26 + +### ๐Ÿงน Chores + +- Fix linting/isort (#1215) + +### ๐ŸŒ€ Miscellaneous + +- Fix: `JinaReaderConnector` - fix the name of the output edge (#1217) + +## [integrations/jina-v0.5.0] - 2024-11-21 + +### ๐Ÿš€ Features + +- Add `JinaReaderConnector` (#1150) + +### ๐Ÿ“š Documentation + +- Update docstrings of JinaDocumentEmbedder and JinaTextEmbedder (#1092) + +### โš™๏ธ CI + +- Adopt uv as installer (#1142) + +### ๐Ÿงน Chores + +- Update ruff linting scripts and settings (#1105) + + ## [integrations/jina-v0.4.0] - 2024-09-18 ### ๐Ÿงช Testing - Do not retry tests in `hatch run test` command (#954) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) + +### ๐Ÿงน Chores + - Update ruff invocation to include check parameter (#853) - Update Jina Embedder usage for V3 release (#1077) +### ๐ŸŒ€ Miscellaneous + +- Remove references to Python 3.7 (#601) +- Jina - add missing ranker to API reference (#610) +- Jina ranker: fix wrong URL in docstring (#628) +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: Jina - ruff update, don't ruff tests (#982) + ## [integrations/jina-v0.3.0] - 2024-03-19 ### ๐Ÿš€ Features @@ -22,13 +64,17 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### ๐Ÿ“š Documentation - Update category slug (#442) - Disable-class-def (#556) +### ๐ŸŒ€ Miscellaneous + +- Jina - remove dead code (#422) +- Jina - review docstrings (#504) +- Make tests show coverage (#566) + ## [integrations/jina-v0.2.0] - 2024-02-14 ### ๐Ÿš€ Features @@ -39,7 +85,7 @@ This PR will also push the docs to Readme - Update paths and titles (#397) -### Jina +### ๐ŸŒ€ Miscellaneous - Update secrets management (#411) @@ -47,18 +93,22 @@ This PR will also push the docs to Readme ### ๐Ÿ› Bug Fixes -- Fix project urls (#96) - - +- Fix project URLs (#96) ### ๐Ÿšœ Refactor - Use `hatch_vcs` to manage integrations versioning (#103) -### โš™๏ธ Miscellaneous Tasks +### ๐Ÿงน Chores - [**breaking**] Rename model_name to model in the Jina integration (#230) +### ๐ŸŒ€ Miscellaneous + +- Change metadata to meta (#152) +- Optimize API key reading (#162) +- Refact!:change import paths (#254) + ## [integrations/jina-v0.0.1] - 2023-12-11 ### ๐Ÿš€ Features diff --git a/integrations/jina/examples/jina_reader_connector.py b/integrations/jina/examples/jina_reader_connector.py new file mode 100644 index 000000000..24b6f5db3 --- /dev/null +++ b/integrations/jina/examples/jina_reader_connector.py @@ -0,0 +1,47 @@ +# to make use of the JinaReaderConnector, we first need to install the Haystack integration +# pip install jina-haystack + +# then we must set the JINA_API_KEY environment variable +# export JINA_API_KEY= + + +from haystack_integrations.components.connectors.jina import JinaReaderConnector + +# we can use the JinaReaderConnector to process a URL and return the textual content of the page +reader = JinaReaderConnector(mode="read") +query = "https://example.com" +result = reader.run(query=query) + +print(result) +# {'documents': [Document(id=fa3e51e4ca91828086dca4f359b6e1ea2881e358f83b41b53c84616cb0b2f7cf, +# content: 'This domain is for use in illustrative examples in documents. You may use this domain in literature ...', +# meta: {'title': 'Example Domain', 'description': '', 'url': 'https://example.com/', 'usage': {'tokens': 42}})]} + + +# we can perform a web search by setting the mode to "search" +reader = JinaReaderConnector(mode="search") +query = "UEFA Champions League 2024" +result = reader.run(query=query) + +print(result) +# {'documents': Document(id=6a71abf9955594232037321a476d39a835c0cb7bc575d886ee0087c973c95940, +# content: '2024/25 UEFA Champions League: Matches, draw, final, key dates | UEFA Champions League | UEFA.com...', +# meta: {'title': '2024/25 UEFA Champions League: Matches, draw, final, key dates', +# 'description': 'What are the match dates? Where is the 2025 final? How will the competition work?', +# 'url': 'https://www.uefa.com/uefachampionsleague/news/...', +# 'usage': {'tokens': 5581}}), ...]} + + +# finally, we can perform fact-checking by setting the mode to "ground" (experimental) +reader = JinaReaderConnector(mode="ground") +query = "ChatGPT was launched in 2017" +result = reader.run(query=query) + +print(result) +# {'documents': [Document(id=f0c964dbc1ebb2d6584c8032b657150b9aa6e421f714cc1b9f8093a159127f0c, +# content: 'The statement that ChatGPT was launched in 2017 is incorrect. Multiple references confirm that ChatG...', +# meta: {'factuality': 0, 'result': False, 'references': [ +# {'url': 'https://en.wikipedia.org/wiki/ChatGPT', +# 'keyQuote': 'ChatGPT is a generative artificial intelligence (AI) chatbot developed by OpenAI and launched in 2022.', +# 'isSupportive': False}, ...], +# 'usage': {'tokens': 10188}})]} diff --git a/integrations/jina/pydoc/config.yml b/integrations/jina/pydoc/config.yml index 8c7a241f6..2d0ef4f87 100644 --- a/integrations/jina/pydoc/config.yml +++ b/integrations/jina/pydoc/config.yml @@ -6,6 +6,7 @@ loaders: "haystack_integrations.components.embedders.jina.document_embedder", "haystack_integrations.components.embedders.jina.text_embedder", "haystack_integrations.components.rankers.jina.ranker", + "haystack_integrations.components.connectors.jina.reader", ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index c89eeacb4..e3af086d0 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -132,18 +132,23 @@ ban-relative-imports = "parents" [tool.ruff.lint.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] +# examples can contain "print" commands +"examples/**/*" = ["T201"] [tool.coverage.run] source = ["haystack_integrations"] branch = true parallel = false - [tool.coverage.report] omit = ["*/tests/*", "*/__init__.py"] show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + [[tool.mypy.overrides]] module = ["haystack.*", "haystack_integrations.*", "pytest.*"] ignore_missing_imports = true diff --git a/integrations/jina/src/haystack_integrations/components/connectors/jina/__init__.py b/integrations/jina/src/haystack_integrations/components/connectors/jina/__init__.py new file mode 100644 index 000000000..95368df21 --- /dev/null +++ b/integrations/jina/src/haystack_integrations/components/connectors/jina/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .reader import JinaReaderConnector +from .reader_mode import JinaReaderMode + +__all__ = ["JinaReaderConnector", "JinaReaderMode"] diff --git a/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py new file mode 100644 index 000000000..618cacb4e --- /dev/null +++ b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import json +from typing import Any, Dict, List, Optional, Union +from urllib.parse import quote + +import requests +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace + +from .reader_mode import JinaReaderMode + +READER_ENDPOINT_URL_BY_MODE = { + JinaReaderMode.READ: "https://r.jina.ai/", + JinaReaderMode.SEARCH: "https://s.jina.ai/", + JinaReaderMode.GROUND: "https://g.jina.ai/", +} + + +@component +class JinaReaderConnector: + """ + A component that interacts with Jina AI's reader service to process queries and return documents. + + This component supports different modes of operation: `read`, `search`, and `ground`. + + Usage example: + ```python + from haystack_integrations.components.connectors.jina import JinaReaderConnector + + reader = JinaReaderConnector(mode="read") + query = "https://example.com" + result = reader.run(query=query) + document = result["documents"][0] + print(document.content) + + >>> "This domain is for use in illustrative examples..." + ``` + """ + + def __init__( + self, + mode: Union[JinaReaderMode, str], + api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008 + json_response: bool = True, + ): + """ + Initialize a JinaReader instance. + + :param mode: The operation mode for the reader (`read`, `search` or `ground`). + - `read`: process a URL and return the textual content of the page. + - `search`: search the web and return textual content of the most relevant pages. + - `ground`: call the grounding engine to perform fact checking. + For more information on the modes, see the [Jina Reader documentation](https://jina.ai/reader/). + :param api_key: The Jina API key. It can be explicitly provided or automatically read from the + environment variable JINA_API_KEY (recommended). + :param json_response: Controls the response format from the Jina Reader API. + If `True`, requests a JSON response, resulting in Documents with rich structured metadata. + If `False`, requests a raw response, resulting in one Document with minimal metadata. + """ + self.api_key = api_key + self.json_response = json_response + + if isinstance(mode, str): + mode = JinaReaderMode.from_str(mode) + self.mode = mode + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + api_key=self.api_key.to_dict(), + mode=str(self.mode), + json_response=self.json_response, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "JinaReaderConnector": + """ + Deserializes the component from a dictionary. + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _json_to_document(self, data: dict) -> Document: + """ + Convert a JSON response/record to a Document, depending on the reader mode. + """ + if self.mode == JinaReaderMode.GROUND: + content = data.pop("reason") + else: + content = data.pop("content") + document = Document(content=content, meta=data) + return document + + @component.output_types(documents=List[Document]) + def run(self, query: str, headers: Optional[Dict[str, str]] = None): + """ + Process the query/URL using the Jina AI reader service. + + :param query: The query string or URL to process. + :param headers: Optional headers to include in the request for customization. Refer to the + [Jina Reader documentation](https://jina.ai/reader/) for more information. + + :returns: + A dictionary with the following keys: + - `documents`: A list of `Document` objects. + """ + headers = headers or {} + headers["Authorization"] = f"Bearer {self.api_key.resolve_value()}" + + if self.json_response: + headers["Accept"] = "application/json" + + endpoint_url = READER_ENDPOINT_URL_BY_MODE[self.mode] + encoded_target = quote(query, safe="") + url = f"{endpoint_url}{encoded_target}" + + response = requests.get(url, headers=headers, timeout=60) + + # raw response: we just return a single Document with text + if not self.json_response: + meta = {"content_type": response.headers["Content-Type"], "query": query} + return {"documents": [Document(content=response.content, meta=meta)]} + + response_json = json.loads(response.content).get("data", {}) + if self.mode == JinaReaderMode.SEARCH: + documents = [self._json_to_document(record) for record in response_json] + return {"documents": documents} + + return {"documents": [self._json_to_document(response_json)]} diff --git a/integrations/jina/src/haystack_integrations/components/connectors/jina/reader_mode.py b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader_mode.py new file mode 100644 index 000000000..2ccf7250b --- /dev/null +++ b/integrations/jina/src/haystack_integrations/components/connectors/jina/reader_mode.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from enum import Enum + + +class JinaReaderMode(Enum): + """ + Enum representing modes for the Jina Reader. + + Modes: + READ: Process a URL and return the textual content of the page. + SEARCH: Search the web and return the textual content of the most relevant pages. + GROUND: Call the grounding engine to perform fact checking. + + """ + + READ = "read" + SEARCH = "search" + GROUND = "ground" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "JinaReaderMode": + """ + Create the reader mode from a string. + + :param string: + String to convert. + :returns: + Reader mode. + """ + enum_map = {e.value: e for e in JinaReaderMode} + reader_mode = enum_map.get(string) + if reader_mode is None: + msg = f"Unknown reader mode '{string}'. Supported modes are: {list(enum_map.keys())}" + raise ValueError(msg) + return reader_mode diff --git a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py index 715092b8a..103132faf 100644 --- a/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py +++ b/integrations/jina/src/haystack_integrations/components/embedders/jina/document_embedder.py @@ -200,7 +200,7 @@ def run(self, documents: List[Document]): - `meta`: A dictionary with metadata including the model name and usage statistics. :raises TypeError: If the input is not a list of Documents. """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "JinaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the JinaTextEmbedder." diff --git a/integrations/jina/tests/test_reader_connector.py b/integrations/jina/tests/test_reader_connector.py new file mode 100644 index 000000000..449f73df8 --- /dev/null +++ b/integrations/jina/tests/test_reader_connector.py @@ -0,0 +1,141 @@ +import json +import os +from unittest.mock import patch + +import pytest +from haystack import Document +from haystack.utils import Secret + +from haystack_integrations.components.connectors.jina import JinaReaderConnector, JinaReaderMode + + +class TestJinaReaderConnector: + def test_init_with_custom_parameters(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="read", api_key=Secret.from_env_var("TEST_KEY"), json_response=False) + + assert reader.mode == JinaReaderMode.READ + assert reader.api_key.resolve_value() == "test-api-key" + assert reader.json_response is False + + def test_init_with_invalid_mode(self): + with pytest.raises(ValueError): + JinaReaderConnector(mode="INVALID") + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="search", api_key=Secret.from_env_var("TEST_KEY"), json_response=True) + + serialized = reader.to_dict() + + assert serialized["type"] == "haystack_integrations.components.connectors.jina.reader.JinaReaderConnector" + assert "init_parameters" in serialized + + init_params = serialized["init_parameters"] + assert init_params["mode"] == "search" + assert init_params["json_response"] is True + assert "api_key" in init_params + assert init_params["api_key"]["type"] == "env_var" + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "test-api-key") + component_dict = { + "type": "haystack_integrations.components.connectors.jina.reader.JinaReaderConnector", + "init_parameters": { + "api_key": {"type": "env_var", "env_vars": ["JINA_API_KEY"], "strict": True}, + "mode": "read", + "json_response": True, + }, + } + + reader = JinaReaderConnector.from_dict(component_dict) + + assert isinstance(reader, JinaReaderConnector) + assert reader.mode == JinaReaderMode.READ + assert reader.json_response is True + assert reader.api_key.resolve_value() == "test-api-key" + + def test_json_to_document_read_mode(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="read") + + data = {"content": "Mocked content", "title": "Mocked Title", "url": "https://example.com"} + document = reader._json_to_document(data) + + assert isinstance(document, Document) + assert document.content == "Mocked content" + assert document.meta["title"] == "Mocked Title" + assert document.meta["url"] == "https://example.com" + + def test_json_to_document_ground_mode(self, monkeypatch): + monkeypatch.setenv("TEST_KEY", "test-api-key") + reader = JinaReaderConnector(mode="ground") + + data = { + "factuality": 0, + "result": False, + "reason": "The statement is contradicted by...", + "references": [{"url": "https://example.com", "keyQuote": "Mocked key quote", "isSupportive": False}], + } + + document = reader._json_to_document(data) + assert isinstance(document, Document) + assert document.content == "The statement is contradicted by..." + assert document.meta["factuality"] == 0 + assert document.meta["result"] is False + assert document.meta["references"] == [ + {"url": "https://example.com", "keyQuote": "Mocked key quote", "isSupportive": False} + ] + + @patch("requests.get") + def test_run_with_mocked_response(self, mock_get, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "test-api-key") + mock_json_response = { + "data": {"content": "Mocked content", "title": "Mocked Title", "url": "https://example.com"} + } + mock_get.return_value.content = json.dumps(mock_json_response).encode("utf-8") + mock_get.return_value.headers = {"Content-Type": "application/json"} + + reader = JinaReaderConnector(mode="read") + result = reader.run(query="https://example.com") + + assert mock_get.call_count == 1 + assert mock_get.call_args[0][0] == "https://r.jina.ai/https%3A%2F%2Fexample.com" + assert mock_get.call_args[1]["headers"] == { + "Authorization": "Bearer test-api-key", + "Accept": "application/json", + } + + assert len(result) == 1 + document = result["documents"][0] + assert isinstance(document, Document) + assert document.content == "Mocked content" + assert document.meta["title"] == "Mocked Title" + assert document.meta["url"] == "https://example.com" + + @pytest.mark.skipif(not os.environ.get("JINA_API_KEY", None), reason="JINA_API_KEY env var not set") + @pytest.mark.integration + def test_run_reader_mode(self): + reader = JinaReaderConnector(mode="read") + result = reader.run(query="https://example.com") + + assert len(result) == 1 + document = result["documents"][0] + assert isinstance(document, Document) + assert "This domain is for use in illustrative examples" in document.content + assert document.meta["title"] == "Example Domain" + assert document.meta["url"] == "https://example.com/" + + @pytest.mark.skipif(not os.environ.get("JINA_API_KEY", None), reason="JINA_API_KEY env var not set") + @pytest.mark.integration + def test_run_search_mode(self): + reader = JinaReaderConnector(mode="search") + result = reader.run(query="When was Jina AI founded?") + + assert len(result) >= 1 + for doc in result["documents"]: + assert isinstance(doc, Document) + assert doc.content + assert "title" in doc.meta + assert "url" in doc.meta + assert "description" in doc.meta diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index 29be7f838..7cf1cc0c4 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/langfuse-v0.6.0] - 2024-11-18 + +### ๐Ÿš€ Features + +- Add support for ttft (#1161) + +### โš™๏ธ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/langfuse-v0.5.0] - 2024-10-01 ### โš™๏ธ Miscellaneous Tasks diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index c9c8a354e..d6f2535c7 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -1,9 +1,10 @@ import contextlib -import logging import os +from contextvars import ContextVar from datetime import datetime -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union +from haystack import logging from haystack.components.generators.openai_utils import _convert_message_to_openai_format from haystack.dataclasses import ChatMessage from haystack.tracing import Span, Tracer, tracer @@ -32,6 +33,17 @@ ] _ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS +# These are the keys used by Haystack for traces and span. +# We keep them here to avoid making typos when using them. +_PIPELINE_RUN_KEY = "haystack.pipeline.run" +_COMPONENT_NAME_KEY = "haystack.component.name" +_COMPONENT_TYPE_KEY = "haystack.component.type" +_COMPONENT_OUTPUT_KEY = "haystack.component.output" + +# Context var used to keep track of tracing related info. +# This mainly useful for parents spans. +tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context") + class LangfuseSpan(Span): """ @@ -86,7 +98,7 @@ def set_content_tag(self, key: str, value: Any) -> None: self._data[key] = value - def raw_span(self) -> Any: + def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]": """ Return the underlying span instance. @@ -115,41 +127,58 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: and only accessible to the Langfuse account owner. """ self._tracer = tracer - self._context: list[LangfuseSpan] = [] + self._context: List[LangfuseSpan] = [] self._name = name self._public = public self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" @contextlib.contextmanager - def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]: - """ - Start and manage a new trace span. - :param operation_name: The name of the operation. - :param tags: A dictionary of tags to attach to the span. - :return: A context manager yielding the span. - """ + def trace( + self, operation_name: str, tags: Optional[Dict[str, Any]] = None, parent_span: Optional[Span] = None + ) -> Iterator[Span]: tags = tags or {} - span_name = tags.get("haystack.component.name", operation_name) - - if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS: - span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name)) + span_name = tags.get(_COMPONENT_NAME_KEY, operation_name) + + # Create new span depending whether there's a parent span or not + if not parent_span: + if operation_name != _PIPELINE_RUN_KEY: + logger.warning( + "Creating a new trace without a parent span is not recommended for operation '{operation_name}'.", + operation_name=operation_name, + ) + # Create a new trace if no parent span is provided + context = tracing_context_var.get({}) + span = LangfuseSpan( + self._tracer.trace( + name=self._name, + public=self._public, + id=context.get("trace_id"), + user_id=context.get("user_id"), + session_id=context.get("session_id"), + tags=context.get("tags"), + version=context.get("version"), + ) + ) + elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS: + span = LangfuseSpan(parent_span.raw_span().generation(name=span_name)) else: - span = LangfuseSpan(self.current_span().raw_span().span(name=span_name)) + span = LangfuseSpan(parent_span.raw_span().span(name=span_name)) self._context.append(span) span.set_tags(tags) yield span - if tags.get("haystack.component.type") in _SUPPORTED_GENERATORS: - meta = span._data.get("haystack.component.output", {}).get("meta") + # Update span metadata based on component type + if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS: + # Haystack returns one meta dict for each message, but the 'usage' value + # is always the same, let's just pick the first item + meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta") if meta: - # Haystack returns one meta dict for each message, but the 'usage' value - # is always the same, let's just pick the first item m = meta[0] span._span.update(usage=m.get("usage") or None, model=m.get("model")) - elif tags.get("haystack.component.type") in _SUPPORTED_CHAT_GENERATORS: - replies = span._data.get("haystack.component.output", {}).get("replies") + elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS: + replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies") if replies: meta = replies[0].meta completion_start_time = meta.get("completion_start_time") @@ -165,36 +194,24 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I completion_start_time=completion_start_time, ) - pipeline_input = tags.get("haystack.pipeline.input_data", None) - if pipeline_input: - span._span.update(input=tags["haystack.pipeline.input_data"]) - pipeline_output = tags.get("haystack.pipeline.output_data", None) - if pipeline_output: - span._span.update(output=tags["haystack.pipeline.output_data"]) - - span.raw_span().end() + raw_span = span.raw_span() + if isinstance(raw_span, langfuse.client.StatefulSpanClient): + raw_span.end() self._context.pop() - if len(self._context) == 1: - # The root span has to be a trace, which need to be removed from the context after the pipeline run - self._context.pop() - - if self.enforce_flush: - self.flush() + if self.enforce_flush: + self.flush() def flush(self): self._tracer.flush() - def current_span(self) -> Span: + def current_span(self) -> Optional[Span]: """ - Return the currently active span. + Return the current active span. - :return: The currently active span. + :return: The current span if available, else None. """ - if not self._context: - # The root span has to be a trace - self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public))) - return self._context[-1] + return self._context[-1] if self._context else None def get_trace_url(self) -> str: """ diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index 9ee8e5dc4..42ae1d07d 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -69,7 +69,7 @@ def test_create_new_span(self): tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False) with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span: - assert len(tracer._context) == 2, "The trace span should have been added to the the root context span" + assert len(tracer._context) == 1, "The trace span should have been added to the the root context span" assert span.raw_span().operation_name == "operation_name" assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"} diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 657b6eae1..e5737b861 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -52,25 +52,28 @@ def test_tracing_integration(llm_class, env_var, expected_trace): assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] - # add a random delay between 1 and 3 seconds to make sure the trace is flushed - # and that the trace is available in Langfuse when we fetch it below - time.sleep(random.uniform(1, 3)) - - url = "https://cloud.langfuse.com/api/public/traces/" trace_url = response["tracer"]["trace_url"] uuid = os.path.basename(urlparse(trace_url).path) + url = f"https://cloud.langfuse.com/api/public/traces/{uuid}" - try: - response = requests.get( - url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) + # Poll the Langfuse API a bit as the trace might not be ready right away + attempts = 5 + delay = 1 + while attempts >= 0: + res = requests.get( + url, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) ) - assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" + if attempts > 0 and res.status_code != 200: + attempts -= 1 + time.sleep(delay) + delay *= 2 + continue + assert res.status_code == 200, f"Failed to retrieve data from Langfuse API: {res.status_code}" # check if the trace contains the expected LLM name - assert expected_trace in str(response.content) + assert expected_trace in str(res.content) # check if the trace contains the expected generation span - assert "GENERATION" in str(response.content) + assert "GENERATION" in str(res.content) # check if the trace contains the expected user_id - assert "user_42" in str(response.content) - except requests.exceptions.RequestException as e: - pytest.fail(f"Failed to retrieve data from Langfuse API: {e}") + assert "user_42" in str(res.content) + break diff --git a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py index 10b20d363..a85dbfd88 100644 --- a/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py +++ b/integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/__init__.py @@ -5,4 +5,4 @@ from .chat.chat_generator import LlamaCppChatGenerator from .generator import LlamaCppGenerator -__all__ = ["LlamaCppGenerator", "LlamaCppChatGenerator"] +__all__ = ["LlamaCppChatGenerator", "LlamaCppGenerator"] diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index 75b31d033..a536e431d 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/nvidia-v0.1.1] - 2024-11-14 + +### ๐Ÿ› Bug Fixes + +- Fixes to NvidiaRanker (#1191) + ## [integrations/nvidia-v0.1.0] - 2024-11-13 ### ๐Ÿš€ Features diff --git a/integrations/nvidia/README.md b/integrations/nvidia/README.md index e28f0ede9..558c34d28 100644 --- a/integrations/nvidia/README.md +++ b/integrations/nvidia/README.md @@ -38,7 +38,7 @@ hatch run test To only run unit tests: ``` -hatch run test -m"not integration" +hatch run test -m "not integration" ``` To run the linters `ruff` and `mypy`: diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index 7f0048c1b..586b50848 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "requests"] +dependencies = ["haystack-ai", "requests", "tqdm"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme" diff --git a/integrations/nvidia/src/haystack_integrations/__init__.py b/integrations/nvidia/src/haystack_integrations/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/__init__.py b/integrations/nvidia/src/haystack_integrations/components/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py index bc2d9372c..c6ecea7b1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py @@ -1,5 +1,9 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .document_embedder import NvidiaDocumentEmbedder from .text_embedder import NvidiaTextEmbedder from .truncate import EmbeddingTruncateMode -__all__ = ["NvidiaDocumentEmbedder", "NvidiaTextEmbedder", "EmbeddingTruncateMode"] +__all__ = ["EmbeddingTruncateMode", "NvidiaDocumentEmbedder", "NvidiaTextEmbedder"] diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index d746a75f4..b417fa737 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -1,13 +1,19 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from haystack import Document, component, default_from_dict, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace from tqdm import tqdm +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode +logger = logging.getLogger(__name__) _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -44,6 +50,7 @@ def __init__( meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", truncate: Optional[Union[EmbeddingTruncateMode, str]] = None, + timeout: Optional[float] = None, ): """ Create a NvidiaTextEmbedder component. @@ -71,8 +78,11 @@ def __init__( :param embedding_separator: Separator used to concatenate the meta fields to the Document text. :param truncate: - Specifies how inputs longer that the maximum token length should be truncated. + Specifies how inputs longer than the maximum token length should be truncated. If None the behavior is model-dependent, see the official documentation for more information. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ self.api_key = api_key @@ -95,6 +105,10 @@ def __init__( if is_hosted(api_url) and not self.model: # manually set default model self.model = "nvidia/nv-embedqa-e5-v5" + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout + def default_model(self): """Set default model in local NIM mode.""" valid_models = [ @@ -125,10 +139,11 @@ def warm_up(self): if self.truncate is not None: model_kwargs["truncate"] = str(self.truncate) self.backend = NimBackend( - self.model, + model=self.model, api_url=self.api_url, api_key=self.api_key, model_kwargs=model_kwargs, + timeout=self.timeout, ) self._initialized = True @@ -155,6 +170,7 @@ def to_dict(self) -> Dict[str, Any]: meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, truncate=str(self.truncate) if self.truncate is not None else None, + timeout=self.timeout, ) @classmethod @@ -167,7 +183,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder": :returns: The deserialized component. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: @@ -224,7 +242,7 @@ def run(self, documents: List[Document]): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - elif not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + elif not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "NvidiaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a string, please use the NvidiaTextEmbedder." @@ -233,8 +251,7 @@ def run(self, documents: List[Document]): for doc in documents: if not doc.content: - msg = f"Document '{doc.id}' has no content to embed." - raise ValueError(msg) + logger.warning(f"Document '{doc.id}' has no content to embed.") texts_to_embed = self._prepare_texts_to_embed(documents) embeddings, metadata = self._embed_batch(texts_to_embed, self.batch_size) diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 22bed8197..a93aa8caa 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -1,12 +1,18 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os import warnings from typing import Any, Dict, List, Optional, Union -from haystack import component, default_from_dict, default_to_dict +from haystack import component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode +logger = logging.getLogger(__name__) _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -41,6 +47,7 @@ def __init__( prefix: str = "", suffix: str = "", truncate: Optional[Union[EmbeddingTruncateMode, str]] = None, + timeout: Optional[float] = None, ): """ Create a NvidiaTextEmbedder component. @@ -61,6 +68,9 @@ def __init__( :param truncate: Specifies how inputs longer that the maximum token length should be truncated. If None the behavior is model-dependent, see the official documentation for more information. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ self.api_key = api_key @@ -79,6 +89,10 @@ def __init__( if is_hosted(api_url) and not self.model: # manually set default model self.model = "nvidia/nv-embedqa-e5-v5" + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout + def default_model(self): """Set default model in local NIM mode.""" valid_models = [ @@ -86,6 +100,12 @@ def default_model(self): ] name = next(iter(valid_models), None) if name: + logger.warning( + "Default model is set as: {model_name}. \n" + "Set model using model parameter. \n" + "To get available models use available_models property.", + model_name=name, + ) warnings.warn( f"Default model is set as: {name}. \n" "Set model using model parameter. \n" @@ -109,10 +129,11 @@ def warm_up(self): if self.truncate is not None: model_kwargs["truncate"] = str(self.truncate) self.backend = NimBackend( - self.model, + model=self.model, api_url=self.api_url, api_key=self.api_key, model_kwargs=model_kwargs, + timeout=self.timeout, ) self._initialized = True @@ -135,6 +156,7 @@ def to_dict(self) -> Dict[str, Any]: prefix=self.prefix, suffix=self.suffix, truncate=str(self.truncate) if self.truncate is not None else None, + timeout=self.timeout, ) @classmethod @@ -147,7 +169,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaTextEmbedder": :returns: The deserialized component. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) @component.output_types(embedding=List[float], meta=Dict[str, Any]) @@ -159,7 +183,7 @@ def run(self, text: str): The text to embed. :returns: A dictionary with the following keys and values: - - `embedding` - Embeddng of the text. + - `embedding` - Embedding of the text. - `meta` - Metadata on usage statistics, etc. :raises RuntimeError: If the component was not initialized. diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py index 3a8eb9d07..931c3cce3 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py index 18354ea17..b809d83b9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .generator import NvidiaGenerator __all__ = ["NvidiaGenerator"] diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 3eadcc5df..5047d0682 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + +import os import warnings from typing import Any, Dict, List, Optional @@ -48,6 +50,7 @@ def __init__( api_url: str = _DEFAULT_API_URL, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), model_arguments: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, ): """ Create a NvidiaGenerator component. @@ -69,6 +72,9 @@ def __init__( specific to a model. Search your model in the [NVIDIA NIM](https://ai.nvidia.com) to find the arguments it accepts. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ self._model = model self._api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/chat/completions"]) @@ -78,6 +84,9 @@ def __init__( self._backend: Optional[Any] = None self.is_hosted = is_hosted(api_url) + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout def default_model(self): """Set default model in local NIM mode.""" @@ -109,10 +118,11 @@ def warm_up(self): msg = "API key is required for hosted NVIDIA NIMs." raise ValueError(msg) self._backend = NimBackend( - self._model, + model=self._model, api_url=self._api_url, api_key=self._api_key, model_kwargs=self._model_arguments, + timeout=self.timeout, ) if not self.is_hosted and not self._model: diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py index 29cb2f7f5..05daa1c54 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .ranker import NvidiaRanker __all__ = ["NvidiaRanker"] diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 1553d1ac3..66203a490 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -1,12 +1,18 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os import warnings from typing import Any, Dict, List, Optional, Union -from haystack import Document, component, default_from_dict, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode from haystack_integrations.utils.nvidia import NimBackend, url_validation -from .truncate import RankerTruncateMode +logger = logging.getLogger(__name__) _DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3" @@ -51,8 +57,13 @@ def __init__( model: Optional[str] = None, truncate: Optional[Union[RankerTruncateMode, str]] = None, api_url: Optional[str] = None, - api_key: Optional[Secret] = None, + api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), top_k: int = 5, + query_prefix: str = "", + document_prefix: str = "", + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + timeout: Optional[float] = None, ): """ Create a NvidiaRanker component. @@ -67,6 +78,19 @@ def __init__( Custom API URL for the NVIDIA NIM. :param top_k: Number of documents to return. + :param query_prefix: + A string to add at the beginning of the query text before ranking. + Use it to prepend the text with an instruction, as required by reranking models like `bge`. + :param document_prefix: + A string to add at the beginning of each document before ranking. You can use it to prepend the document + with an instruction, as required by embedding models like `bge`. + :param meta_fields_to_embed: + List of metadata fields to embed with the document. + :param embedding_separator: + Separator to concatenate metadata fields to the document. + :param timeout: + Timeout for request calls, if not set it is inferred from the `NVIDIA_TIMEOUT` environment variable + or set to 60 by default. """ if model is not None and not isinstance(model, str): msg = "Ranker expects the `model` parameter to be a string." @@ -81,25 +105,34 @@ def __init__( raise TypeError(msg) # todo: detect default in non-hosted case (when api_url is provided) - self._model = model or _DEFAULT_MODEL - self._truncate = truncate - self._api_key = api_key + self.model = model or _DEFAULT_MODEL + self.truncate = truncate + self.api_key = api_key # if no api_url is provided, we're using a hosted model and can # - assume the default url will work, because there's only one model # - assume we won't call backend.models() if api_url is not None: - self._api_url = url_validation(api_url, None, ["v1/ranking"]) - self._endpoint = None # we let backend.rank() handle the endpoint + self.api_url = url_validation(api_url, None, ["v1/ranking"]) + self.endpoint = None # we let backend.rank() handle the endpoint else: - if self._model not in _MODEL_ENDPOINT_MAP: + if self.model not in _MODEL_ENDPOINT_MAP: msg = f"Model '{model}' is unknown. Please provide an api_url to access it." raise ValueError(msg) - self._api_url = None # we handle the endpoint - self._endpoint = _MODEL_ENDPOINT_MAP[self._model] + self.api_url = None # we handle the endpoint + self.endpoint = _MODEL_ENDPOINT_MAP[self.model] if api_key is None: self._api_key = Secret.from_env_var("NVIDIA_API_KEY") - self._top_k = top_k + self.top_k = top_k self._initialized = False + self._backend: Optional[Any] = None + + self.query_prefix = query_prefix + self.document_prefix = document_prefix + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", 60.0)) + self.timeout = timeout def to_dict(self) -> Dict[str, Any]: """ @@ -109,11 +142,16 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - model=self._model, - top_k=self._top_k, - truncate=self._truncate, - api_url=self._api_url, - api_key=self._api_key, + model=self.model, + top_k=self.top_k, + truncate=self.truncate, + api_url=self.api_url, + api_key=self.api_key.to_dict() if self.api_key else None, + query_prefix=self.query_prefix, + document_prefix=self.document_prefix, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + timeout=self.timeout, ) @classmethod @@ -124,7 +162,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker": :param data: A dictionary containing the ranker's attributes. :returns: The deserialized ranker. """ - deserialize_secrets_inplace(data, keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def warm_up(self): @@ -135,18 +175,31 @@ def warm_up(self): """ if not self._initialized: model_kwargs = {} - if self._truncate is not None: - model_kwargs.update(truncate=str(self._truncate)) + if self.truncate is not None: + model_kwargs.update(truncate=str(self.truncate)) self._backend = NimBackend( - self._model, - api_url=self._api_url, - api_key=self._api_key, + model=self.model, + api_url=self.api_url, + api_key=self.api_key, model_kwargs=model_kwargs, + timeout=self.timeout, ) - if not self._model: - self._model = _DEFAULT_MODEL + if not self.model: + self.model = _DEFAULT_MODEL self._initialized = True + def _prepare_documents_to_embed(self, documents: List[Document]) -> List[str]: + document_texts = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) + for key in self.meta_fields_to_embed + if key in doc.meta and doc.meta[key] # noqa: RUF019 + ] + text_to_embed = self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + document_texts.append(self.document_prefix + text_to_embed) + return document_texts + @component.output_types(documents=List[Document]) def run( self, @@ -170,32 +223,37 @@ def run( msg = "The ranker has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) if not isinstance(query, str): - msg = "Ranker expects the `query` parameter to be a string." + msg = "NvidiaRanker expects the `query` parameter to be a string." raise TypeError(msg) if not isinstance(documents, list): - msg = "Ranker expects the `documents` parameter to be a list." + msg = "NvidiaRanker expects the `documents` parameter to be a list." raise TypeError(msg) if not all(isinstance(doc, Document) for doc in documents): - msg = "Ranker expects the `documents` parameter to be a list of Document objects." + msg = "NvidiaRanker expects the `documents` parameter to be a list of Document objects." raise TypeError(msg) if top_k is not None and not isinstance(top_k, int): - msg = "Ranker expects the `top_k` parameter to be an integer." + msg = "NvidiaRanker expects the `top_k` parameter to be an integer." raise TypeError(msg) if len(documents) == 0: return {"documents": []} - top_k = top_k if top_k is not None else self._top_k + top_k = top_k if top_k is not None else self.top_k if top_k < 1: + logger.warning("top_k should be at least 1, returning nothing") warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2) return {"documents": []} assert self._backend is not None + + query_text = self.query_prefix + query + document_texts = self._prepare_documents_to_embed(documents=documents) + # rank result is list[{index: int, logit: float}] sorted by logit sorted_indexes_and_scores = self._backend.rank( - query, - documents, - endpoint=self._endpoint, + query_text=query_text, + document_texts=document_texts, + endpoint=self.endpoint, ) sorted_documents = [] for item in sorted_indexes_and_scores[:top_k]: diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py index 3b5d7f40a..649ceaf9d 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/utils/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py index da301d29d..0b69c8d24 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -1,4 +1,8 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .nim_backend import Model, NimBackend from .utils import is_hosted, url_validation -__all__ = ["NimBackend", "Model", "is_hosted", "url_validation"] +__all__ = ["Model", "NimBackend", "is_hosted", "url_validation"] diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index cbb6b7c3f..15b35e4b2 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -1,11 +1,18 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +import os from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import requests -from haystack import Document +from haystack import logging from haystack.utils import Secret -REQUEST_TIMEOUT = 60 +logger = logging.getLogger(__name__) + +REQUEST_TIMEOUT = 60.0 @dataclass @@ -31,6 +38,7 @@ def __init__( api_url: str, api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), model_kwargs: Optional[Dict[str, Any]] = None, + timeout: Optional[float] = None, ): headers = { "Content-Type": "application/json", @@ -46,6 +54,9 @@ def __init__( self.model = model self.api_url = api_url self.model_kwargs = model_kwargs or {} + if timeout is None: + timeout = float(os.environ.get("NVIDIA_TIMEOUT", REQUEST_TIMEOUT)) + self.timeout = timeout def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: url = f"{self.api_url}/embeddings" @@ -58,10 +69,11 @@ def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: "input": texts, **self.model_kwargs, }, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() except requests.HTTPError as e: + logger.error("Error when calling NIM embedding endpoint: Error - {error}", error=e.response.text) msg = f"Failed to query embedding endpoint: Error - {e.response.text}" raise ValueError(msg) from e @@ -90,10 +102,11 @@ def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: ], **self.model_kwargs, }, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() except requests.HTTPError as e: + logger.error("Error when calling NIM chat completion endpoint: Error - {error}", error=e.response.text) msg = f"Failed to query chat completion endpoint: Error - {e.response.text}" raise ValueError(msg) from e @@ -128,21 +141,22 @@ def models(self) -> List[Model]: res = self.session.get( url, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() data = res.json()["data"] models = [Model(element["id"]) for element in data if "id" in element] if not models: + logger.error("No hosted model were found at URL '{u}'.", u=url) msg = f"No hosted model were found at URL '{url}'." raise ValueError(msg) return models def rank( self, - query: str, - documents: List[Document], + query_text: str, + document_texts: List[str], endpoint: Optional[str] = None, ) -> List[Dict[str, Any]]: url = endpoint or f"{self.api_url}/ranking" @@ -152,18 +166,22 @@ def rank( url, json={ "model": self.model, - "query": {"text": query}, - "passages": [{"text": doc.content} for doc in documents], + "query": {"text": query_text}, + "passages": [{"text": text} for text in document_texts], **self.model_kwargs, }, - timeout=REQUEST_TIMEOUT, + timeout=self.timeout, ) res.raise_for_status() except requests.HTTPError as e: + logger.error("Error when calling NIM ranking endpoint: Error - {error}", error=e.response.text) msg = f"Failed to rank endpoint: Error - {e.response.text}" raise ValueError(msg) from e data = res.json() - assert "rankings" in data, f"Expected 'rankings' in response, got {data}" + if "rankings" not in data: + logger.error("Expected 'rankings' in response, got {d}", d=data) + msg = f"Expected 'rankings' in response, got {data}" + raise ValueError(msg) return data["rankings"] diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py index 7d4dfc3b4..f07989405 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -1,9 +1,13 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings -from typing import List +from typing import List, Optional from urllib.parse import urlparse, urlunparse -def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: +def url_validation(api_url: str, default_api_url: Optional[str], allowed_paths: List[str]) -> str: """ Validate and normalize an API URL. diff --git a/integrations/nvidia/tests/__init__.py b/integrations/nvidia/tests/__init__.py index 47611e0b9..38adc654d 100644 --- a/integrations/nvidia/tests/__init__.py +++ b/integrations/nvidia/tests/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .conftest import MockBackend __all__ = ["MockBackend"] diff --git a/integrations/nvidia/tests/conftest.py b/integrations/nvidia/tests/conftest.py index a6c78ba4e..b6346c672 100644 --- a/integrations/nvidia/tests/conftest.py +++ b/integrations/nvidia/tests/conftest.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Dict, List, Optional, Tuple import pytest diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py index 426bacc25..506fbc385 100644 --- a/integrations/nvidia/tests/test_base_url.py +++ b/integrations/nvidia/tests/test_base_url.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index db69053e7..8c01f0759 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -71,6 +75,7 @@ def test_to_dict(self, monkeypatch): "meta_fields_to_embed": [], "embedding_separator": "\n", "truncate": None, + "timeout": 60.0, }, } @@ -86,6 +91,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): meta_fields_to_embed=["test_field"], embedding_separator=" | ", truncate=EmbeddingTruncateMode.END, + timeout=45.0, ) data = component.to_dict() assert data == { @@ -101,10 +107,11 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", "truncate": "END", + "timeout": 45.0, }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", @@ -119,18 +126,38 @@ def from_dict(self, monkeypatch): "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", "truncate": "START", + "timeout": 45.0, }, } component = NvidiaDocumentEmbedder.from_dict(data) - assert component.model == "nvolveqa_40k" + assert component.model == "playground_nvolveqa_40k" assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" + assert component.batch_size == 10 + assert component.progress_bar is False + assert component.meta_fields_to_embed == ["test_field"] + assert component.embedding_separator == " | " + assert component.truncate == EmbeddingTruncateMode.START + assert component.timeout == 45.0 + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", + "init_parameters": {}, + } + component = NvidiaDocumentEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" assert component.batch_size == 32 assert component.progress_bar assert component.meta_fields_to_embed == [] assert component.embedding_separator == "\n" - assert component.truncate == EmbeddingTruncateMode.START + assert component.truncate is None + assert component.timeout == 60.0 def test_prepare_texts_to_embed_w_metadata(self): documents = [ @@ -326,7 +353,7 @@ def test_run_wrong_input_format(self): with pytest.raises(TypeError, match="NvidiaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) - def test_run_empty_document(self): + def test_run_empty_document(self, caplog): model = "playground_nvolveqa_40k" api_key = Secret.from_token("fake-api-key") embedder = NvidiaDocumentEmbedder(model, api_key=api_key) @@ -334,8 +361,10 @@ def test_run_empty_document(self): embedder.warm_up() embedder.backend = MockBackend(model=model, api_key=api_key) - with pytest.raises(ValueError, match="no content to embed"): + # Write check using caplog that a logger.warning is raised + with caplog.at_level("WARNING"): embedder.run(documents=[Document(content="")]) + assert "has no content to embed." in caplog.text def test_run_on_empty_list(self): model = "playground_nvolveqa_40k" @@ -351,6 +380,19 @@ def test_run_on_empty_list(self): assert result["documents"] is not None assert not result["documents"] # empty list + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + embedder = NvidiaDocumentEmbedder(timeout=10.0) + embedder.warm_up() + assert embedder.backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + embedder = NvidiaDocumentEmbedder() + embedder.warm_up() + assert embedder.backend.timeout == 45.0 + @pytest.mark.skipif( not os.environ.get("NVIDIA_API_KEY", None), reason="Export an env var called NVIDIA_API_KEY containing the Nvidia API key to run this test.", diff --git a/integrations/nvidia/tests/test_embedding_truncate_mode.py b/integrations/nvidia/tests/test_embedding_truncate_mode.py index e74d0308c..16f9112ea 100644 --- a/integrations/nvidia/tests/test_embedding_truncate_mode.py +++ b/integrations/nvidia/tests/test_embedding_truncate_mode.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 0bd8b1fc6..414de4884 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -123,6 +124,19 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + generator = NvidiaGenerator(timeout=10.0) + generator.warm_up() + assert generator._backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + generator = NvidiaGenerator() + generator.warm_up() + assert generator._backend.timeout == 45.0 + @pytest.mark.skipif( not os.environ.get("NVIDIA_NIM_GENERATOR_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_GENERATOR_MODEL containing the hosted model name and " diff --git a/integrations/nvidia/tests/test_ranker.py b/integrations/nvidia/tests/test_ranker.py index 566fd18a8..3d93dc028 100644 --- a/integrations/nvidia/tests/test_ranker.py +++ b/integrations/nvidia/tests/test_ranker.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import re from typing import Any, Optional, Union @@ -15,8 +19,8 @@ class TestNvidiaRanker: def test_init_default(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") client = NvidiaRanker() - assert client._model == _DEFAULT_MODEL - assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.model == _DEFAULT_MODEL + assert client.api_key == Secret.from_env_var("NVIDIA_API_KEY") def test_init_with_parameters(self): client = NvidiaRanker( @@ -25,10 +29,10 @@ def test_init_with_parameters(self): top_k=3, truncate="END", ) - assert client._api_key == Secret.from_token("fake-api-key") - assert client._model == _DEFAULT_MODEL - assert client._top_k == 3 - assert client._truncate == RankerTruncateMode.END + assert client.api_key == Secret.from_token("fake-api-key") + assert client.model == _DEFAULT_MODEL + assert client.top_k == 3 + assert client.truncate == RankerTruncateMode.END def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("NVIDIA_API_KEY", raising=False) @@ -39,7 +43,7 @@ def test_init_fail_wo_api_key(self, monkeypatch): def test_init_pass_wo_api_key_w_api_url(self): url = "https://url.bogus/v1" client = NvidiaRanker(api_url=url) - assert client._api_url == url + assert client.api_url == url def test_warm_up_required(self): client = NvidiaRanker() @@ -256,3 +260,104 @@ def test_warm_up_once(self, monkeypatch) -> None: backend = client._backend client.warm_up() assert backend == client._backend + + def test_to_dict(self) -> None: + client = NvidiaRanker() + assert client.to_dict() == { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + "query_prefix": "", + "document_prefix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": 60.0, + }, + } + + def test_from_dict(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + "query_prefix": "", + "document_prefix": "", + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": 45.0, + }, + } + ) + assert client.model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client.top_k == 5 + assert client.truncate is None + assert client.api_url is None + assert client.api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.query_prefix == "" + assert client.document_prefix == "" + assert client.meta_fields_to_embed == [] + assert client.embedding_separator == "\n" + assert client.timeout == 45.0 + + def test_from_dict_defaults(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": {}, + } + ) + assert client.model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client.top_k == 5 + assert client.truncate is None + assert client.api_url is None + assert client.api_key == Secret.from_env_var("NVIDIA_API_KEY") + assert client.query_prefix == "" + assert client.document_prefix == "" + assert client.meta_fields_to_embed == [] + assert client.embedding_separator == "\n" + assert client.timeout == 60.0 + + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + client = NvidiaRanker(timeout=10.0) + client.warm_up() + assert client._backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + client = NvidiaRanker() + client.warm_up() + assert client._backend.timeout == 45.0 + + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + ranker = NvidiaRanker( + model=None, + api_key=Secret.from_token("fake-api-key"), + meta_fields_to_embed=["meta_field"], + embedding_separator=" | ", + ) + + prepared_texts = ranker._prepare_documents_to_embed(documents) + + # note that newline is replaced by space + assert prepared_texts == [ + "meta_value 0 | document number 0:\ncontent", + "meta_value 1 | document number 1:\ncontent", + "meta_value 2 | document number 2:\ncontent", + "meta_value 3 | document number 3:\ncontent", + "meta_value 4 | document number 4:\ncontent", + ] diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 8690de6b1..b572cc046 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -52,6 +56,7 @@ def test_to_dict(self, monkeypatch): "prefix": "", "suffix": "", "truncate": None, + "timeout": 60.0, }, } @@ -63,6 +68,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): prefix="prefix", suffix="suffix", truncate=EmbeddingTruncateMode.START, + timeout=10.0, ) data = component.to_dict() assert data == { @@ -74,10 +80,11 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "prefix": "prefix", "suffix": "suffix", "truncate": "START", + "timeout": 10.0, }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", @@ -88,6 +95,7 @@ def from_dict(self, monkeypatch): "prefix": "prefix", "suffix": "suffix", "truncate": "START", + "timeout": 10.0, }, } component = NvidiaTextEmbedder.from_dict(data) @@ -95,7 +103,21 @@ def from_dict(self, monkeypatch): assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" - assert component.truncate == "START" + assert component.truncate == EmbeddingTruncateMode.START + assert component.timeout == 10.0 + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", + "init_parameters": {}, + } + component = NvidiaTextEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" + assert component.truncate is None @pytest.mark.usefixtures("mock_local_models") def test_run_default_model(self): @@ -158,6 +180,19 @@ def test_run_empty_string(self): with pytest.raises(ValueError, match="empty string"): embedder.run(text="") + def test_setting_timeout(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + embedder = NvidiaTextEmbedder(timeout=10.0) + embedder.warm_up() + assert embedder.backend.timeout == 10.0 + + def test_setting_timeout_env(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + monkeypatch.setenv("NVIDIA_TIMEOUT", "45") + embedder = NvidiaTextEmbedder() + embedder.warm_up() + assert embedder.backend.timeout == 45.0 + @pytest.mark.skipif( not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 55c6aa7b7..9e2e0a0cb 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,21 +1,45 @@ # Changelog +## [integrations/ollama-v2.1.0] - 2024-11-28 + +### ๐Ÿš€ Features + +- `OllamaDocumentEmbedder` - allow batching embeddings (#1224) + +### ๐ŸŒ€ Miscellaneous + +- Chore: update changelog for `ollama-haystack==2.0.0` (#1214) +- Chore: use class methods to create `ChatMessage` (#1222) + +## [integrations/ollama-v2.0.0] - 2024-11-22 + +### ๐Ÿ› Bug Fixes + +- Adapt to Ollama client 0.4.0 (#1209) + +### โš™๏ธ CI + +- Adopt uv as installer (#1142) + + ## [integrations/ollama-v1.1.0] - 2024-10-11 ### ๐Ÿš€ Features - Add `keep_alive` parameter to Ollama Generators (#1131) -### โš™๏ธ Miscellaneous Tasks +### ๐Ÿงน Chores - Update ruff linting scripts and settings (#1105) + ## [integrations/ollama-v1.0.1] - 2024-09-26 ### ๐Ÿ› Bug Fixes - Ollama Chat Generator - add missing `to_dict` and `from_dict` methods (#1110) + ## [integrations/ollama-v1.0.0] - 2024-09-07 ### ๐Ÿ› Bug Fixes @@ -30,31 +54,47 @@ - Do not retry tests in `hatch run test` command (#954) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) + +### ๐Ÿงน Chores + - Update ruff invocation to include check parameter (#853) +### ๐ŸŒ€ Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: ollama - ruff update, don't ruff tests (#985) + ## [integrations/ollama-v0.0.7] - 2024-05-31 ### ๐Ÿš€ Features - Add streaming support to OllamaChatGenerator (#757) +### ๐ŸŒ€ Miscellaneous + +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) + ## [integrations/ollama-v0.0.6] - 2024-04-18 ### ๐Ÿ“š Documentation - Disable-class-def (#556) -### โš™๏ธ Miscellaneous Tasks +### ๐Ÿงน Chores - Update docstrings (#499) -### Ollama +### ๐ŸŒ€ Miscellaneous +- Update API docs adding embedders (#494) - Change testing workflow (#551) +- Remove references to Python 3.7 (#601) - Add ollama embedder example (#669) +- Fix: change ollama output name to 'meta' (#670) ## [integrations/ollama-v0.0.5] - 2024-02-28 @@ -62,24 +102,42 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### ๐Ÿ“š Documentation - Update category slug (#442) -### โš™๏ธ Miscellaneous Tasks +### ๐Ÿงน Chores - Use `serialize_callable` instead of `serialize_callback_handler` in Ollama (#461) +### ๐ŸŒ€ Miscellaneous + +- Ollama document embedder (#400) +- Changed Default Ollama Embedding models to supported model: nomic-embed-text (#490) + ## [integrations/ollama-v0.0.4] - 2024-02-12 -### Ollama +### ๐ŸŒ€ Miscellaneous +- Ollama: add license (#219) - Generate api docs (#332) +- Ollama Text Embedder with new format (#252) +- Support for streaming ollama generator (#280) ## [integrations/ollama-v0.0.3] - 2024-01-16 +### ๐ŸŒ€ Miscellaneous + +- Docs: Ollama docstrings update (#171) +- Add example of OllamaGenerator (#170) +- Ollama Chat Generator (#176) +- Ollama: improve test (#191) +- Mount Ollama in haystack_integrations (#216) + ## [integrations/ollama-v0.0.1] - 2024-01-03 +### ๐ŸŒ€ Miscellaneous + +- Ollama Integration (#132) + diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 598d1d214..c9fc22f3d 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -19,7 +19,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -27,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "ollama"] +dependencies = ["haystack-ai", "ollama>=0.4.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme" @@ -63,7 +62,7 @@ cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +python = ["3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py index 46042a1c9..822b3d0aa 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/__init__.py @@ -1,4 +1,4 @@ from .document_embedder import OllamaDocumentEmbedder from .text_embedder import OllamaTextEmbedder -__all__ = ["OllamaTextEmbedder", "OllamaDocumentEmbedder"] +__all__ = ["OllamaDocumentEmbedder", "OllamaTextEmbedder"] diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py index ac8f38f35..8d2f5f505 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py @@ -36,6 +36,7 @@ def __init__( progress_bar: bool = True, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + batch_size: int = 32, ): """ :param model: @@ -48,12 +49,24 @@ def __init__( [Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values). :param timeout: The number of seconds before throwing a timeout error from the Ollama API. + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param progress_bar: + If `True`, shows a progress bar when running. + :param meta_fields_to_embed: + List of metadata fields to embed along with the document text. + :param embedding_separator: + Separator used to concatenate the metadata fields to the document text. + :param batch_size: + Number of documents to process at once. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model - self.batch_size = 1 # API only supports a single call at the moment + self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed self.embedding_separator = embedding_separator @@ -88,24 +101,19 @@ def _embed_batch( self, texts_to_embed: List[str], batch_size: int, generation_kwargs: Optional[Dict[str, Any]] = None ): """ - Ollama Embedding only allows single uploads, not batching. Currently the batch size is set to 1. - If this changes in the future, line 86 (the first line within the for loop), can contain: - batch = texts_to_embed[i + i + batch_size] + Internal method to embed a batch of texts. """ all_embeddings = [] - meta: Dict[str, Any] = {"model": ""} for i in tqdm( range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" ): - batch = texts_to_embed[i] # Single batch only - result = self._client.embeddings(model=self.model, prompt=batch, options=generation_kwargs) - all_embeddings.append(result["embedding"]) + batch = texts_to_embed[i : i + batch_size] + result = self._client.embed(model=self.model, input=batch, options=generation_kwargs) + all_embeddings.extend(result["embeddings"]) - meta["model"] = self.model - - return all_embeddings, meta + return all_embeddings @component.output_types(documents=List[Document], meta=Dict[str, Any]) def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, Any]] = None): @@ -122,19 +130,21 @@ def run(self, documents: List[Document], generation_kwargs: Optional[Dict[str, A - `documents`: Documents with embedding information attached - `meta`: The metadata collected during the embedding process """ - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "OllamaDocumentEmbedder expects a list of Documents as input." "In case you want to embed a list of strings, please use the OllamaTextEmbedder." ) raise TypeError(msg) + generation_kwargs = generation_kwargs or self.generation_kwargs + texts_to_embed = self._prepare_texts_to_embed(documents=documents) - embeddings, meta = self._embed_batch( + embeddings = self._embed_batch( texts_to_embed=texts_to_embed, batch_size=self.batch_size, generation_kwargs=generation_kwargs ) for doc, emb in zip(documents, embeddings): doc.embedding = emb - return {"documents": documents, "meta": meta} + return {"documents": documents, "meta": {"model": self.model}} diff --git a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py index 7779c6d6e..b08b8bef3 100644 --- a/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py +++ b/integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py @@ -62,7 +62,7 @@ def run(self, text: str, generation_kwargs: Optional[Dict[str, Any]] = None): - `embedding`: The computed embeddings - `meta`: The metadata collected during the embedding process """ - result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs) + result = self._client.embeddings(model=self.model, prompt=text, options=generation_kwargs).model_dump() result["meta"] = {"model": self.model} return result diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py index 41a02d0ac..24e4d2edb 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/__init__.py @@ -1,4 +1,4 @@ from .chat.chat_generator import OllamaChatGenerator from .generator import OllamaGenerator -__all__ = ["OllamaGenerator", "OllamaChatGenerator"] +__all__ = ["OllamaChatGenerator", "OllamaGenerator"] diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 558fd593e..b1be7a2db 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -4,7 +4,7 @@ from haystack.dataclasses import ChatMessage, StreamingChunk from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from ollama import Client +from ollama import ChatResponse, Client @component @@ -111,12 +111,13 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaChatGenerator": def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: return {"role": message.role.value, "content": message.content} - def _build_message_from_ollama_response(self, ollama_response: Dict[str, Any]) -> ChatMessage: + def _build_message_from_ollama_response(self, ollama_response: ChatResponse) -> ChatMessage: """ Converts the non-streaming response from the Ollama API to a ChatMessage. """ - message = ChatMessage.from_assistant(content=ollama_response["message"]["content"]) - message.meta.update({key: value for key, value in ollama_response.items() if key != "message"}) + response_dict = ollama_response.model_dump() + message = ChatMessage.from_assistant(content=response_dict["message"]["content"]) + message.meta.update({key: value for key, value in response_dict.items() if key != "message"}) return message def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: @@ -133,9 +134,11 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - content = chunk_response["message"]["content"] - meta = {key: value for key, value in chunk_response.items() if key != "message"} - meta["role"] = chunk_response["message"]["role"] + chunk_response_dict = chunk_response.model_dump() + + content = chunk_response_dict["message"]["content"] + meta = {key: value for key, value in chunk_response_dict.items() if key != "message"} + meta["role"] = chunk_response_dict["message"]["role"] chunk_message = StreamingChunk(content, meta) return chunk_message diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index 058948e8a..dad671c94 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -4,7 +4,7 @@ from haystack.dataclasses import StreamingChunk from haystack.utils.callable_serialization import deserialize_callable, serialize_callable -from ollama import Client +from ollama import Client, GenerateResponse @component @@ -118,15 +118,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "OllamaGenerator": data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) - def _convert_to_response(self, ollama_response: Dict[str, Any]) -> Dict[str, List[Any]]: + def _convert_to_response(self, ollama_response: GenerateResponse) -> Dict[str, List[Any]]: """ Converts a response from the Ollama API to the required Haystack format. """ + reply = ollama_response.response + meta = {key: value for key, value in ollama_response.model_dump().items() if key != "response"} - replies = [ollama_response["response"]] - meta = {key: value for key, value in ollama_response.items() if key != "response"} - - return {"replies": replies, "meta": [meta]} + return {"replies": [reply], "meta": [meta]} def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[str, List[Any]]: """ @@ -154,8 +153,9 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk: """ Converts the response from the Ollama API to a StreamingChunk. """ - content = chunk_response["response"] - meta = {key: value for key, value in chunk_response.items() if key != "response"} + chunk_response_dict = chunk_response.model_dump() + content = chunk_response_dict["response"] + meta = {key: value for key, value in chunk_response_dict.items() if key != "response"} chunk_message = StreamingChunk(content, meta) return chunk_message diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index 5ac9289aa..0308f42ec 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -3,8 +3,8 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk -from haystack.dataclasses import ChatMessage, ChatRole -from ollama._types import ResponseError +from haystack.dataclasses import ChatMessage +from ollama._types import ChatResponse, ResponseError from haystack_integrations.components.generators.ollama import OllamaChatGenerator @@ -86,18 +86,18 @@ def test_from_dict(self): def test_build_message_from_ollama_response(self): model = "some_model" - ollama_response = { - "model": model, - "created_at": "2023-12-12T14:13:43.416799Z", - "message": {"role": "assistant", "content": "Hello! How are you today?"}, - "done": True, - "total_duration": 5191566416, - "load_duration": 2154458, - "prompt_eval_count": 26, - "prompt_eval_duration": 383809000, - "eval_count": 298, - "eval_duration": 4799921000, - } + ollama_response = ChatResponse( + model=model, + created_at="2023-12-12T14:13:43.416799Z", + message={"role": "assistant", "content": "Hello! How are you today?"}, + done=True, + total_duration=5191566416, + load_duration=2154458, + prompt_eval_count=26, + prompt_eval_duration=383809000, + eval_count=298, + eval_duration=4799921000, + ) observed = OllamaChatGenerator(model=model)._build_message_from_ollama_response(ollama_response) @@ -128,16 +128,12 @@ def test_run_with_chat_history(self): chat_generator = OllamaChatGenerator() chat_history = [ - {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, - {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, - {"role": "user", "content": "And what is the second largest?"}, + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), ] - chat_messages = [ - ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) - for message in chat_history - ] - response = chat_generator.run(chat_messages) + response = chat_generator.run(chat_history) assert isinstance(response, dict) assert isinstance(response["replies"], list) @@ -159,17 +155,12 @@ def test_run_with_streaming(self): chat_generator = OllamaChatGenerator(streaming_callback=streaming_callback) chat_history = [ - {"role": "user", "content": "What is the largest city in the United Kingdom by population?"}, - {"role": "assistant", "content": "London is the largest city in the United Kingdom by population"}, - {"role": "user", "content": "And what is the second largest?"}, - ] - - chat_messages = [ - ChatMessage(role=ChatRole(message["role"]), content=message["content"], name=None) - for message in chat_history + ChatMessage.from_user("What is the largest city in the United Kingdom by population?"), + ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"), + ChatMessage.from_user("And what is the second largest?"), ] - response = chat_generator.run(chat_messages) + response = chat_generator.run(chat_history) streaming_callback.assert_called() diff --git a/integrations/ollama/tests/test_document_embedder.py b/integrations/ollama/tests/test_document_embedder.py index 4fe3cfbb3..7d972e898 100644 --- a/integrations/ollama/tests/test_document_embedder.py +++ b/integrations/ollama/tests/test_document_embedder.py @@ -43,10 +43,14 @@ def import_text_in_embedder(self): @pytest.mark.integration def test_run(self): - embedder = OllamaDocumentEmbedder(model="nomic-embed-text") - list_of_docs = [Document(content="This is a document containing some text.")] - reply = embedder.run(list_of_docs) - - assert isinstance(reply, dict) - assert all(isinstance(element, float) for element in reply["documents"][0].embedding) - assert reply["meta"]["model"] == "nomic-embed-text" + embedder = OllamaDocumentEmbedder(model="nomic-embed-text", batch_size=2) + list_of_docs = [ + Document(content="Llamas are amazing animals known for their soft wool and gentle demeanor."), + Document(content="The Andes mountains are the natural habitat of many llamas."), + Document(content="Llamas have been used as pack animals for centuries, especially in South America."), + ] + result = embedder.run(list_of_docs) + assert result["meta"]["model"] == "nomic-embed-text" + documents = result["documents"] + assert len(documents) == 3 + assert all(isinstance(element, float) for document in documents for element in document.embedding) diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py index 02e56b34c..ec0ecdef1 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py @@ -10,10 +10,10 @@ __all__ = [ "OptimumDocumentEmbedder", - "OptimumEmbedderOptimizationMode", "OptimumEmbedderOptimizationConfig", + "OptimumEmbedderOptimizationMode", "OptimumEmbedderPooling", - "OptimumEmbedderQuantizationMode", "OptimumEmbedderQuantizationConfig", + "OptimumEmbedderQuantizationMode", "OptimumTextEmbedder", ] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py index 27f533430..2016f3ffe 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py @@ -208,7 +208,7 @@ def run(self, documents: List[Document]): if not self._initialized: msg = "The embedding model has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) - if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): msg = ( "OptimumDocumentEmbedder expects a list of Documents as input." " In case you want to embed a string, please use the OptimumTextEmbedder." diff --git a/integrations/pgvector/CHANGELOG.md b/integrations/pgvector/CHANGELOG.md index 0fe5f4fa4..f3821f1d3 100644 --- a/integrations/pgvector/CHANGELOG.md +++ b/integrations/pgvector/CHANGELOG.md @@ -1,24 +1,50 @@ # Changelog -## [integrations/pgvector-v1.0.0] - 2024-09-12 +## [integrations/pgvector-v1.2.0] - 2024-11-22 + +### ๐Ÿš€ Features + +- Add `create_extension` parameter to control vector extension creation (#1213) + + +## [integrations/pgvector-v1.1.0] - 2024-11-21 ### ๐Ÿš€ Features - Add filter_policy to pgvector integration (#820) +- Add schema support to pgvector document store. (#1095) +- Pgvector - recreate the connection if it is no longer valid (#1202) ### ๐Ÿ› Bug Fixes - `PgVector` - Fallback to default filter policy when deserializing retrievers without the init parameter (#900) +### ๐Ÿ“š Documentation + +- Explain different connection string formats in the docstring (#1132) + ### ๐Ÿงช Testing - Do not retry tests in `hatch run test` command (#954) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) +- Adopt uv as installer (#1142) + +### ๐Ÿงน Chores + - Update ruff invocation to include check parameter (#853) - PgVector - remove legacy filter support (#1068) +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) + +### ๐ŸŒ€ Miscellaneous + +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Chore: Minor retriever pydoc fix (#884) +- Chore: Update pgvector test for the new `apply_filter_policy` usage (#970) +- Chore: pgvector ruff update, don't ruff tests (#984) ## [integrations/pgvector-v0.4.0] - 2024-06-20 @@ -27,6 +53,11 @@ - Defer the database connection to when it's needed (#773) - Add customizable index names for pgvector (#818) +### ๐ŸŒ€ Miscellaneous + +- Docs: add missing api references (#728) +- [deepset-ai/haystack-core-integrations#727] (#738) + ## [integrations/pgvector-v0.2.0] - 2024-05-08 ### ๐Ÿš€ Features @@ -38,19 +69,35 @@ - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### ๐Ÿ“š Documentation - Update category slug (#442) - Disable-class-def (#556) +### ๐ŸŒ€ Miscellaneous + +- Pgvector - review docstrings and API reference (#502) +- Refactor tests (#574) +- Remove references to Python 3.7 (#601) +- Make Document Stores initially skip `SparseEmbedding` (#606) +- Chore: add license classifiers (#680) +- Type hints in pgvector document store updated for 3.8 compability (#704) +- Chore: change the pydoc renderer class (#718) + ## [integrations/pgvector-v0.1.0] - 2024-02-14 ### ๐Ÿ› Bug Fixes -- Fix linting (#328) +- Pgvector: fix linting (#328) +### ๐ŸŒ€ Miscellaneous +- Pgvector Document Store - minimal implementation (#239) +- Pgvector - filters (#257) +- Pgvector - embedding retrieval (#298) +- Pgvector - Embedding Retriever (#320) +- Pgvector: generate API docs (#325) +- Pgvector: add an example (#334) +- Adopt `Secret` to pgvector (#402) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 1b1333f5c..87655a5ec 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ -CREATE TABLE IF NOT EXISTS {table_name} ( +CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( id VARCHAR(128) PRIMARY KEY, embedding VECTOR({embedding_dimension}), content TEXT, @@ -36,7 +36,7 @@ """ INSERT_STATEMENT = """ -INSERT INTO {table_name} +INSERT INTO {schema_name}.{table_name} (id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta) VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s) """ @@ -54,7 +54,7 @@ KEYWORD_QUERY = """ SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score -FROM {table_name}, plainto_tsquery({language}, %s) query +FROM {schema_name}.{table_name}, plainto_tsquery({language}, %s) query WHERE to_tsvector({language}, content) @@ query """ @@ -78,6 +78,8 @@ def __init__( self, *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), + create_extension: bool = True, + schema_name: str = "public", table_name: str = "haystack_documents", language: str = "english", embedding_dimension: int = 768, @@ -101,6 +103,11 @@ def __init__( e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"` See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) for more details. + :param create_extension: Whether to create the pgvector extension if it doesn't exist. + Set this to `True` (default) to automatically create the extension if it is missing. + Creating the extension may require superuser privileges. + If set to `False`, ensure the extension is already installed; otherwise, an error will be raised. + :param schema_name: The name of the schema the table is created in. The schema must already exist. :param table_name: The name of the table to use to store Haystack documents. :param language: The language to be used to parse query and document content in keyword retrieval. To see the list of available languages, you can run the following SQL query in your PostgreSQL database: @@ -136,7 +143,9 @@ def __init__( """ self.connection_string = connection_string + self.create_extension = create_extension self.table_name = table_name + self.schema_name = schema_name self.embedding_dimension = embedding_dimension if vector_function not in VALID_VECTOR_FUNCTIONS: msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" @@ -153,49 +162,86 @@ def __init__( self._connection = None self._cursor = None self._dict_cursor = None + self._table_initialized = False @property def cursor(self): - if self._cursor is None: + if self._cursor is None or not self._connection_is_valid(self._connection): self._create_connection() return self._cursor @property def dict_cursor(self): - if self._dict_cursor is None: + if self._dict_cursor is None or not self._connection_is_valid(self._connection): self._create_connection() return self._dict_cursor @property def connection(self): - if self._connection is None: + if self._connection is None or not self._connection_is_valid(self._connection): self._create_connection() return self._connection def _create_connection(self): + """ + Internal method to create a connection to the PostgreSQL database. + """ + + # close the connection if it already exists + if self._connection: + try: + self._connection.close() + except Error as e: + logger.debug("Failed to close connection: %s", str(e)) + conn_str = self.connection_string.resolve_value() or "" connection = connect(conn_str) connection.autocommit = True - connection.execute("CREATE EXTENSION IF NOT EXISTS vector") + if self.create_extension: + connection.execute("CREATE EXTENSION IF NOT EXISTS vector") register_vector(connection) # Note: this must be called before creating the cursors. self._connection = connection self._cursor = self._connection.cursor() self._dict_cursor = self._connection.cursor(row_factory=dict_row) - # Init schema + if not self._table_initialized: + self._initialize_table() + + return self._connection + + def _initialize_table(self): + """ + Internal method to initialize the table. + """ if self.recreate_table: self.delete_table() + self._create_table_if_not_exists() self._create_keyword_index_if_not_exists() if self.search_strategy == "hnsw": self._handle_hnsw() - return self._connection + self._table_initialized = True + + @staticmethod + def _connection_is_valid(connection): + """ + Internal method to check if the connection is still valid. + """ + + # implementation inspired to psycopg pool + # https://github.com/psycopg/psycopg/blob/d38cf7798b0c602ff43dac9f20bbab96237a9c38/psycopg_pool/psycopg_pool/pool.py#L528 + + try: + connection.execute("") + except Error: + return False + return True def to_dict(self) -> Dict[str, Any]: """ @@ -207,6 +253,8 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, connection_string=self.connection_string.to_dict(), + create_extension=self.create_extension, + schema_name=self.schema_name, table_name=self.table_name, embedding_dimension=self.embedding_dimension, vector_function=self.vector_function, @@ -266,7 +314,9 @@ def _create_table_if_not_exists(self): """ create_sql = SQL(CREATE_TABLE_STATEMENT).format( - table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + embedding_dimension=SQLLiteral(self.embedding_dimension), ) self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") @@ -274,12 +324,18 @@ def _create_table_if_not_exists(self): def delete_table(self): """ Deletes the table used to store Haystack documents. - The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`. + The name of the schema (`schema_name`) and the name of the table (`table_name`) + are defined when initializing the `PgvectorDocumentStore`. """ + delete_sql = SQL("DROP TABLE IF EXISTS {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + ) - delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) - - self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + self._execute_sql( + delete_sql, + error_msg=f"Could not delete table {self.schema_name}.{self.table_name} in PgvectorDocumentStore", + ) def _create_keyword_index_if_not_exists(self): """ @@ -287,15 +343,16 @@ def _create_keyword_index_if_not_exists(self): """ index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.keyword_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.keyword_index_name), "Could not check if keyword index exists", ).fetchone() ) sql_create_index = SQL( - "CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))" + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING GIN (to_tsvector({language}, content))" ).format( + schema_name=Identifier(self.schema_name), index_name=Identifier(self.keyword_index_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), @@ -318,8 +375,8 @@ def _handle_hnsw(self): index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.hnsw_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.hnsw_index_name), "Could not check if HNSW index exists", ).fetchone() ) @@ -349,8 +406,13 @@ def _create_hnsw_index(self): if key in HNSW_INDEX_CREATION_VALID_KWARGS } - sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( - index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops) + sql_create_index = SQL( + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING hnsw (embedding {ops}) " + ).format( + schema_name=Identifier(self.schema_name), + index_name=Identifier(self.hnsw_index_name), + table_name=Identifier(self.table_name), + ops=SQL(pg_ops), ) if actual_hnsw_index_creation_kwargs: @@ -369,7 +431,9 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ - sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_count = SQL("SELECT COUNT(*) FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ 0 @@ -395,7 +459,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." raise ValueError(msg) - sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_filter = SQL("SELECT * FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) params = () if filters: @@ -434,7 +500,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D db_documents = self._from_haystack_to_pg_documents(documents) - sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) + sql_insert = SQL(INSERT_STATEMENT).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) if policy == DuplicatePolicy.OVERWRITE: sql_insert += SQL(UPDATE_STATEMENT) @@ -543,8 +611,10 @@ def delete_documents(self, document_ids: List[str]) -> None: document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids) - delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format( - table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str) + delete_sql = SQL("DELETE FROM {schema_name}.{table_name} WHERE id IN ({document_ids_str})").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + document_ids_str=SQL(document_ids_str), ) self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") @@ -570,6 +640,7 @@ def _keyword_retrieval( raise ValueError(msg) sql_select = SQL(KEYWORD_QUERY).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), query=SQLLiteral(query), @@ -643,7 +714,8 @@ def _embedding_retrieval( elif vector_function == "l2_distance": score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" - sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), score=SQL(score_definition), ) diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 93514b71c..baa921137 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -41,12 +41,33 @@ def test_write_dataframe(self, document_store: PgvectorDocumentStore): retrieved_docs = document_store.filter_documents() assert retrieved_docs == docs + def test_connection_check_and_recreation(self, document_store: PgvectorDocumentStore): + original_connection = document_store.connection + + with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=False): + new_connection = document_store.connection + + # verify that a new connection is created + assert new_connection is not original_connection + assert document_store._connection == new_connection + assert original_connection.closed + + assert document_store._cursor is not None + assert document_store._dict_cursor is not None + + # test with new connection + with patch.object(PgvectorDocumentStore, "_connection_is_valid", return_value=True): + same_connection = document_store.connection + assert same_connection is document_store._connection + @pytest.mark.usefixtures("patches_for_unit_tests") def test_init(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + create_extension=True, + schema_name="my_schema", table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -59,6 +80,8 @@ def test_init(monkeypatch): keyword_index_name="my_keyword_index", ) + assert document_store.create_extension + assert document_store.schema_name == "my_schema" assert document_store.table_name == "my_table" assert document_store.embedding_dimension == 512 assert document_store.vector_function == "l2_distance" @@ -76,6 +99,7 @@ def test_to_dict(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + create_extension=False, table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -92,7 +116,9 @@ def test_to_dict(monkeypatch): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": False, "table_name": "my_table", + "schema_name": "public", "embedding_dimension": 512, "vector_function": "l2_distance", "recreate_table": True, diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index 290891307..11be71ab1 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -50,6 +50,8 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": True, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -81,6 +83,7 @@ def test_from_dict(self, monkeypatch): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": False, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -105,6 +108,7 @@ def test_from_dict(self, monkeypatch): assert isinstance(document_store, PgvectorDocumentStore) assert isinstance(document_store.connection_string, EnvVarSecret) + assert not document_store.create_extension assert document_store.table_name == "haystack_test_to_dict" assert document_store.embedding_dimension == 768 assert document_store.vector_function == "cosine_similarity" @@ -175,6 +179,8 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": True, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -205,6 +211,7 @@ def test_from_dict(self, monkeypatch): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "create_extension": False, "table_name": "haystack_test_to_dict", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -228,6 +235,7 @@ def test_from_dict(self, monkeypatch): assert isinstance(document_store, PgvectorDocumentStore) assert isinstance(document_store.connection_string, EnvVarSecret) + assert not document_store.create_extension assert document_store.table_name == "haystack_test_to_dict" assert document_store.embedding_dimension == 768 assert document_store.vector_function == "cosine_similarity" diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py index ed6422bfe..bbb7251d0 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/__init__.py @@ -4,4 +4,4 @@ from .retriever import QdrantEmbeddingRetriever, QdrantHybridRetriever, QdrantSparseEmbeddingRetriever -__all__ = ("QdrantEmbeddingRetriever", "QdrantSparseEmbeddingRetriever", "QdrantHybridRetriever") +__all__ = ("QdrantEmbeddingRetriever", "QdrantHybridRetriever", "QdrantSparseEmbeddingRetriever") diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py index 87c7b6b01..db084502b 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py @@ -5,10 +5,10 @@ from .document_store import WeaviateDocumentStore __all__ = [ - "WeaviateDocumentStore", "AuthApiKey", "AuthBearerToken", "AuthClientCredentials", "AuthClientPassword", "AuthCredentials", + "WeaviateDocumentStore", ]