diff --git a/.github/workflows/CI_pypi_release.yml b/.github/workflows/CI_pypi_release.yml index e29d3ea36..1162ca3b3 100644 --- a/.github/workflows/CI_pypi_release.yml +++ b/.github/workflows/CI_pypi_release.yml @@ -54,7 +54,7 @@ jobs: with: config: cliff.toml args: > - --include-path "${{ steps.pathfinder.outputs.project_path }}/*" + --include-path "${{ steps.pathfinder.outputs.project_path }}/**/*" --tag-pattern "${{ steps.pathfinder.outputs.project_path }}-v*" - name: Commit changelog diff --git a/.github/workflows/CI_stale.yml b/.github/workflows/CI_stale.yml new file mode 100644 index 000000000..5a4b3b467 --- /dev/null +++ b/.github/workflows/CI_stale.yml @@ -0,0 +1,15 @@ +name: 'Stalebot' +on: + schedule: + - cron: '30 1 * * *' + +jobs: + makestale: + runs-on: ubuntu-latest + steps: + - uses: actions/stale@v9 + with: + any-of-labels: 'community-triage' + stale-pr-message: 'This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 10 days.' + days-before-stale: 30 + days-before-close: 10 \ No newline at end of file diff --git a/.github/workflows/azure_ai_search.yml b/.github/workflows/azure_ai_search.yml new file mode 100644 index 000000000..1c10edc91 --- /dev/null +++ b/.github/workflows/azure_ai_search.yml @@ -0,0 +1,72 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / azure_ai_search + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/azure_ai_search/**" + - ".github/workflows/azure_ai_search.yml" + +concurrency: + group: azure_ai_search-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }} + AZURE_SEARCH_SERVICE_ENDPOINT: ${{ secrets.AZURE_SEARCH_SERVICE_ENDPOINT }} + +defaults: + run: + working-directory: integrations/azure_ai_search + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + max-parallel: 3 + matrix: + os: [ubuntu-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov-retry + + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/.github/workflows/weaviate.yml b/.github/workflows/weaviate.yml index 06a4bc289..36c30f069 100644 --- a/.github/workflows/weaviate.yml +++ b/.github/workflows/weaviate.yml @@ -30,7 +30,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 32f8ca677..e0ba3d036 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,14 +48,14 @@ By participating, you are expected to uphold this code. Please report unacceptab ## I Have a Question > [!TIP] -> If you want to ask a question, we assume that you have read the available [Documentation](https://docs.haystack.deepset.ai/v2.0/docs/intro). +> If you want to ask a question, we assume that you have read the available [documentation](https://docs.haystack.deepset.ai/docs/intro). -Before you ask a question, it is best to search for existing [Issues](/issues) that might help you. In case you have +Before you ask a question, it is best to search for existing [issues](/../../issues) that might help you. In case you have found a suitable issue and still need clarification, you can write your question in this issue. It is also advisable to search the internet for answers first. If you then still feel the need to ask a question and need clarification, you can use one of our -[Community Channels](https://haystack.deepset.ai/community), Discord in particular is often very helpful. +[community channels](https://haystack.deepset.ai/community). Discord in particular is often very helpful. ## Reporting Bugs @@ -67,8 +67,8 @@ investigate carefully, collect information and describe the issue in detail in y following steps in advance to help us fix any potential bug as fast as possible. - Make sure that you are using the latest version. -- Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions (Make sure that you have read the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/intro). If you are looking for support, you might want to check [this section](#i-have-a-question)). -- To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](/issues). +- Determine if your bug is really a bug and not an error on your side e.g. using incompatible environment components/versions (Make sure that you have read the [documentation](https://docs.haystack.deepset.ai/docs/intro). If you are looking for support, you might want to check [this section](#i-have-a-question)). +- To see if other users have experienced (and potentially already solved) the same issue you are having, check if there is not already a bug report existing for your bug or error in the [bug tracker](/../../issues?labels=bug). - Also make sure to search the internet (including Stack Overflow) to see if users outside of the GitHub community have discussed the issue. - Collect information about the bug: - OS, Platform and Version (Windows, Linux, macOS, x86, ARM) @@ -85,7 +85,7 @@ following steps in advance to help us fix any potential bug as fast as possible. We use GitHub issues to track bugs and errors. If you run into an issue with the project: -- Open an [Issue of type Bug Report](/issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=). +- Open an [issue of type Bug Report](/../../issues/new?assignees=&labels=bug&projects=&template=bug_report.md&title=). - Explain the behavior you would expect and the actual behavior. - Please provide as much context as possible and describe the *reproduction steps* that someone else can follow to recreate the issue on their own. This usually includes your code. For good bug reports you should isolate the problem and create a reduced test case. - Provide the information you collected in the previous section. @@ -94,7 +94,7 @@ Once it's filed: - The project team will label the issue accordingly. - A team member will try to reproduce the issue with your provided steps. If there are no reproduction steps or no obvious way to reproduce the issue, the team will ask you for those steps. -- If the team is able to reproduce the issue, the issue will scheduled for a fix, or left to be [implemented by someone](#your-first-code-contribution). +- If the team can reproduce the issue, it will either be scheduled for a fix or made available for [community contribution](#contribute-code). ## Suggesting Enhancements @@ -106,14 +106,14 @@ to existing ones. Following these guidelines will help maintainers and the commu ### Before Submitting an Enhancement - Make sure that you are using the latest version. -- Read the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/intro) carefully and find out if the functionality is already covered, maybe by an individual configuration. -- Perform a [search](/issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. +- Read the [documentation](https://docs.haystack.deepset.ai/docs/intro) carefully and find out if the functionality is already covered, maybe by an individual configuration. +- Perform a [search](/../../issues) to see if the enhancement has already been suggested. If it has, add a comment to the existing issue instead of opening a new one. - Find out whether your idea fits with the scope and aims of the project. It's up to you to make a strong case to convince the project's developers of the merits of this feature. Keep in mind that we want features that will be useful to the majority of our users and not just a small subset. If you're just targeting a minority of users, consider writing and distributing the integration on your own. ### How Do I Submit a Good Enhancement Suggestion? -Enhancement suggestions are tracked as GitHub issues of type [Feature request for existing integrations](/issues/new?assignees=&labels=feature+request&projects=&template=feature-request-for-existing-integrations.md&title=). +Enhancement suggestions are tracked as GitHub issues of type [Feature request for existing integrations](/../../issues/new?assignees=&labels=feature+request&projects=&template=feature-request-for-existing-integrations.md&title=). - Use a **clear and descriptive title** for the issue to identify the suggestion. - Fill the issue following the template @@ -129,8 +129,8 @@ Enhancement suggestions are tracked as GitHub issues of type [Feature request fo If this is your first contribution, a good starting point is looking for an open issue that's marked with the label ["good first issue"](https://github.com/deepset-ai/haystack-core-integrations/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). The core contributors periodically mark certain issues as good for first-time contributors. Those issues are usually -limited in scope, easy fixable and low priority, so there is absolutely no reason why you should not try fixing them, -it's a good excuse to start looking into the project and a safe space for experimenting failure: if you don't get the +limited in scope, easy fixable and low priority, so there is absolutely no reason why you should not try fixing them. +It's also a good excuse to start looking into the project and a safe space for experimenting failure: if you don't get the grasp of something, pick another one! ### Setting up your development environment @@ -279,7 +279,7 @@ The Python API docs detail the source code: classes, functions, and parameters t This type of documentation is extracted from the source code itself, and contributors should pay attention when they change the code to also change relevant comments and docstrings. This type of documentation is mostly useful to developers, but it can be handy for users at times. You can browse it on the dedicated section in the -[documentation website](https://docs.haystack.deepset.ai/v2.0/reference/integrations-chroma). +[documentation website](https://docs.haystack.deepset.ai/reference/integrations-chroma). We use `pydoc-markdown` to convert docstrings into properly formatted Markdown files, and while the CI takes care of generating and publishing the updated documentation at every merge on the `main` branch, you can generate the docs diff --git a/README.md b/README.md index 2b4a83253..af83d045d 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [cohere-haystack](integrations/cohere/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | | [deepeval-haystack](integrations/deepeval/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/deepeval-haystack.svg)](https://pypi.org/project/deepeval-haystack) | [![Test / deepeval](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | -| [fastembed-haystack](integrations/fastembed/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | +| [fastembed-haystack](integrations/fastembed/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | | [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | | [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | | [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | diff --git a/cliff.toml b/cliff.toml index 29228543e..45e4c647b 100644 --- a/cliff.toml +++ b/cliff.toml @@ -19,11 +19,43 @@ body = """ ## [unreleased] {% endif %}\ {% for group, commits in commits | group_by(attribute="group") %} + {# + Skip the whole section if it contains only a single commit + and it's the commit that updated the changelog. + If we don't do this we get an empty section since we don't show + commits that update the changelog + #}\ + {% if commits | length == 1 and commits[0].message == 'Update the changelog' %}\ + {% continue %}\ + {% endif %}\ ### {{ group | striptags | trim | upper_first }} - {% for commit in commits %} + {% for commit in commits %}\ + {# + Skip commits that update the changelog, they're not useful to the user + #}\ + {% if commit.message == 'Update the changelog' %}\ + {% continue %}\ + {% endif %} - {% if commit.scope %}*({{ commit.scope }})* {% endif %}\ {% if commit.breaking %}[**breaking**] {% endif %}\ - {{ commit.message | upper_first }}\ + {# + We first try to render the conventional commit message if present. + If it's not a conventional commit we get the PR title if present. + If the commit is neither conventional, nor has a PR title set + we fallback to whatever the commit message is. + + We do this cause when merging PRs with multiple commits that don't + have a title following conventional commit guidelines we might get + a commit message that is multiple lines. That makes the changelog + look a bit funky so we handle it like so. + #}\ + {% if commit.conventional %}\ + {{ commit.message | upper_first }}\ + {% elif commit.remote.pr_title %}\ + {{ commit.remote.pr_title | upper_first }} (#{{ commit.remote.pr_number }})\ + {% else %}\ + {{ commit.message | upper_first }}\ + {% endif %}\ {% endfor %} {% endfor %}\n """ @@ -35,7 +67,7 @@ footer = """ trim = true # postprocessors postprocessors = [ - # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL + # { pattern = '', replace = "https://github.com/orhun/git-cliff" }, # replace repository URL ] [git] @@ -47,24 +79,26 @@ filter_unconventional = false split_commits = false # regex for preprocessing the commit messages commit_preprocessors = [ - # Replace issue numbers - #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, - # Check spelling of the commit with https://github.com/crate-ci/typos - # If the spelling is incorrect, it will be automatically fixed. - #{ pattern = '.*', replace_command = 'typos --write-changes -' }, + # Replace issue numbers + #{ pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](/issues/${2}))"}, + # Check spelling of the commit with https://github.com/crate-ci/typos + # If the spelling is incorrect, it will be automatically fixed. + #{ pattern = '.*', replace_command = 'typos --write-changes -' }, ] # regex for parsing and grouping commits commit_parsers = [ - { message = "^feat", group = "๐Ÿš€ Features" }, - { message = "^fix", group = "๐Ÿ› Bug Fixes" }, - { message = "^doc", group = "๐Ÿ“š Documentation" }, - { message = "^perf", group = "โšก Performance" }, - { message = "^refactor", group = "๐Ÿšœ Refactor" }, - { message = "^style", group = "๐ŸŽจ Styling" }, - { message = "^test", group = "๐Ÿงช Testing" }, - { message = "^chore|^ci", group = "โš™๏ธ Miscellaneous Tasks" }, - { body = ".*security", group = "๐Ÿ›ก๏ธ Security" }, - { message = "^revert", group = "โ—€๏ธ Revert" }, + { message = "^feat", group = "๐Ÿš€ Features" }, + { message = "^fix", group = "๐Ÿ› Bug Fixes" }, + { message = "^refactor", group = "๐Ÿšœ Refactor" }, + { message = "^doc", group = "๐Ÿ“š Documentation" }, + { message = "^perf", group = "โšก Performance" }, + { message = "^style", group = "๐ŸŽจ Styling" }, + { message = "^test", group = "๐Ÿงช Testing" }, + { body = ".*security", group = "๐Ÿ›ก๏ธ Security" }, + { message = "^revert", group = "โ—€๏ธ Revert" }, + { message = "^ci", group = "โš™๏ธ CI" }, + { message = "^chore", group = "๐Ÿงน Chores" }, + { message = ".*", group = "๐ŸŒ€ Miscellaneous" }, ] # protect breaking changes from being skipped due to matching a skipping commit_parser protect_breaking_commits = false @@ -82,3 +116,7 @@ topo_order = false sort_commits = "oldest" # limit the number of commits included in the changelog. # limit_commits = 42 + +[remote.github] +owner = "deepset-ai" +repo = "haystack-core-integrations" diff --git a/integrations/amazon_bedrock/CHANGELOG.md b/integrations/amazon_bedrock/CHANGELOG.md index 417c661fe..1068e870a 100644 --- a/integrations/amazon_bedrock/CHANGELOG.md +++ b/integrations/amazon_bedrock/CHANGELOG.md @@ -1,15 +1,42 @@ # Changelog -## [unreleased] +## [integrations/amazon_bedrock-v1.1.0] - 2024-10-23 + +### ๐Ÿšœ Refactor + +- Avoid downloading tokenizer if `truncate` is `False` (#1152) + +### โš™๏ธ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + +## [integrations/amazon_bedrock-v1.0.5] - 2024-10-17 + +### ๐Ÿš€ Features + +- Add prefixes to supported model patterns to allow cross region model ids (#1127) + +## [integrations/amazon_bedrock-v1.0.4] - 2024-10-16 + +### ๐Ÿ› Bug Fixes + +- Avoid bedrock read timeout (add boto3_config param) (#1135) + +## [integrations/amazon_bedrock-v1.0.3] - 2024-10-04 ### ๐Ÿ› Bug Fixes - *(Bedrock)* Allow tools kwargs for AWS Bedrock Claude model (#976) +- Chat roles for model responses in chat generators (#1030) ### ๐Ÿšœ Refactor - Remove usage of deprecated `ChatMessage.to_openai_format` (#1007) +### โš™๏ธ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) + ## [integrations/amazon_bedrock-v1.0.1] - 2024-08-19 ### ๐Ÿš€ Features diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 1298abfab..872d4933b 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/amazon_bedrock-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py index 719130f0b..2404488a3 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/chat/chat_generator.py @@ -58,9 +58,9 @@ class AmazonBedrockChatGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelChatAdapter]]] = { - r"anthropic.claude.*": AnthropicClaudeChatAdapter, - r"meta.llama2.*": MetaLlama2ChatAdapter, - r"mistral.*": MistralChatAdapter, + r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeChatAdapter, + r"([a-z]{2}\.)?meta.llama2.*": MetaLlama2ChatAdapter, + r"([a-z]{2}\.)?mistral.*": MistralChatAdapter, } def __init__( diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index 6ef0a4765..941fdbf71 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -3,6 +3,7 @@ import re from typing import Any, Callable, ClassVar, Dict, List, Optional, Type +from botocore.config import Config from botocore.exceptions import ClientError from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk @@ -65,13 +66,13 @@ class AmazonBedrockGenerator: """ SUPPORTED_MODEL_PATTERNS: ClassVar[Dict[str, Type[BedrockModelAdapter]]] = { - r"amazon.titan-text.*": AmazonTitanAdapter, - r"ai21.j2.*": AI21LabsJurassic2Adapter, - r"cohere.command-[^r].*": CohereCommandAdapter, - r"cohere.command-r.*": CohereCommandRAdapter, - r"anthropic.claude.*": AnthropicClaudeAdapter, - r"meta.llama.*": MetaLlamaAdapter, - r"mistral.*": MistralAdapter, + r"([a-z]{2}\.)?amazon.titan-text.*": AmazonTitanAdapter, + r"([a-z]{2}\.)?ai21.j2.*": AI21LabsJurassic2Adapter, + r"([a-z]{2}\.)?cohere.command-[^r].*": CohereCommandAdapter, + r"([a-z]{2}\.)?cohere.command-r.*": CohereCommandRAdapter, + r"([a-z]{2}\.)?anthropic.claude.*": AnthropicClaudeAdapter, + r"([a-z]{2}\.)?meta.llama.*": MetaLlamaAdapter, + r"([a-z]{2}\.)?mistral.*": MistralAdapter, } def __init__( @@ -87,6 +88,7 @@ def __init__( max_length: Optional[int] = 100, truncate: Optional[bool] = True, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + boto3_config: Optional[Dict[str, Any]] = None, **kwargs, ): """ @@ -102,6 +104,7 @@ def __init__( :param truncate: Whether to truncate the prompt or not. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param boto3_config: The configuration for the boto3 client. :param kwargs: Additional keyword arguments to be passed to the model. These arguments are specific to the model. You can find them in the model's documentation. :raises ValueError: If the model name is empty or None. @@ -120,6 +123,7 @@ def __init__( self.aws_region_name = aws_region_name self.aws_profile_name = aws_profile_name self.streaming_callback = streaming_callback + self.boto3_config = boto3_config self.kwargs = kwargs def resolve_secret(secret: Optional[Secret]) -> Optional[str]: @@ -133,7 +137,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: aws_region_name=resolve_secret(aws_region_name), aws_profile_name=resolve_secret(aws_profile_name), ) - self.client = session.client("bedrock-runtime") + config: Optional[Config] = None + if self.boto3_config: + config = Config(**self.boto3_config) + self.client = session.client("bedrock-runtime", config=config) except Exception as exception: msg = ( "Could not connect to Amazon Bedrock. Make sure the AWS environment is configured correctly. " @@ -145,15 +152,16 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]: # We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed model_max_length = kwargs.get("model_max_length", 4096) - # Truncate prompt if prompt tokens > model_max_length-max_length - # (max_length is the length of the generated text) - # we use GPT2 tokenizer which will likely provide good token count approximation - - self.prompt_handler = DefaultPromptHandler( - tokenizer="gpt2", - model_max_length=model_max_length, - max_length=self.max_length or 100, - ) + # we initialize the prompt handler only if truncate is True: we avoid unnecessarily downloading the tokenizer + if self.truncate: + # Truncate prompt if prompt tokens > model_max_length-max_length + # (max_length is the length of the generated text) + # we use GPT2 tokenizer which will likely provide good token count approximation + self.prompt_handler = DefaultPromptHandler( + tokenizer="gpt2", + model_max_length=model_max_length, + max_length=self.max_length or 100, + ) model_adapter_cls = self.get_model_adapter(model=model) if not model_adapter_cls: @@ -273,6 +281,7 @@ def to_dict(self) -> Dict[str, Any]: max_length=self.max_length, truncate=self.truncate, streaming_callback=callback_name, + boto3_config=self.boto3_config, **self.kwargs, ) diff --git a/integrations/amazon_bedrock/tests/test_chat_generator.py b/integrations/amazon_bedrock/tests/test_chat_generator.py index df3fb4381..a60976ef2 100644 --- a/integrations/amazon_bedrock/tests/test_chat_generator.py +++ b/integrations/amazon_bedrock/tests/test_chat_generator.py @@ -243,7 +243,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.run(messages=messages) # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called(), + mock_ensure_token_limit.assert_not_called() # Check the prompt passed to prepare_body generator.model_adapter.prepare_body.assert_called_with(messages=messages, stop_words=[]) @@ -254,11 +254,16 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): [ ("anthropic.claude-v1", AnthropicClaudeChatAdapter), ("anthropic.claude-v2", AnthropicClaudeChatAdapter), + ("eu.anthropic.claude-v1", AnthropicClaudeChatAdapter), # cross-region inference + ("us.anthropic.claude-v2", AnthropicClaudeChatAdapter), # cross-region inference ("anthropic.claude-instant-v1", AnthropicClaudeChatAdapter), ("anthropic.claude-super-v5", AnthropicClaudeChatAdapter), # artificial ("meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), ("meta.llama2-130b-v5", MetaLlama2ChatAdapter), # artificial + ("us.meta.llama2-13b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference + ("eu.meta.llama2-70b-chat-v1", MetaLlama2ChatAdapter), # cross-region inference + ("de.meta.llama2-130b-v5", MetaLlama2ChatAdapter), # cross-region inference ("unknown_model", None), ], ) @@ -515,7 +520,6 @@ def test_get_responses(self) -> None: @pytest.mark.parametrize("model_name", MODELS_TO_TEST) @pytest.mark.integration def test_default_inference_params(self, model_name, chat_messages): - client = AmazonBedrockChatGenerator(model=model_name) response = client.run(chat_messages) diff --git a/integrations/amazon_bedrock/tests/test_generator.py b/integrations/amazon_bedrock/tests/test_generator.py index f0233888c..be645218e 100644 --- a/integrations/amazon_bedrock/tests/test_generator.py +++ b/integrations/amazon_bedrock/tests/test_generator.py @@ -36,6 +36,7 @@ def test_to_dict(mock_boto3_session): "truncate": False, "temperature": 10, "streaming_callback": None, + "boto3_config": None, }, } @@ -57,12 +58,16 @@ def test_from_dict(mock_boto3_session): "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, "model": "anthropic.claude-v2", "max_length": 99, + "boto3_config": { + "read_timeout": 1000, + }, }, } ) assert generator.max_length == 99 assert generator.model == "anthropic.claude-v2" + assert generator.boto3_config == {"read_timeout": 1000} def test_default_constructor(mock_boto3_session, set_env_variables): @@ -103,6 +108,14 @@ def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_ assert layer.prompt_handler.model_max_length == 4096 +def test_prompt_handler_absent_when_truncate_false(mock_boto3_session): + """ + Test that the prompt_handler is not initialized when truncate is set to False. + """ + generator = AmazonBedrockGenerator(model="anthropic.claude-v2", truncate=False) + assert not hasattr(generator, "prompt_handler") + + def test_constructor_with_model_kwargs(mock_boto3_session): """ Test that model_kwargs are correctly set in the constructor @@ -220,7 +233,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): generator.run(prompt=long_prompt_text) # Ensure _ensure_token_limit was not called - mock_ensure_token_limit.assert_not_called(), + mock_ensure_token_limit.assert_not_called() # Check the prompt passed to prepare_body generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False) @@ -231,6 +244,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): [ ("anthropic.claude-v1", AnthropicClaudeAdapter), ("anthropic.claude-v2", AnthropicClaudeAdapter), + ("eu.anthropic.claude-v1", AnthropicClaudeAdapter), # cross-region inference + ("us.anthropic.claude-v2", AnthropicClaudeAdapter), # cross-region inference ("anthropic.claude-instant-v1", AnthropicClaudeAdapter), ("anthropic.claude-super-v5", AnthropicClaudeAdapter), # artificial ("cohere.command-text-v14", CohereCommandAdapter), @@ -244,10 +259,13 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("ai21.j2-mega-v5", AI21LabsJurassic2Adapter), # artificial ("amazon.titan-text-lite-v1", AmazonTitanAdapter), ("amazon.titan-text-express-v1", AmazonTitanAdapter), + ("us.amazon.titan-text-express-v1", AmazonTitanAdapter), # cross-region inference ("amazon.titan-text-agile-v1", AmazonTitanAdapter), ("amazon.titan-text-lightning-v8", AmazonTitanAdapter), # artificial ("meta.llama2-13b-chat-v1", MetaLlamaAdapter), ("meta.llama2-70b-chat-v1", MetaLlamaAdapter), + ("eu.meta.llama2-13b-chat-v1", MetaLlamaAdapter), # cross-region inference + ("us.meta.llama2-70b-chat-v1", MetaLlamaAdapter), # cross-region inference ("meta.llama2-130b-v5", MetaLlamaAdapter), # artificial ("meta.llama3-8b-instruct-v1:0", MetaLlamaAdapter), ("meta.llama3-70b-instruct-v1:0", MetaLlamaAdapter), @@ -255,6 +273,8 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session): ("mistral.mistral-7b-instruct-v0:2", MistralAdapter), ("mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), ("mistral.mistral-large-2402-v1:0", MistralAdapter), + ("eu.mistral.mixtral-8x7b-instruct-v0:1", MistralAdapter), # cross-region inference + ("us.mistral.mistral-large-2402-v1:0", MistralAdapter), # cross-region inference ("mistral.mistral-medium-v8:0", MistralAdapter), # artificial ("unknown_model", None), ], diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml index a25b806f6..219b4c2df 100644 --- a/integrations/amazon_sagemaker/pyproject.toml +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -45,6 +45,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/amazon_sagemaker-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -65,8 +66,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/anthropic/pyproject.toml b/integrations/anthropic/pyproject.toml index 987f017be..21e23fbb4 100644 --- a/integrations/anthropic/pyproject.toml +++ b/integrations/anthropic/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/anthropic-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py index c2c1ee40d..0bd29898e 100644 --- a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 from .chat.chat_generator import AnthropicChatGenerator +from .chat.vertex_chat_generator import AnthropicVertexChatGenerator from .generator import AnthropicGenerator -__all__ = ["AnthropicGenerator", "AnthropicChatGenerator"] +__all__ = ["AnthropicGenerator", "AnthropicChatGenerator", "AnthropicVertexChatGenerator"] diff --git a/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py new file mode 100644 index 000000000..4ece944cd --- /dev/null +++ b/integrations/anthropic/src/haystack_integrations/components/generators/anthropic/chat/vertex_chat_generator.py @@ -0,0 +1,135 @@ +import os +from typing import Any, Callable, Dict, Optional + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import StreamingChunk +from haystack.utils import deserialize_callable, serialize_callable + +from anthropic import AnthropicVertex + +from .chat_generator import AnthropicChatGenerator + +logger = logging.getLogger(__name__) + + +@component +class AnthropicVertexChatGenerator(AnthropicChatGenerator): + """ + + Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API. + It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`, + accessible through the Vertex AI API endpoint. + + To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled. + Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden. + Before making requests, you may need to authenticate with GCP using `gcloud auth login`. + For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai). + + Any valid text generation parameters for the Anthropic messaging API can be passed to + the AnthropicVertex API. Users can provide these parameters directly to the component via + the `generation_kwargs` parameter in `__init__` or the `run` method. + + For more details on the parameters supported by the Anthropic API, refer to the + Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages). + + ```python + from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_user("What's Natural Language Processing?")] + client = AnthropicVertexChatGenerator( + model="claude-3-sonnet@20240229", + project_id="your-project-id", region="your-region" + ) + response = client.run(messages) + print(response) + + >> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that + >> focuses on enabling computers to understand, interpret, and generate human language. It involves developing + >> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and + >> communicate in natural languages like English, Spanish, or Chinese.', role=, + >> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn', + >> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]} + ``` + + For more details on supported models and their capabilities, refer to the Anthropic + [documentation](https://docs.anthropic.com/claude/docs/intro-to-claude). + + """ + + def __init__( + self, + region: str, + project_id: str, + model: str = "claude-3-5-sonnet@20240620", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ignore_tools_thinking_messages: bool = True, + ): + """ + Creates an instance of AnthropicVertexChatGenerator. + + :param region: The region where the Anthropic model is deployed. Defaults to "us-central1". + :param project_id: The GCP project ID where the Anthropic model is deployed. + :param model: The name of the model to use. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function accepts StreamingChunk as an argument. + :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to + the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post) + for more details. + + Supported generation_kwargs parameters are: + - `system`: The system message to be passed to the model. + - `max_tokens`: The maximum number of tokens to generate. + - `metadata`: A dictionary of metadata to be passed to the model. + - `stop_sequences`: A list of strings that the model should stop generating at. + - `temperature`: The temperature to use for sampling. + - `top_p`: The top_p value to use for nucleus sampling. + - `top_k`: The top_k value to use for top-k sampling. + - `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features). + :param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a + "chain of thought" messages before returning the actual function names and parameters in a message. If + `ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool + use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use) + for more details. + """ + self.region = region or os.environ.get("REGION") + self.project_id = project_id or os.environ.get("PROJECT_ID") + self.model = model + self.generation_kwargs = generation_kwargs or {} + self.streaming_callback = streaming_callback + self.client = AnthropicVertex(region=self.region, project_id=self.project_id) + self.ignore_tools_thinking_messages = ignore_tools_thinking_messages + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + region=self.region, + project_id=self.project_id, + model=self.model, + streaming_callback=callback_name, + generation_kwargs=self.generation_kwargs, + ignore_tools_thinking_messages=self.ignore_tools_thinking_messages, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator": + """ + Deserialize this component from a dictionary. + + :param data: The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) diff --git a/integrations/anthropic/tests/test_vertex_chat_generator.py b/integrations/anthropic/tests/test_vertex_chat_generator.py new file mode 100644 index 000000000..a67e801ad --- /dev/null +++ b/integrations/anthropic/tests/test_vertex_chat_generator.py @@ -0,0 +1,197 @@ +import os + +import anthropic +import pytest +from haystack.components.generators.utils import print_streaming_chunk +from haystack.dataclasses import ChatMessage, ChatRole + +from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."), + ChatMessage.from_user("What's the capital of France?"), + ] + + +class TestAnthropicVertexChatGenerator: + def test_init_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is None + assert not component.generation_kwargs + assert component.ignore_tools_thinking_messages + + def test_init_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ignore_tools_thinking_messages=False, + ) + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.model == "claude-3-5-sonnet@20240620" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.ignore_tools_thinking_messages is False + + def test_to_dict_default(self): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": None, + "generation_kwargs": {}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_parameters(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + streaming_callback=print_streaming_chunk, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_to_dict_with_lambda_streaming_callback(self): + component = AnthropicVertexChatGenerator( + region="us-central1", + project_id="test-project-id", + model="claude-3-5-sonnet@20240620", + streaming_callback=lambda x: x, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "tests.test_vertex_chat_generator.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + + def test_from_dict(self): + data = { + "type": ( + "haystack_integrations.components.generators." + "anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator" + ), + "init_parameters": { + "region": "us-central1", + "project_id": "test-project-id", + "model": "claude-3-5-sonnet@20240620", + "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "ignore_tools_thinking_messages": True, + }, + } + component = AnthropicVertexChatGenerator.from_dict(data) + assert component.model == "claude-3-5-sonnet@20240620" + assert component.region == "us-central1" + assert component.project_id == "test-project-id" + assert component.streaming_callback is print_streaming_chunk + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + def test_run(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id") + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_params(self, chat_messages, mock_chat_completion): + component = AnthropicVertexChatGenerator( + region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the Anthropic API with the correct parameters + _, kwargs = mock_chat_completion.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = AnthropicVertexChatGenerator( + model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID") + ) + with pytest.raises(anthropic.NotFoundError): + component.run(chat_messages) + + @pytest.mark.skipif( + not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)), + reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.", + ) + @pytest.mark.integration + def test_default_inference_params(self, chat_messages): + client = AnthropicVertexChatGenerator( + region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229" + ) + response = client.run(chat_messages) + + assert "replies" in response, "Response does not contain 'replies' key" + replies = response["replies"] + assert isinstance(replies, list), "Replies is not a list" + assert len(replies) > 0, "No replies received" + + first_reply = replies[0] + assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance" + assert first_reply.content, "First reply has no content" + assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant" + assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'" + assert first_reply.meta, "First reply has no metadata" + + # Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint, + # remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator. diff --git a/integrations/astra/CHANGELOG.md b/integrations/astra/CHANGELOG.md index 79bb9e35d..fff6cb65f 100644 --- a/integrations/astra/CHANGELOG.md +++ b/integrations/astra/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## [integrations/astra-v0.10.0] - 2024-10-22 + +### ๐Ÿš€ Features + +- Update astradb integration for latest client library (#1145) + +### โš™๏ธ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + ## [integrations/astra-v0.9.3] - 2024-09-12 ### ๐Ÿ› Bug Fixes @@ -23,9 +34,7 @@ ### ๐Ÿ› Bug Fixes - Fix astra nightly - - Fix typing checks - - `Astra` - Fallback to default filter policy when deserializing retrievers without the init parameter (#896) ### โš™๏ธ Miscellaneous Tasks @@ -50,8 +59,6 @@ - Fix haystack-ai pin (#649) - - ## [integrations/astra-v0.5.0] - 2024-03-18 ### ๐Ÿ“š Documentation @@ -75,8 +82,6 @@ This PR will also push the docs to Readme - Fix integration tests (#450) - - ## [integrations/astra-v0.4.0] - 2024-02-20 ### ๐Ÿ“š Documentation diff --git a/integrations/astra/README.md b/integrations/astra/README.md index f679b7207..9ee47b8c9 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -6,17 +6,18 @@ ```bash pip install astra-haystack - ``` ### Local Development + install astra-haystack package locally to run integration tests: Open in gitpod: [![Open in Gitpod](https://gitpod.io/button/open-in-gitpod.svg)](https://gitpod.io/#https://github.com/Anant/astra-haystack/tree/main) -Switch Python version to 3.9 (Requires 3.8+ but not 3.12) -``` +Switch Python version to 3.9 (Requires 3.9+ but not 3.12) + +```bash pyenv install 3.9 pyenv local 3.9 ``` @@ -33,7 +34,8 @@ Install requirements `pip install -r requirements.txt` Export environment variables -``` + +```bash export ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com" export ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." export COLLECTION_NAME="my_collection" @@ -49,22 +51,25 @@ or This package includes Astra Document Store and Astra Embedding Retriever classes that integrate with Haystack, allowing you to easily perform document retrieval or RAG with Astra, and include those functions in Haystack pipelines. -### In order to use the Document Store directly: +### Use the Document Store Directly Import the Document Store: -``` + +```python from haystack_integrations.document_stores.astra import AstraDocumentStore from haystack.document_stores.types.policy import DuplicatePolicy ``` Load in environment variables: -``` + +```python namespace = os.environ.get("ASTRA_DB_KEYSPACE") collection_name = os.environ.get("COLLECTION_NAME", "haystack_vector_search") ``` Create the Document Store object (API Endpoint and Token are read off the environment): -``` + +```python document_store = AstraDocumentStore( collection_name=collection_name, namespace=namespace, @@ -80,7 +85,7 @@ Then you can use the document store functions like count_document below: Create the Document Store object like above, then import and create the Pipeline: -``` +```python from haystack import Pipeline pipeline = Pipeline() ``` @@ -101,7 +106,6 @@ or, > Astra DB collection '...' is detected as having the following indexing policy: {...}. This does not match the requested indexing policy for this object: {...}. In particular, there may be stricter limitations on the amount of text each string in a document can store. Consider indexing anew on a fresh collection to be able to store longer texts. - The reason for the warning is that the requested collection already exists on the database, and it is configured to [index all of its fields for search](https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option), possibly implicitly, by default. When the Haystack object tries to create it, it attempts to enforce, instead, an indexing policy tailored to the prospected usage: this is both to enable storing very long texts and to avoid indexing fields that will never be used in filtering a search (indexing those would also have a slight performance cost for writes). Typically there are two reasons why you may encounter the warning: diff --git a/integrations/astra/examples/requirements.txt b/integrations/astra/examples/requirements.txt index 710749bbe..221138666 100644 --- a/integrations/astra/examples/requirements.txt +++ b/integrations/astra/examples/requirements.txt @@ -1,4 +1,4 @@ haystack-ai sentence_transformers==2.2.2 openai==1.6.1 -astrapy>=0.7.7 \ No newline at end of file +astrapy>=1.5.0,<2.0 diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 25bcf20b8..5645cd5d3 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -7,7 +7,7 @@ name = "astra-haystack" dynamic = ["version"] description = '' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "Anant Corporation", email = "support@anant.us" }] @@ -15,14 +15,13 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "pydantic", "typing_extensions", "astrapy"] +dependencies = ["haystack-ai", "pydantic", "typing_extensions", "astrapy>=1.5.0,<2.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/astra#readme" @@ -41,6 +40,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/astra-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -56,11 +56,12 @@ cov = ["test-cov", "cov-report"] cov-retry = ["test-cov-retry", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] -python = ["3.8", "3.9", "3.10", "3.11"] +python = ["3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index b594f87d3..6f2289786 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -3,8 +3,9 @@ from typing import Dict, List, Optional, Union from warnings import warn -from astrapy.api import APIRequestError -from astrapy.db import AstraDB +from astrapy import DataAPIClient as AstraDBClient +from astrapy.constants import ReturnDocument +from astrapy.exceptions import CollectionAlreadyExistsException from haystack.version import __version__ as integration_version from pydantic.dataclasses import dataclass @@ -65,83 +66,78 @@ def __init__( self.similarity_function = similarity_function self.namespace = namespace - # Build the Astra DB object - self._astra_db = AstraDB( + # Get the keyspace from the collection name + my_client = AstraDBClient( + callers=[(CALLER_NAME, integration_version)], + ) + + # Get the database object + self._astra_db = my_client.get_database( api_endpoint=api_endpoint, token=token, - namespace=namespace, - caller_name=CALLER_NAME, - caller_version=integration_version, + keyspace=namespace, ) - indexing_options = {"indexing": {"deny": NON_INDEXED_FIELDS}} + indexing_options = {"deny": NON_INDEXED_FIELDS} try: # Create and connect to the newly created collection self._astra_db_collection = self._astra_db.create_collection( - collection_name=collection_name, + name=collection_name, dimension=embedding_dimension, - options=indexing_options, + indexing=indexing_options, ) - except APIRequestError: + except CollectionAlreadyExistsException as _: # possibly the collection is preexisting and has legacy # indexing settings: verify - get_coll_response = self._astra_db.get_collections(options={"explain": True}) - - collections = (get_coll_response["status"] or {}).get("collections") or [] - - preexisting = [collection for collection in collections if collection["name"] == collection_name] + preexisting = [ + coll_descriptor + for coll_descriptor in self._astra_db.list_collections() + if coll_descriptor.name == collection_name + ] if preexisting: - pre_collection = preexisting[0] # if it has no "indexing", it is a legacy collection; - # otherwise it's unexpected warn and proceed at user's risk - pre_col_options = pre_collection.get("options") or {} - if "indexing" not in pre_col_options: + # otherwise it's unexpected: warn and proceed at user's risk + pre_col_idx_opts = preexisting[0].options.indexing or {} + if not pre_col_idx_opts: warn( ( - f"Astra DB collection '{collection_name}' is " - "detected as having indexing turned on for all " - "fields (either created manually or by older " - "versions of this plugin). This implies stricter " - "limitations on the amount of text each string in a " - "document can store. Consider indexing anew on a " - "fresh collection to be able to store longer texts. " - "See https://github.com/deepset-ai/haystack-core-" - "integrations/blob/main/integrations/astra/README" - ".md#warnings-about-indexing for more details." + f"Collection '{collection_name}' is detected as " + "having indexing turned on for all fields " + "(either created manually or by older versions " + "of this plugin). This implies stricter " + "limitations on the amount of text" + " each entry can store. Consider indexing anew on a" + " fresh collection to be able to store longer texts." ), UserWarning, stacklevel=2, ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, + self._astra_db_collection = self._astra_db.get_collection( + collection_name, + ) + # check if the indexing options match entirely + elif pre_col_idx_opts == indexing_options: + self._astra_db_collection = self._astra_db.get_collection( + collection_name, ) - elif pre_col_options["indexing"] != indexing_options["indexing"]: - detected_options_json = json.dumps(pre_col_options["indexing"]) - indexing_options_json = json.dumps(indexing_options["indexing"]) + else: + options_json = json.dumps(pre_col_idx_opts) warn( ( - f"Astra DB collection '{collection_name}' is " - "detected as having the following indexing policy: " - f"{detected_options_json}. This does not match the requested " - f"indexing policy for this object: {indexing_options_json}. " - "In particular, there may be stricter " - "limitations on the amount of text each string in a " - "document can store. Consider indexing anew on a " - "fresh collection to be able to store longer texts. " - "See https://github.com/deepset-ai/haystack-core-" - "integrations/blob/main/integrations/astra/README" - ".md#warnings-about-indexing for more details." + f"Collection '{collection_name}' has unexpected 'indexing'" + f" settings (options.indexing = {options_json})." + " This can result in odd behaviour when running " + " metadata filtering and/or unwarranted limitations" + " on storing long texts. Consider indexing anew on a" + " fresh collection." ), UserWarning, stacklevel=2, ) - self._astra_db_collection = self._astra_db.collection( - collection_name=collection_name, + self._collection = self._astra_db.get_collection( + collection_name, ) - else: - # the collection mismatch lies elsewhere than the indexing - raise else: # other exception raise @@ -180,7 +176,7 @@ def query( return formatted_response def _query_without_vector(self, top_k, filters=None): - query = {"filter": filters, "options": {"limit": top_k}} + query = {"filter": filters, "limit": top_k} return self.find_documents(query) @@ -196,8 +192,11 @@ def _format_query_response(responses, include_metadata, include_values): score = response.pop("$similarity", None) text = response.pop("content", None) values = response.pop("$vector", None) if include_values else [] + metadata = response if include_metadata else {} # Add all remaining fields to the metadata + rsp = Response(_id, text, values, metadata, score) + final_res.append(rsp) return QueryResponse(final_res) @@ -219,17 +218,21 @@ def find_documents(self, find_query): :param find_query: a dictionary with the query options :returns: the documents found in the index """ - response_dict = self._astra_db_collection.find( + find_cursor = self._astra_db_collection.find( filter=find_query.get("filter"), sort=find_query.get("sort"), - options=find_query.get("options"), + limit=find_query.get("limit"), projection={"*": 1}, ) - if "data" in response_dict and "documents" in response_dict["data"]: - return response_dict["data"]["documents"] - else: - logger.warning(f"No documents found: {response_dict}") + find_results = [] + for result in find_cursor: + find_results.append(result) + + if not find_results: + logger.warning("No documents found.") + + return find_results def find_one_document(self, find_query): """ @@ -238,16 +241,15 @@ def find_one_document(self, find_query): :param find_query: a dictionary with the query options :returns: the document found in the index """ - response_dict = self._astra_db_collection.find_one( + find_result = self._astra_db_collection.find_one( filter=find_query.get("filter"), - options=find_query.get("options"), projection={"*": 1}, ) - if "data" in response_dict and "document" in response_dict["data"]: - return response_dict["data"]["document"] - else: - logger.warning(f"No document found: {response_dict}") + if not find_result: + logger.warning("No document found.") + + return find_result def get_documents(self, ids: List[str], batch_size: int = 20) -> QueryResponse: """ @@ -281,15 +283,8 @@ def insert(self, documents: List[Dict]): :param documents: a list of documents to insert :returns: the IDs of the inserted documents """ - response_dict = self._astra_db_collection.insert_many(documents=documents) - - inserted_ids = ( - response_dict["status"]["insertedIds"] - if "status" in response_dict and "insertedIds" in response_dict["status"] - else [] - ) - if "errors" in response_dict: - logger.error(response_dict["errors"]) + insert_result = self._astra_db_collection.insert_many(documents=documents) + inserted_ids = [str(_id) for _id in insert_result.inserted_ids] return inserted_ids @@ -303,23 +298,21 @@ def update_document(self, document: Dict, id_key: str): """ document_id = document.pop(id_key) - response_dict = self._astra_db_collection.find_one_and_update( + update_result = self._astra_db_collection.find_one_and_update( filter={id_key: document_id}, update={"$set": document}, - options={"returnDocument": "after"}, + return_document=ReturnDocument.AFTER, projection={"*": 1}, ) document[id_key] = document_id - if "status" in response_dict and "errors" not in response_dict: - if "matchedCount" in response_dict["status"] and "modifiedCount" in response_dict["status"]: - if response_dict["status"]["matchedCount"] == 1 and response_dict["status"]["modifiedCount"] == 1: - return True + if update_result is None: + logger.warning(f"Documents {document_id} not updated in Astra DB.") - logger.warning(f"Documents {document_id} not updated in Astra DB.") + return False - return False + return True def delete( self, @@ -345,23 +338,13 @@ def delete( if "filter" in query["deleteMany"]: filter_dict = query["deleteMany"]["filter"] - deletion_counter = 0 - moredata = True - while moredata: - response_dict = self._astra_db_collection.delete_many(filter=filter_dict) - - if "moreData" not in response_dict.get("status", {}): - moredata = False + delete_result = self._astra_db_collection.delete_many(filter=filter_dict) - deletion_counter += int(response_dict["status"].get("deletedCount", 0)) + return delete_result.deleted_count - return deletion_counter - - def count_documents(self) -> int: + def count_documents(self, upper_bound: int = 10000) -> int: """ Count the number of documents in the Astra index. :returns: the number of documents in the index """ - documents_count = self._astra_db_collection.count_documents() - - return documents_count["status"]["count"] + return self._astra_db_collection.count_documents({}, upper_bound=upper_bound) diff --git a/integrations/astra/tests/test_document_store.py b/integrations/astra/tests/test_document_store.py index c4d1b6347..ef00b6b25 100644 --- a/integrations/astra/tests/test_document_store.py +++ b/integrations/astra/tests/test_document_store.py @@ -20,25 +20,14 @@ def mock_auth(monkeypatch): monkeypatch.setenv("ASTRA_DB_APPLICATION_TOKEN", "test_token") -@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") +@mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDBClient") def test_init_is_lazy(_mock_client, mock_auth): # noqa _ = AstraDocumentStore() _mock_client.assert_not_called() -def test_namespace_init(mock_auth): # noqa - with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB") as client: - _ = AstraDocumentStore().index - assert "namespace" in client.call_args.kwargs - assert client.call_args.kwargs["namespace"] is None - - _ = AstraDocumentStore(namespace="foo").index - assert "namespace" in client.call_args.kwargs - assert client.call_args.kwargs["namespace"] == "foo" - - def test_to_dict(mock_auth): # noqa - with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDB"): + with mock.patch("haystack_integrations.document_stores.astra.astra_client.AstraDBClient"): ds = AstraDocumentStore() result = ds.to_dict() assert result["type"] == "haystack_integrations.document_stores.astra.document_store.AstraDocumentStore" @@ -206,6 +195,17 @@ def test_filter_documents_by_id(self, document_store): result = document_store.filter_documents(filters={"field": "id", "operator": "==", "value": "1"}) self.assert_documents_are_equal(result, [docs[0]]) + def test_filter_documents_by_in_operator(self, document_store): + docs = [Document(id="3", content="test doc 3"), Document(id="4", content="test doc 4")] + document_store.write_documents(docs) + result = document_store.filter_documents(filters={"field": "id", "operator": "in", "value": ["3", "4"]}) + + # Sort the result in place by the id field + result.sort(key=lambda x: x.id) + + self.assert_documents_are_equal([result[0]], [docs[0]]) + self.assert_documents_are_equal([result[1]], [docs[1]]) + @pytest.mark.skip(reason="Unsupported filter operator not.") def test_not_operator(self, document_store, filterable_docs): pass diff --git a/integrations/azure_ai_search/LICENSE b/integrations/azure_ai_search/LICENSE new file mode 100644 index 000000000..de4c7f39f --- /dev/null +++ b/integrations/azure_ai_search/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/azure_ai_search/README.md b/integrations/azure_ai_search/README.md new file mode 100644 index 000000000..915a23b63 --- /dev/null +++ b/integrations/azure_ai_search/README.md @@ -0,0 +1,26 @@ +# Azure AI Search Document Store for Haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) + +----- + +**Table of Contents** + +- [Azure AI Search Document Store for Haystack](#azure-ai-search-document-store-for-haystack) + - [Installation](#installation) + - [Examples](#examples) + - [License](#license) + +## Installation + +```console +pip install azure-ai-search-haystack +``` + +## Examples +You can find a code example showing how to use the Document Store and the Retriever in the documentation or in [this Colab](https://colab.research.google.com/drive/1YpDetI8BRbObPDEVdfqUcwhEX9UUXP-m?usp=sharing). + +## License + +`azure-ai-search-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py new file mode 100644 index 000000000..92a641717 --- /dev/null +++ b/integrations/azure_ai_search/example/document_store.py @@ -0,0 +1,43 @@ +from haystack import Document + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchDocumentStore to write and filter documents. +To run this example, you'll need an Azure Search service endpoint and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" +document_store = AzureAISearchDocumentStore( + metadata_fields={"version": float, "label": str}, + index_name="document-store-example", +) + +documents = [ + Document( + content="This is an introduction to using Python for data analysis.", + meta={"version": 1.0, "label": "chapter_one"}, + ), + Document( + content="Learn how to use Python libraries for machine learning.", + meta={"version": 1.5, "label": "chapter_two"}, + ), + Document( + content="Advanced Python techniques for data visualization.", + meta={"version": 2.0, "label": "chapter_three"}, + ), +] +document_store.write_documents(documents) + +filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.version", "operator": ">", "value": 1.2}, + {"field": "meta.label", "operator": "in", "value": ["chapter_one", "chapter_three"]}, + ], +} + +results = document_store.filter_documents(filters) +print(results) diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py new file mode 100644 index 000000000..188f8525a --- /dev/null +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -0,0 +1,55 @@ +from haystack import Document, Pipeline +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.writers import DocumentWriter + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchEmbeddingRetriever to retrieve documents +using embeddings based on a query. To run this example, you'll need an Azure Search service endpoint +and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" + +document_store = AzureAISearchDocumentStore(index_name="retrieval-example") + +model = "sentence-transformers/all-mpnet-base-v2" + +documents = [ + Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="""Elephants have been observed to behave in a way that indicates a + high level of self-awareness, such as recognizing themselves in mirrors.""" + ), + Document( + content="""In certain parts of the world, like the Maldives, Puerto Rico, and + San Diego, you can witness the phenomenon of bioluminescent waves.""" + ), +] + +document_embedder = SentenceTransformersDocumentEmbedder(model=model) +document_embedder.warm_up() + +# Indexing Pipeline +indexing_pipeline = Pipeline() +indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") +indexing_pipeline.add_component(instance=DocumentWriter(document_store=document_store), name="doc_writer") +indexing_pipeline.connect("doc_embedder", "doc_writer") + +indexing_pipeline.run({"doc_embedder": {"documents": documents}}) + +# Query Pipeline +query_pipeline = Pipeline() +query_pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder(model=model)) +query_pipeline.add_component("retriever", AzureAISearchEmbeddingRetriever(document_store=document_store)) +query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + +query = "How many languages are there?" + +result = query_pipeline.run({"text_embedder": {"text": query}}) + +print(result["retriever"]["documents"][0]) diff --git a/integrations/azure_ai_search/pydoc/config.yml b/integrations/azure_ai_search/pydoc/config.yml new file mode 100644 index 000000000..ec411af60 --- /dev/null +++ b/integrations/azure_ai_search/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever", + "haystack_integrations.document_stores.azure_ai_search.document_store", + "haystack_integrations.document_stores.azure_ai_search.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Azure AI Search integration for Haystack + category_slug: integrations-api + title: Azure AI Search + slug: integrations-azure_ai_search + order: 180 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_azure_ai_search.md diff --git a/integrations/azure_ai_search/pyproject.toml b/integrations/azure_ai_search/pyproject.toml new file mode 100644 index 000000000..49ca623e7 --- /dev/null +++ b/integrations/azure_ai_search/pyproject.toml @@ -0,0 +1,163 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "azure-ai-search-haystack" +dynamic = ["version"] +description = 'Haystack 2.x Document Store for Azure AI Search' +readme = "README.md" +requires-python = ">=3.8,<3.13" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset", email = "info@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "azure-search-documents>=11.5", "azure-identity"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/azure-ai-search-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/azure-ai-search-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-rerunfailures", + "pytest-xdist", + "haystack-pydoc-tools", +] + +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:src/}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] +all = ["style", "typing"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 + +[tool.ruff.lint] +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] +exclude = ["example"] + +[tool.ruff.lint.isort] +known-first-party = ["src"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252", "S311"] +"example/**/*" = ["T201"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + + +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure.identity.*", "mypy.*", "azure.core.*", "azure.search.documents.*"] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py new file mode 100644 index 000000000..56dc30db4 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -0,0 +1,5 @@ +from .bm25_retriever import AzureAISearchBM25Retriever +from .embedding_retriever import AzureAISearchEmbeddingRetriever +from .hybrid_retriever import AzureAISearchHybridRetriever + +__all__ = ["AzureAISearchBM25Retriever", "AzureAISearchEmbeddingRetriever", "AzureAISearchHybridRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py new file mode 100644 index 000000000..4a1c7f98c --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/bm25_retriever.py @@ -0,0 +1,135 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchBM25Retriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using BM25 retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, + ): + """ + Create the AzureAISearchBM25Retriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the BM25 search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If the query is not valid, or if the document store is not correctly configured. + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._kwargs = kwargs + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise TypeError(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchBM25Retriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the BM25 retrieval process. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + filters = filters or self._filters + if filters: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = _normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._bm25_retrieval( + query=query, + filters=normalized_filters, + top_k=top_k, + **self._kwargs, + ) + except Exception as e: + msg = ( + "An error occurred during the bm25 retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py new file mode 100644 index 000000000..69fad7208 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -0,0 +1,128 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchEmbeddingRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a vector similarity metric. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, + ): + """ + Create the AzureAISearchEmbeddingRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._kwargs = kwargs + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query_embedding: A list of floats representing the query embedding. + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :returns: Dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + filters = filters or self._filters + if filters: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = _normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._embedding_retrieval( + query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs + ) + except Exception as e: + msg = ( + "An error occurred during the embedding retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query embedding is valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py new file mode 100644 index 000000000..79282933f --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/hybrid_retriever.py @@ -0,0 +1,139 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, _normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchHybridRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a hybrid (vector + BM25) retrieval. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + **kwargs, + ): + """ + Create the AzureAISearchHybridRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the hybrid search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :param filter_policy: Policy to determine how filters are applied. + :param kwargs: Additional keyword arguments to pass to the Azure AI's search endpoint. + Some of the supported parameters: + - `query_type`: A string indicating the type of query to perform. Possible values are + 'simple','full' and 'semantic'. + - `semantic_configuration_name`: The name of semantic configuration to be used when + processing semantic queries. + For more information on parameters, see the + [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + :raises TypeError: If the document store is not an instance of AzureAISearchDocumentStore. + :raises RuntimeError: If query or query_embedding are invalid, or if document store is not correctly configured. + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + self._kwargs = kwargs + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise TypeError(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchHybridRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query: str, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + ): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query: Text of the query. + :param query_embedding: A list of floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See `__init__` method docstring for more + details. + :param top_k: The maximum number of documents to retrieve. + :raises RuntimeError: If an error occurs during the hybrid retrieval process. + :returns: A dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + filters = filters or self._filters + if filters: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = _normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._hybrid_retrieval( + query=query, query_embedding=query_embedding, filters=normalized_filters, top_k=top_k, **self._kwargs + ) + except Exception as e: + msg = ( + "An error occurred during the hybrid retrieval process from the AzureAISearchDocumentStore. " + "Ensure that the query and query_embedding are valid and the document store is correctly configured." + ) + raise RuntimeError(msg) from e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py new file mode 100644 index 000000000..ca0ea7554 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore +from .filters import _normalize_filters + +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "_normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py new file mode 100644 index 000000000..137ff621c --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +from dataclasses import asdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.search.documents import SearchClient +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchAlgorithmMetric, + VectorSearchProfile, +) +from azure.search.documents.models import VectorizedQuery +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace + +from .errors import AzureAISearchDocumentStoreConfigError +from .filters import _normalize_filters + +type_mapping = { + str: "Edm.String", + bool: "Edm.Boolean", + int: "Edm.Int32", + float: "Edm.Double", + datetime: "Edm.DateTimeOffset", +} + +DEFAULT_VECTOR_SEARCH = VectorSearch( + profiles=[ + VectorSearchProfile(name="default-vector-config", algorithm_configuration_name="cosine-algorithm-config") + ], + algorithms=[ + HnswAlgorithmConfiguration( + name="cosine-algorithm-config", + parameters=HnswParameters( + metric=VectorSearchAlgorithmMetric.COSINE, + ), + ) + ], +) + +logger = logging.getLogger(__name__) +logging.getLogger("azure").setLevel(logging.ERROR) +logging.getLogger("azure.identity").setLevel(logging.DEBUG) + + +class AzureAISearchDocumentStore: + def __init__( + self, + *, + api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008 + azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=True), # noqa: B008 + index_name: str = "default", + embedding_dimension: int = 768, + metadata_fields: Optional[Dict[str, type]] = None, + vector_search_configuration: VectorSearch = None, + **index_creation_kwargs, + ): + """ + A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) + as the backend. + + :param azure_endpoint: The URL endpoint of an Azure AI Search service. + :param api_key: The API key to use for authentication. + :param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created. + :param embedding_dimension: Dimension of the embeddings. + :param metadata_fields: A dictionary of metadata keys and their types to create + additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic, + it is necessary to specify the metadata fields in advance. + (e.g. metadata_fields = {"author": str, "date": datetime}) + :param vector_search_configuration: Configuration option related to vector search. + Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. + + :param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class + during index creation. Some of the supported parameters: + - `semantic_search`: Defines semantic configuration of the search index. This parameter is needed + to enable semantic search capabilities in index. + - `similarity`: The type of similarity algorithm to be used when scoring and ranking the documents + matching a search query. The similarity algorithm can only be defined at index creation time and + cannot be modified on existing indexes. + + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/). + """ + + azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None + if not azure_endpoint: + msg = "Please provide an Azure endpoint or set the environment variable AZURE_SEARCH_SERVICE_ENDPOINT." + raise ValueError(msg) + + api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None + + self._client = None + self._index_client = None + self._index_fields = [] # type: List[Any] # stores all fields in the final schema of index + self._api_key = api_key + self._azure_endpoint = azure_endpoint + self._index_name = index_name + self._embedding_dimension = embedding_dimension + self._dummy_vector = [-10.0] * self._embedding_dimension + self._metadata_fields = metadata_fields + self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH + self._index_creation_kwargs = index_creation_kwargs + + @property + def client(self) -> SearchClient: + + # resolve secrets for authentication + resolved_endpoint = ( + self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint + ) + resolved_key = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key + + credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() + try: + if not self._index_client: + self._index_client = SearchIndexClient( + resolved_endpoint, + credential, + ) + if not self._index_exists(self._index_name): + # Create a new index if it does not exist + logger.debug( + "The index '%s' does not exist. A new index will be created.", + self._index_name, + ) + self._create_index(self._index_name) + except (HttpResponseError, ClientAuthenticationError) as error: + msg = f"Failed to authenticate with Azure Search: {error}" + raise AzureAISearchDocumentStoreConfigError(msg) from error + + if self._index_client: + # Get the search client, if index client is initialized + index_fields = self._index_client.get_index(self._index_name).fields + self._index_fields = [field.name for field in index_fields] + self._client = self._index_client.get_search_client(self._index_name) + else: + msg = "Search Index Client is not initialized." + raise AzureAISearchDocumentStoreConfigError(msg) + + return self._client + + def _create_index(self, index_name: str) -> None: + """ + Creates a new search index. + :param index_name: Name of the index to create. If None, the index name from the constructor is used. + :param kwargs: Optional keyword parameters. + """ + + # default fields to create index based on Haystack Document (id, content, embedding) + default_fields = [ + SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True), + SearchableField(name="content", type=SearchFieldDataType.String), + SearchField( + name="embedding", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + hidden=False, + vector_search_dimensions=self._embedding_dimension, + vector_search_profile_name="default-vector-config", + ), + ] + + if not index_name: + index_name = self._index_name + if self._metadata_fields: + default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) + index = SearchIndex( + name=index_name, + fields=default_fields, + vector_search=self._vector_search_configuration, + **self._index_creation_kwargs, + ) + if self._index_client: + self._index_client.create_index(index) + + def to_dict(self) -> Dict[str, Any]: + # This is not the best solution to serialise this class but is the fastest to implement. + # Not all kwargs types can be serialised to text so this can fail. We must serialise each + # type explicitly to handle this properly. + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None, + api_key=self._api_key.to_dict() if self._api_key else None, + index_name=self._index_name, + embedding_dimension=self._embedding_dimension, + metadata_fields=self._metadata_fields, + vector_search_configuration=self._vector_search_configuration.as_dict(), + **self._index_creation_kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"]) + if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None: + data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration) + return default_from_dict(cls, data) + + def count_documents(self) -> int: + """ + Returns how many documents are present in the search index. + + :returns: list of retrieved documents. + """ + return self.client.get_document_count() + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes the provided documents to search index. + + :param documents: documents to write to the index. + :param policy: Policy to determine how duplicates are handled. + :raises ValueError: If the documents are not of type Document. + :raises TypeError: If the document ids are not strings. + :return: the number of documents added to index. + """ + + def _convert_input_document(documents: Document): + document_dict = asdict(documents) + if not isinstance(document_dict["id"], str): + msg = f"Document id {document_dict['id']} is not a string, " + raise TypeError(msg) + index_document = self._convert_haystack_documents_to_azure(document_dict) + + return index_document + + if len(documents) > 0: + if not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + if policy not in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: + logger.warning( + f"AzureAISearchDocumentStore only supports `DuplicatePolicy.OVERWRITE`" + f"but got {policy}. Overwriting duplicates is enabled by default." + ) + client = self.client + documents_to_write = [(_convert_input_document(doc)) for doc in documents] + + if documents_to_write != []: + client.upload_documents(documents_to_write) + return len(documents_to_write) + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the search index. + + :param document_ids: ids of the documents to be deleted. + """ + if self.count_documents() == 0: + return + documents = self._get_raw_documents_by_id(document_ids) + if documents: + self.client.delete_documents(documents) + + def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: + return self._convert_search_result_to_documents(self._get_raw_documents_by_id(document_ids)) + + def search_documents(self, search_text: str = "*", top_k: int = 10) -> List[Document]: + """ + Returns all documents that match the provided search_text. + If search_text is None, returns all documents. + :param search_text: the text to search for in the Document list. + :param top_k: Maximum number of documents to return. + :returns: A list of Documents that match the given search_text. + """ + result = self.client.search(search_text=search_text, top=top_k) + return self._convert_search_result_to_documents(list(result)) + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the provided filters. + Filters should be given as a dictionary supporting filtering by metadata. For details on + filters, see the [metadata filtering documentation](https://docs.haystack.deepset.ai/docs/metadata-filtering). + + :param filters: the filters to apply to the document list. + :returns: A list of Documents that match the given filters. + """ + if filters: + normalized_filters = _normalize_filters(filters) + result = self.client.search(filter=normalized_filters) + return self._convert_search_result_to_documents(result) + else: + return self.search_documents() + + def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]]) -> List[Document]: + """ + Converts Azure search results to Haystack Documents. + """ + documents = [] + + for azure_doc in azure_docs: + embedding = azure_doc.get("embedding") + if embedding == self._dummy_vector: + embedding = None + + # Anything besides default fields (id, content, and embedding) is considered metadata + meta = { + key: value + for key, value in azure_doc.items() + if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None + } + + # Create the document with meta only if it's non-empty + doc = Document( + id=azure_doc["id"], content=azure_doc["content"], embedding=embedding, meta=meta if meta else {} + ) + + documents.append(doc) + return documents + + def _index_exists(self, index_name: Optional[str]) -> bool: + """ + Check if the index exists in the Azure AI Search service. + + :param index_name: The name of the index to check. + :returns bool: whether the index exists. + """ + + if self._index_client and index_name: + return index_name in self._index_client.list_index_names() + else: + msg = "Index name is required to check if the index exists." + raise ValueError(msg) + + def _get_raw_documents_by_id(self, document_ids: List[str]): + """ + Retrieves all Azure documents with a matching document_ids from the document store. + + :param document_ids: ids of the documents to be retrieved. + :returns: list of retrieved Azure documents. + """ + azure_documents = [] + for doc_id in document_ids: + try: + document = self.client.get_document(doc_id) + azure_documents.append(document) + except ResourceNotFoundError: + logger.warning(f"Document with ID {doc_id} not found.") + return azure_documents + + def _convert_haystack_documents_to_azure(self, document: Dict[str, Any]) -> Dict[str, Any]: + """Map the document keys to fields of search index""" + + # Because Azure Search does not allow dynamic fields, we only include fields that are part of the schema + index_document = {k: v for k, v in {**document, **document.get("meta", {})}.items() if k in self._index_fields} + if index_document["embedding"] is None: + index_document["embedding"] = self._dummy_vector + + return index_document + + def _create_metadata_index_fields(self, metadata: Dict[str, Any]) -> List[SimpleField]: + """Create a list of index fields for storing metadata values.""" + + index_fields = [] + metadata_field_mapping = self._map_metadata_field_types(metadata) + + for key, field_type in metadata_field_mapping.items(): + index_fields.append(SimpleField(name=key, type=field_type, filterable=True)) + + return index_fields + + def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]: + """Map metadata field types to Azure Search field types.""" + + metadata_field_mapping = {} + + for key, value_type in metadata.items(): + + if not key[0].isalpha(): + msg = ( + f"Azure Search index only allows field names starting with letters. " + f"Invalid key: {key} will be dropped." + ) + logger.warning(msg) + continue + + field_type = type_mapping.get(value_type) + if not field_type: + error_message = f"Unsupported field type for key '{key}': {value_type}" + raise ValueError(error_message) + metadata_field_mapping[key] = field_type + + return metadata_field_mapping + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + It uses the vector configuration specified in the document store. By default, it uses the HNSW algorithm + with cosine similarity. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. + + :param query_embedding: Embedding of the query. + :param top_k: Maximum number of Documents to return. + :param filters: Filters applied to the retrieved Documents. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + :raises ValueError: If `query_embedding` is an empty list. + :returns: List of Document that are most similar to `query_embedding`. + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search(vector_queries=[vector_query], filter=filters, **kwargs) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _bm25_retrieval( + self, + query: str, + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Document]: + """ + Retrieves documents that are most similar to `query`, using the BM25 algorithm. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchBM25Retriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + + :raises ValueError: If `query` is an empty string. + :returns: List of Document that are most similar to `query`. + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + + result = self.client.search(search_text=query, filter=filters, top=top_k, query_type="simple", **kwargs) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _hybrid_retrieval( + self, + query: str, + query_embedding: List[float], + top_k: int = 10, + filters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Document]: + """ + Retrieves documents similar to query using the vector configuration in the document store and + the BM25 algorithm. This method combines vector similarity and BM25 for improved retrieval. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchHybridRetriever` uses this method directly and is the public interface for it. + + :param query: Text of the query. + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param kwargs: Optional keyword arguments to pass to the Azure AI's search endpoint. + + :raises ValueError: If `query` or `query_embedding` is empty. + :returns: List of Document that are most similar to `query`. + """ + + if query is None: + msg = "query must not be None" + raise ValueError(msg) + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search( + search_text=query, + vector_queries=[vector_query], + filter=filters, + top=top_k, + query_type="simple", + **kwargs, + ) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py new file mode 100644 index 000000000..0fbc80696 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py @@ -0,0 +1,20 @@ +from haystack.document_stores.errors import DocumentStoreError +from haystack.errors import FilterError + + +class AzureAISearchDocumentStoreError(DocumentStoreError): + """Parent class for all AzureAISearchDocumentStore exceptions.""" + + pass + + +class AzureAISearchDocumentStoreConfigError(AzureAISearchDocumentStoreError): + """Raised when a configuration is not valid for a AzureAISearchDocumentStore.""" + + pass + + +class AzureAISearchDocumentStoreFilterError(FilterError): + """Raised when filter is not valid for AzureAISearchDocumentStore.""" + + pass diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py new file mode 100644 index 000000000..0f105bc91 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -0,0 +1,112 @@ +from typing import Any, Dict + +from dateutil import parser + +from .errors import AzureAISearchDocumentStoreFilterError + +LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} + + +def _normalize_filters(filters: Dict[str, Any]) -> str: + """ + Converts Haystack filters in Azure AI Search compatible filters. + """ + if not isinstance(filters, dict): + msg = """Filters must be a dictionary. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("operator", "conditions") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + operator = condition["operator"] + if operator not in LOGICAL_OPERATORS: + msg = f"Unknown operator {operator}" + raise AzureAISearchDocumentStoreFilterError(msg) + conditions = [] + for c in condition["conditions"]: + # Recursively parse if the condition itself is a logical condition + if isinstance(c, dict) and "operator" in c and c["operator"] in LOGICAL_OPERATORS: + conditions.append(_parse_logical_condition(c)) + else: + # Otherwise, parse it as a comparison condition + conditions.append(_parse_comparison_condition(c)) + + # Format the result based on the operator + if operator == "NOT": + return f"not ({' and '.join([f'({c})' for c in conditions])})" + else: + return f" {LOGICAL_OPERATORS[operator]} ".join([f"({c})" for c in conditions]) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("field", "operator", "value") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + # Remove the "meta." prefix from the field name if present + field = condition["field"][5:] if condition["field"].startswith("meta.") else condition["field"] + operator = condition["operator"] + value = "null" if condition["value"] is None else condition["value"] + + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown operator {operator}. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise AzureAISearchDocumentStoreFilterError(msg) + + return COMPARISON_OPERATORS[operator](field, value) + + +def _eq(field: str, value: Any) -> str: + return f"{field} eq '{value}'" if isinstance(value, str) and value != "null" else f"{field} eq {value}" + + +def _ne(field: str, value: Any) -> str: + return f"not ({field} eq '{value}')" if isinstance(value, str) and value != "null" else f"not ({field} eq {value})" + + +def _in(field: str, value: Any) -> str: + if not isinstance(value, list) or any(not isinstance(v, str) for v in value): + msg = "Azure AI Search only supports a list of strings for 'in' comparators" + raise AzureAISearchDocumentStoreFilterError(msg) + values = ", ".join(map(str, value)) + return f"search.in({field},'{values}')" + + +def _comparison_operator(field: str, value: Any, operator: str) -> str: + _validate_type(value, operator) + return f"{field} {operator} {value}" + + +def _validate_type(value: Any, operator: str) -> None: + """Validates that the value is either an integer, float, or ISO 8601 string.""" + msg = f"Invalid value type for '{operator}' comparator. Supported types are: int, float, or ISO 8601 string." + + if isinstance(value, str): + try: + parser.isoparse(value) + except ValueError as e: + raise AzureAISearchDocumentStoreFilterError(msg) from e + elif not isinstance(value, (int, float)): + raise AzureAISearchDocumentStoreFilterError(msg) + + +COMPARISON_OPERATORS = { + "==": _eq, + "!=": _ne, + "in": _in, + ">": lambda f, v: _comparison_operator(f, v, "gt"), + ">=": lambda f, v: _comparison_operator(f, v, "ge"), + "<": lambda f, v: _comparison_operator(f, v, "lt"), + "<=": lambda f, v: _comparison_operator(f, v, "le"), +} diff --git a/integrations/azure_ai_search/tests/__init__.py b/integrations/azure_ai_search/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/azure_ai_search/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py new file mode 100644 index 000000000..89369c87e --- /dev/null +++ b/integrations/azure_ai_search/tests/conftest.py @@ -0,0 +1,82 @@ +import os +import time +import uuid + +import pytest +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.search.documents.indexes import SearchIndexClient +from haystack import logging +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +# This is the approximate time in seconds it takes for the documents to be available in Azure Search index +SLEEP_TIME_IN_SECONDS = 10 +MAX_WAIT_TIME_FOR_INDEX_DELETION = 5 + + +@pytest.fixture() +def sleep_time(): + return SLEEP_TIME_IN_SECONDS + + +@pytest.fixture +def document_store(request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + index_name = f"haystack_test_{uuid.uuid4().hex}" + metadata_fields = getattr(request, "param", {}).get("metadata_fields", None) + + azure_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] + api_key = os.environ["AZURE_SEARCH_API_KEY"] + + client = SearchIndexClient(azure_endpoint, AzureKeyCredential(api_key)) + if index_name in client.list_index_names(): + client.delete_index(index_name) + + store = AzureAISearchDocumentStore( + api_key=api_key, + azure_endpoint=azure_endpoint, + index_name=index_name, + create_index=True, + embedding_dimension=768, + metadata_fields=metadata_fields, + ) + + # Override some methods to wait for the documents to be available + original_write_documents = store.write_documents + original_delete_documents = store.delete_documents + + def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): + written_docs = original_write_documents(documents, policy) + time.sleep(SLEEP_TIME_IN_SECONDS) + return written_docs + + def delete_documents_and_wait(filters): + original_delete_documents(filters) + time.sleep(SLEEP_TIME_IN_SECONDS) + + # Helper function to wait for the index to be deleted, needed to cover latency + def wait_for_index_deletion(client, index_name): + start_time = time.time() + while time.time() - start_time < MAX_WAIT_TIME_FOR_INDEX_DELETION: + if index_name not in client.list_index_names(): + return True + time.sleep(1) + return False + + store.write_documents = write_documents_and_wait + store.delete_documents = delete_documents_and_wait + + yield store + try: + client.delete_index(index_name) + if not wait_for_index_deletion(client, index_name): + logging.error(f"Index {index_name} was not properly deleted.") + except ResourceNotFoundError: + logging.info(f"Index {index_name} was already deleted or not found.") + except Exception as e: + logging.error(f"Unexpected error when deleting index {index_name}: {e}") diff --git a/integrations/azure_ai_search/tests/test_bm25_retriever.py b/integrations/azure_ai_search/tests/test_bm25_retriever.py new file mode 100644 index 000000000..6ebb20949 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_bm25_retriever.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import Mock + +import pytest +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchBM25Retriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchBM25Retriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.bm25_retriever.AzureAISearchBM25Retriever", + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "metadata_fields": None, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchBM25Retriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever(document_store=mock_store) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, filters={"field": "type", "operator": "==", "value": "article"}, top_k=11 + ) + res = retriever.run(query="Test query") + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] + retriever = AzureAISearchBM25Retriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run(query="Test query", filters={"field": "type", "operator": "==", "value": "book"}, top_k=5) + mock_store._bm25_retrieval.assert_called_once_with( + query="Test query", filters="type eq 'book'", top_k=5, select="name" + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1", content="Test document")] + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + res = retriever.run(query="Test document") + assert res["documents"] == docs + + def test_document_retrieval(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document(content="This is first document"), + Document(content="This is second document"), + Document(content="This is third document"), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchBM25Retriever(document_store=document_store) + results = retriever.run(query="This is first document") + assert results["documents"][0].content == "This is first document" diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py new file mode 100644 index 000000000..1bcd967c6 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +import random +from datetime import datetime, timezone +from typing import List +from unittest.mock import patch + +import pytest +from haystack.dataclasses.document import Document +from haystack.errors import FilterError +from haystack.testing.document_store import ( + CountDocumentsTest, + DeleteDocumentsTest, + FilterDocumentsTest, + WriteDocumentsTest, +) +from haystack.utils.auth import EnvVarSecret, Secret + +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_to_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + document_store = AzureAISearchDocumentStore() + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + }, + } + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_from_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + + data = { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "embedding_dimension": 768, + "index_name": "default", + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + }, + } + document_store = AzureAISearchDocumentStore.from_dict(data) + assert isinstance(document_store._api_key, EnvVarSecret) + assert isinstance(document_store._azure_endpoint, EnvVarSecret) + assert document_store._index_name == "default" + assert document_store._embedding_dimension == 768 + assert document_store._metadata_fields is None + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init_is_lazy(_mock_azure_search_client): + AzureAISearchDocumentStore(azure_endpoint=Secret.from_token("test_endpoint")) + _mock_azure_search_client.assert_not_called() + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init(_mock_azure_search_client): + + document_store = AzureAISearchDocumentStore( + api_key=Secret.from_token("fake-api-key"), + azure_endpoint=Secret.from_token("fake_endpoint"), + index_name="my_index", + embedding_dimension=15, + metadata_fields={"Title": str, "Pages": int}, + ) + + assert document_store._index_name == "my_index" + assert document_store._embedding_dimension == 15 + assert document_store._metadata_fields == {"Title": str, "Pages": int} + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + + def test_write_documents(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + + # Parametrize the test with metadata fields + @pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"author": str, "publication_year": int, "rating": float}}, + ], + indirect=True, + ) + def test_write_documents_with_meta(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document( + id="1", + meta={"author": "Tom", "publication_year": 2021, "rating": 4.5}, + content="This is a test document.", + ) + ] + document_store.write_documents(docs) + doc = document_store.get_documents_by_id(["1"]) + assert doc[0] == docs[0] + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_fail(self, document_store: AzureAISearchDocumentStore): ... + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_skip(self, document_store: AzureAISearchDocumentStore): ... + + +def _random_embeddings(n): + return [round(random.random(), 7) for _ in range(n)] # nosec: S311 + + +TEST_EMBEDDING_1 = _random_embeddings(768) +TEST_EMBEDDING_2 = _random_embeddings(768) + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"name": str, "page": str, "chapter": str, "number": int, "date": datetime}}, + ], + indirect=True, +) +class TestFilters(FilterDocumentsTest): + + # Overriding to change "date" to compatible ISO 8601 format + # and remove incompatible fields (dataframes) for Azure search index + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """Fixture that returns a list of Documents that can be used to test filtering.""" + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00Z", + }, + embedding=_random_embeddings(768), + ) + ) + + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + return documents + + # Overriding to compare the documents with the same order + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + + This is used in every test, if a Document Store implementation has a different behaviour + it should override this method. This can happen for example when the Document Store sets + a score to returned Documents. Since we can't know what the score will be, we can't compare + the Documents reliably. + """ + sorted_recieved = sorted(received, key=lambda doc: doc.id) + sorted_expected = sorted(expected, key=lambda doc: doc.id) + assert sorted_recieved == sorted_expected + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + # Azure search index supports UTC datetime in ISO 8601 format + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">", "value": "1972-12-11T19:54:58Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + > datetime.strptime("1972-12-11T19:54:58Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + >= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + < datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + <= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + # Override as comparison operators with None/null raise errors + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": None}) + + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": None}) + + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": None}) + + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": None}) + + # Override as Azure AI Search supports 'in' operator only for strings + def test_comparison_in(self, document_store, filterable_docs): + """Test filter_documents() with 'in' comparator""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents({"field": "meta.page", "operator": "in", "value": ["100", "123"]}) + assert len(result) + expected = [d for d in filterable_docs if d.meta.get("page") is not None and d.meta["page"] in ["100", "123"]] + self.assert_documents_are_equal(result, expected) + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs): ... + + def test_missing_condition_operator_key(self, document_store, filterable_docs): + """Test filter_documents() with missing operator key""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents( + filters={"conditions": [{"field": "meta.name", "operator": "eq", "value": "test"}]} + ) + + def test_nested_logical_filters(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + {"field": "meta.name", "operator": "==", "value": "name_0"}, + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "!=", "value": 0}, + {"field": "meta.page", "operator": "==", "value": "123"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + {"field": "meta.page", "operator": "==", "value": "90"}, + ], + }, + ], + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + # Ensure all required fields are present in doc.meta + ("name" in doc.meta and doc.meta.get("name") == "name_0") + or ( + all(key in doc.meta for key in ["number", "page"]) + and doc.meta.get("number") != 0 + and doc.meta.get("page") == "123" + ) + or ( + all(key in doc.meta for key in ["page", "chapter"]) + and doc.meta.get("chapter") == "conclusion" + and doc.meta.get("page") == "90" + ) + ) + ], + ) diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py new file mode 100644 index 000000000..576ecda08 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchEmbeddingRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7]) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchEmbeddingRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], filters={"field": "type", "operator": "==", "value": "book"}, top_k=9 + ) + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.run(query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is thrid document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + results = retriever.run(query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._embedding_retrieval(query_embedding=query_embedding) diff --git a/integrations/azure_ai_search/tests/test_hybrid_retriever.py b/integrations/azure_ai_search/tests/test_hybrid_retriever.py new file mode 100644 index 000000000..bf305c4fe --- /dev/null +++ b/integrations/azure_ai_search/tests/test_hybrid_retriever.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchHybridRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchHybridRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.hybrid_retriever.AzureAISearchHybridRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchHybridRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +def test_run(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever(document_store=mock_store) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="", + top_k=10, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_init_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + ) + res = retriever.run(query_embedding=[0.5, 0.7], query="Test query") + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'article'", + top_k=11, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +def test_run_time_params(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + mock_store._hybrid_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] + retriever = AzureAISearchHybridRetriever( + document_store=mock_store, + filters={"field": "type", "operator": "==", "value": "article"}, + top_k=11, + select="name", + ) + res = retriever.run( + query_embedding=[0.5, 0.7], + query="Test query", + filters={"field": "type", "operator": "==", "value": "book"}, + top_k=9, + ) + mock_store._hybrid_retrieval.assert_called_once_with( + query="Test query", + query_embedding=[0.5, 0.7], + filters="type eq 'book'", + top_k=9, + select="name", + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].content == "Test doc" + assert res["documents"][0].embedding == [0.1, 0.2] + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + res = retriever.run(query="Test document", query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_hybrid_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is third document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchHybridRetriever(document_store=document_store) + results = retriever.run(query="This is first document", query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._hybrid_retrieval(query="", query_embedding=query_embedding) diff --git a/integrations/chroma/CHANGELOG.md b/integrations/chroma/CHANGELOG.md index f6a23d84a..591c0ec39 100644 --- a/integrations/chroma/CHANGELOG.md +++ b/integrations/chroma/CHANGELOG.md @@ -1,5 +1,45 @@ # Changelog +## [integrations/chroma-v1.0.0] - 2024-11-06 + +### ๐Ÿ› Bug Fixes + +- Fixing Chroma tests due `chromadb` update behaviour change (#1148) +- Adapt our implementation to breaking changes in Chroma 0.5.17 (#1165) + +### โš™๏ธ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + +## [integrations/chroma-v0.22.1] - 2024-09-30 + +### Chroma + +- Empty filters should behave as no filters (#1117) + +## [integrations/chroma-v0.22.0] - 2024-09-30 + +### ๐Ÿš€ Features + +- Chroma - allow remote HTTP connection (#1094) +- Chroma - defer the DB connection (#1107) + +### ๐Ÿ› Bug Fixes + +- Fix chroma linting; rm numpy (#1063) + +Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> +- Filters in chroma integration (#1072) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Chroma - ruff update, don't ruff tests (#983) +- Update ruff linting scripts and settings (#1105) + ## [integrations/chroma-v0.21.1] - 2024-07-17 ### ๐Ÿ› Bug Fixes @@ -76,8 +116,6 @@ This PR will also push the docs to Readme - Fix project urls (#96) - - ### ๐Ÿšœ Refactor - Use `hatch_vcs` to manage integrations versioning (#103) @@ -88,13 +126,10 @@ This PR will also push the docs to Readme - Fix import and increase version (#77) - - ## [integrations/chroma-v0.8.0] - 2023-12-04 ### ๐Ÿ› Bug Fixes - Fix license headers - diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 2bffabfd8..cfe7a606e 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "chromadb>=0.5.0", "typing_extensions>=4.8.0"] +dependencies = ["haystack-ai", "chromadb>=0.5.17", "typing_extensions>=4.8.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme" @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/chroma-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -61,8 +62,10 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.9", "3.10"] [tool.hatch.envs.lint] +installer = "uv" detached = true dependencies = [ + "pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", @@ -131,6 +134,8 @@ ignore = [ "PLR0915", # Ignore unused params "ARG002", + # Allow assertions + "S101", ] unfixable = [ # Don't touch unused imports diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py index 359ace58d..439e4b144 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/document_store.py @@ -40,9 +40,8 @@ def __init__( **embedding_function_params, ): """ - Initializes the store. The __init__ constructor is not part of the Store Protocol - and the signature can be customized to your needs. For example, parameters needed - to set up a database client would be passed to this method. + Creates a new ChromaDocumentStore instance. + It is meant to be connected to a Chroma collection. Note: for the component to be part of a serializable pipeline, the __init__ parameters must be serializable, reason why we use a registry to configure the @@ -65,7 +64,6 @@ def __init__( :param metadata: a dictionary of chromadb collection parameters passed directly to chromadb's client method `create_collection`. If it contains the key `"hnsw:space"`, the value will take precedence over the `distance_function` parameter above. - :param embedding_function_params: additional parameters to pass to the embedding function. """ @@ -79,53 +77,61 @@ def __init__( # Store the params for marshalling self._collection_name = collection_name self._embedding_function = embedding_function + self._embedding_func = get_embedding_function(embedding_function, **embedding_function_params) self._embedding_function_params = embedding_function_params self._distance_function = distance_function + self._metadata = metadata + self._collection = None self._persist_path = persist_path self._host = host self._port = port - # Create the client instance - if persist_path and (host or port is not None): - error_message = ( - "You must specify `persist_path` for local persistent storage or, " - "alternatively, `host` and `port` for remote HTTP client connection. " - "You cannot specify both options." - ) - raise ValueError(error_message) - if host and port is not None: - # Remote connection via HTTP client - self._chroma_client = chromadb.HttpClient( - host=host, - port=port, - ) - elif persist_path is None: - # In-memory storage - self._chroma_client = chromadb.Client() - else: - # Local persistent storage - self._chroma_client = chromadb.PersistentClient(path=persist_path) + self._initialized = False - embedding_func = get_embedding_function(embedding_function, **embedding_function_params) + def _ensure_initialized(self): + if not self._initialized: + # Create the client instance + if self._persist_path and (self._host or self._port is not None): + error_message = ( + "You must specify `persist_path` for local persistent storage or, " + "alternatively, `host` and `port` for remote HTTP client connection. " + "You cannot specify both options." + ) + raise ValueError(error_message) + if self._host and self._port is not None: + # Remote connection via HTTP client + client = chromadb.HttpClient( + host=self._host, + port=self._port, + ) + elif self._persist_path is None: + # In-memory storage + client = chromadb.Client() + else: + # Local persistent storage + client = chromadb.PersistentClient(path=self._persist_path) - metadata = metadata or {} - if "hnsw:space" not in metadata: - metadata["hnsw:space"] = distance_function + self._metadata = self._metadata or {} + if "hnsw:space" not in self._metadata: + self._metadata["hnsw:space"] = self._distance_function - if collection_name in [c.name for c in self._chroma_client.list_collections()]: - self._collection = self._chroma_client.get_collection(collection_name, embedding_function=embedding_func) + if self._collection_name in [c.name for c in client.list_collections()]: + self._collection = client.get_collection(self._collection_name, embedding_function=self._embedding_func) - if metadata != self._collection.metadata: - logger.warning( - "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." + if self._metadata != self._collection.metadata: + logger.warning( + "Collection already exists. " + "The `distance_function` and `metadata` parameters will be ignored." + ) + else: + self._collection = client.create_collection( + name=self._collection_name, + metadata=self._metadata, + embedding_function=self._embedding_func, ) - else: - self._collection = self._chroma_client.create_collection( - name=collection_name, - metadata=metadata, - embedding_function=embedding_func, - ) + + self._initialized = True def count_documents(self) -> int: """ @@ -133,6 +139,8 @@ def count_documents(self) -> int: :returns: how many documents are present in the document store. """ + self._ensure_initialized() + assert self._collection is not None return self._collection.count() def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: @@ -197,6 +205,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :param filters: the filters to apply to the document list. :returns: a list of Documents that match the given filters. """ + self._ensure_initialized() + assert self._collection is not None + if filters: chroma_filter = _convert_filters(filters) kwargs: Dict[str, Any] = {"where": chroma_filter.where} @@ -227,6 +238,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :returns: The number of documents written """ + self._ensure_initialized() + assert self._collection is not None + for doc in documents: if not isinstance(doc, Document): msg = "param 'documents' must contain a list of objects of type Document" @@ -234,9 +248,12 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D if doc.content is None: logger.warning( - "ChromaDocumentStore can only store the text field of Documents: " - "'array', 'dataframe' and 'blob' will be dropped." + "ChromaDocumentStore cannot store documents with `content=None`. " + "`array`, `dataframe` and `blob` are not supported. " + "Document with id %s will be skipped.", + doc.id, ) + continue data = {"ids": [doc.id], "documents": [doc.content]} if doc.meta: @@ -280,8 +297,11 @@ def delete_documents(self, document_ids: List[str]) -> None: """ Deletes all documents with a matching document_ids from the document store. - :param document_ids: the object_ids to delete + :param document_ids: the document ids to delete """ + self._ensure_initialized() + assert self._collection is not None + self._collection.delete(ids=document_ids) def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any]] = None) -> List[List[Document]]: @@ -292,7 +312,10 @@ def search(self, queries: List[str], top_k: int, filters: Optional[Dict[str, Any :param filters: a dictionary of filters to apply to the search. Accepts filters in haystack format. :returns: matching documents for each query. """ - if filters is None: + self._ensure_initialized() + assert self._collection is not None + + if not filters: results = self._collection.query( query_texts=queries, n_results=top_k, @@ -323,7 +346,10 @@ def search_embeddings( :returns: a list of lists of documents that match the given filters. """ - if filters is None: + self._ensure_initialized() + assert self._collection is not None + + if not filters: results = self._collection.query( query_embeddings=query_embeddings, n_results=top_k, diff --git a/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py b/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py index ef5c920a7..df49da673 100644 --- a/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py +++ b/integrations/chroma/src/haystack_integrations/document_stores/chroma/filters.py @@ -1,6 +1,6 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from chromadb.api.types import validate_where, validate_where_document @@ -27,15 +27,15 @@ class ChromaFilter: """ Dataclass to store the converted filter structure used in Chroma queries. - Following filter criterias are supported: + Following filter criteria are supported: - `ids`: A list of document IDs to filter by in Chroma collection. - `where`: A dictionary of metadata filters applied to the documents. - `where_document`: A dictionary of content-based filters applied to the documents' content. """ ids: List[str] - where: Dict[str, Any] - where_document: Dict[str, Any] + where: Optional[Dict[str, Any]] + where_document: Optional[Dict[str, Any]] def _convert_filters(filters: Dict[str, Any]) -> ChromaFilter: @@ -80,7 +80,7 @@ def _convert_filters(filters: Dict[str, Any]) -> ChromaFilter: msg = f"Invalid '{test_clause}' : {e}" raise ChromaDocumentStoreFilterError(msg) from e - return ChromaFilter(ids=ids, where=where, where_document=where_document) + return ChromaFilter(ids=ids, where=where or None, where_document=where_document or None) def _convert_filter_clause(filters: Dict[str, Any]) -> Dict[str, Any]: diff --git a/integrations/chroma/tests/test_document_store.py b/integrations/chroma/tests/test_document_store.py index d33086945..ed815251e 100644 --- a/integrations/chroma/tests/test_document_store.py +++ b/integrations/chroma/tests/test_document_store.py @@ -13,9 +13,12 @@ from chromadb.api.types import Documents, EmbeddingFunction, Embeddings from haystack import Document from haystack.testing.document_store import ( + TEST_EMBEDDING_1, + TEST_EMBEDDING_2, CountDocumentsTest, DeleteDocumentsTest, FilterDocumentsTest, + _random_embeddings, ) from haystack_integrations.document_stores.chroma import ChromaDocumentStore @@ -51,6 +54,67 @@ def document_store(self) -> ChromaDocumentStore: get_func.return_value = _TestEmbeddingFunction() return ChromaDocumentStore(embedding_function="test_function", collection_name=str(uuid.uuid1())) + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """ + This fixture has been copied from haystack/testing/document_store.py and modified to + remove the documents that don't have textual content, as Chroma does not support writing them. + """ + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"Document {i} without embedding", + meta={"name": f"name_{i}", "no_embedding": True, "chapter": "conclusion"}, + ) + ) + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + return documents + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ Assert that two lists of Documents are equal. @@ -98,7 +162,8 @@ def test_invalid_initialization_both_host_and_persist_path(self): Test that providing both host and persist_path raises an error. """ with pytest.raises(ValueError): - ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + store = ChromaDocumentStore(persist_path="./path/to/local/store", host="localhost") + store._ensure_initialized() def test_delete_empty(self, document_store: ChromaDocumentStore): """ @@ -136,6 +201,10 @@ def test_search(self): assert isinstance(doc.embedding, list) assert all(isinstance(el, float) for el in doc.embedding) + # check that empty filters behave as no filters + result_empty_filters = document_store.search(["Third"], filters={}, top_k=1) + assert result == result_empty_filters + def test_write_documents_unsupported_meta_values(self, document_store: ChromaDocumentStore): """ Unsupported meta values should be removed from the documents before writing them to the database @@ -207,6 +276,7 @@ def test_same_collection_name_reinitialization(self): @pytest.mark.integration def test_distance_metric_initialization(self): store = ChromaDocumentStore("test_2", distance_function="cosine") + store._ensure_initialized() assert store._collection.metadata["hnsw:space"] == "cosine" with pytest.raises(ValueError): @@ -215,9 +285,11 @@ def test_distance_metric_initialization(self): @pytest.mark.integration def test_distance_metric_reinitialization(self, caplog): store = ChromaDocumentStore("test_4", distance_function="cosine") + store._ensure_initialized() with caplog.at_level(logging.WARNING): new_store = ChromaDocumentStore("test_4", distance_function="ip") + new_store._ensure_initialized() assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." @@ -238,6 +310,8 @@ def test_metadata_initialization(self, caplog): "hnsw:M": 103, }, ) + store._ensure_initialized() + assert store._collection.metadata["hnsw:space"] == "ip" assert store._collection.metadata["hnsw:search_ef"] == 101 assert store._collection.metadata["hnsw:construction_ef"] == 102 @@ -254,6 +328,8 @@ def test_metadata_initialization(self, caplog): }, ) + new_store._ensure_initialized() + assert ( "Collection already exists. The `distance_function` and `metadata` parameters will be ignored." in caplog.text @@ -330,25 +406,6 @@ def test_nested_logical_filters(self, document_store: ChromaDocumentStore, filte ], ) - # Override inequality tests from FilterDocumentsTest - # because chroma doesn't return documents with absent meta fields - - def test_comparison_not_equal(self, document_store, filterable_docs): - """Test filter_documents() with != comparator""" - document_store.write_documents(filterable_docs) - result = document_store.filter_documents({"field": "meta.number", "operator": "!=", "value": 100}) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if "number" in d.meta and d.meta.get("number") != 100] - ) - - def test_comparison_not_in(self, document_store, filterable_docs): - """Test filter_documents() with 'not in' comparator""" - document_store.write_documents(filterable_docs) - result = document_store.filter_documents({"field": "meta.number", "operator": "not in", "value": [2, 9]}) - self.assert_documents_are_equal( - result, [d for d in filterable_docs if "number" in d.meta and d.meta.get("number") not in [2, 9]] - ) - @pytest.mark.skip(reason="Filter on dataframe contents is not supported.") def test_comparison_equal_with_dataframe( self, document_store: ChromaDocumentStore, filterable_docs: List[Document] diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index d86165668..262b1612d 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/deepeval/CHANGELOG.md b/integrations/deepeval/CHANGELOG.md new file mode 100644 index 000000000..a296c7cfa --- /dev/null +++ b/integrations/deepeval/CHANGELOG.md @@ -0,0 +1,35 @@ +# Changelog + +## [integrations/deepeval-v0.1.2] - 2024-11-14 + +### ๐Ÿš€ Features + +- Implement `DeepEvalEvaluator` (#346) + +### ๐Ÿ› Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Deepeval - pin indirect dependencies based on python version (#1187) + +### ๐Ÿ“š Documentation + +- Update paths and titles (#397) +- Update category slug (#442) +- Update `deepeval-haystack` docstrings (#527) +- Disable-class-def (#556) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Exculde evaluator private classes in API docs (#392) +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + + diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 5d81fa0a5..78cc2542a 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "deepeval==0.20.57"] +dependencies = ["haystack-ai", "deepeval==0.20.57", "langchain<0.3; python_version < '3.10'"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/deepeval" @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/deepeval-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -55,8 +56,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/elasticsearch/CHANGELOG.md b/integrations/elasticsearch/CHANGELOG.md index 5d2b66470..bd8bff63c 100644 --- a/integrations/elasticsearch/CHANGELOG.md +++ b/integrations/elasticsearch/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [integrations/elasticsearch-v1.0.1] - 2024-10-28 + +### โš™๏ธ Miscellaneous Tasks + +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + ## [integrations/elasticsearch-v1.0.0] - 2024-09-12 ### ๐Ÿš€ Features @@ -69,8 +77,6 @@ This PR will also push the docs to Readme - Fix project urls (#96) - - ### ๐Ÿšœ Refactor - Use `hatch_vcs` to manage integrations versioning (#103) @@ -81,15 +87,12 @@ This PR will also push the docs to Readme - Fix import and increase version (#77) - - ## [integrations/elasticsearch-v0.1.0] - 2023-12-04 ### ๐Ÿ› Bug Fixes - Fix license headers - ## [integrations/elasticsearch-v0.0.2] - 2023-11-29 diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index 47b168f30..8bf01cc65 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/elasticsearch-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -61,8 +62,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py index 734e2d2b8..8dfb07919 100644 --- a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py +++ b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py @@ -105,9 +105,12 @@ def __init__( @property def client(self) -> Elasticsearch: if self._client is None: + headers = self._kwargs.pop("headers", {}) + headers["user-agent"] = f"haystack-py-ds/{haystack_version}" + client = Elasticsearch( self._hosts, - headers={"user-agent": f"haystack-py-ds/{haystack_version}"}, + headers=headers, **self._kwargs, ) # Check client connection, this will raise if not connected diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index 51a19b641..d636ff027 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -22,6 +22,20 @@ def test_init_is_lazy(_mock_es_client): _mock_es_client.assert_not_called() +@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") +def test_headers_are_supported(_mock_es_client): + _ = ElasticsearchDocumentStore(hosts="testhost", headers={"header1": "value1", "header2": "value2"}).client + + assert _mock_es_client.call_count == 1 + _, kwargs = _mock_es_client.call_args + + headers_found = kwargs["headers"] + assert headers_found["header1"] == "value1" + assert headers_found["header2"] == "value2" + + assert headers_found["user-agent"].startswith("haystack-py-ds/") + + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") def test_to_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore(hosts="some hosts") diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md index 9ae3da929..841781660 100644 --- a/integrations/fastembed/CHANGELOG.md +++ b/integrations/fastembed/CHANGELOG.md @@ -1,20 +1,62 @@ # Changelog -## [unreleased] +## [integrations/fastembed-v1.4.1] - 2024-11-19 -### โš™๏ธ Miscellaneous Tasks +### ๐ŸŒ€ Miscellaneous + +- Add new example with the ranker and fix the old ones (#1198) +- Fix: Fastembed - Change default Sparse model as the used one is deprecated due to a typo (#1201) + +## [integrations/fastembed-v1.4.0] - 2024-11-13 + +### โš™๏ธ CI + +- Adopt uv as installer (#1142) + +### ๐ŸŒ€ Miscellaneous + +- Chore: fastembed - pin `onnxruntime<1.20.0` (#1164) +- Feat: Fastembed - add FastembedRanker (#1178) + +## [integrations/fastembed-v1.3.0] - 2024-10-07 + +### ๐Ÿš€ Features + +- Introduce `model_kwargs` in Sparse Embedders (can be used for BM25 parameters) (#1126) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) + +### ๐Ÿงน Chores + - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) -### Fix +### ๐ŸŒ€ Miscellaneous +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) - Typo on Sparse embedders. The parameter should be "progress_bar" โ€ฆ (#814) +- Chore: fastembed - ruff update, don't ruff tests (#997) ## [integrations/fastembed-v1.1.0] - 2024-05-15 +### ๐ŸŒ€ Miscellaneous + +- Chore: change the pydoc renderer class (#718) +- Use the local_files_only option available as of fastembed==0.2.7. It โ€ฆ (#736) + ## [integrations/fastembed-v1.0.0] - 2024-05-06 +### ๐ŸŒ€ Miscellaneous + +- Chore: add license classifiers (#680) +- `FastembedSparseTextEmbedder` - remove `batch_size` (#688) + ## [integrations/fastembed-v0.1.0] - 2024-04-10 ### ๐Ÿš€ Features @@ -25,6 +67,10 @@ - Disable-class-def (#556) +### ๐ŸŒ€ Miscellaneous + +- Remove references to Python 3.7 (#601) + ## [integrations/fastembed-v0.0.6] - 2024-03-07 ### ๐Ÿ“š Documentation @@ -32,20 +78,32 @@ - Review and normalize docstrings - `integrations.fastembed` (#519) - Small consistency improvements (#536) +### ๐ŸŒ€ Miscellaneous + +- Docs: Fix `integrations.fastembed` API docs (#540) +- Improvements to FastEmbed integration (#558) + ## [integrations/fastembed-v0.0.5] - 2024-02-20 ### ๐Ÿ› Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### ๐Ÿ“š Documentation - Update category slug (#442) +### ๐ŸŒ€ Miscellaneous + +- Fastembed integration new parameters (#446) + ## [integrations/fastembed-v0.0.4] - 2024-02-16 +### ๐ŸŒ€ Miscellaneous + +- Fastembed integration: add example (#401) +- Fastembed fix: add parallel (#403) + ## [integrations/fastembed-v0.0.3] - 2024-02-12 ### ๐Ÿ› Bug Fixes @@ -58,6 +116,15 @@ This PR will also push the docs to Readme ## [integrations/fastembed-v0.0.2] - 2024-02-11 +### ๐ŸŒ€ Miscellaneous + +- Updated labeler and readme (#389) +- Fastembed fix: added prefix and suffix (#390) + ## [integrations/fastembed-v0.0.1] - 2024-02-10 +### ๐ŸŒ€ Miscellaneous + +- Add Fastembed Embeddings integration (#383) + diff --git a/integrations/fastembed/README.md b/integrations/fastembed/README.md index c021dec3b..f3c2bb135 100644 --- a/integrations/fastembed/README.md +++ b/integrations/fastembed/README.md @@ -8,6 +8,7 @@ **Table of Contents** - [Installation](#installation) +- [Usage](#Usage) - [License](#license) ## Installation @@ -33,7 +34,7 @@ embedding = text_embedder.run(text)["embedding"] ```python from haystack_integrations.components.embedders.fastembed import FastembedDocumentEmbedder -from haystack.dataclasses import Document +from haystack import Document embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", @@ -50,24 +51,50 @@ from haystack_integrations.components.embedders.fastembed import FastembedSparse text = "fastembed is supported by and maintained by Qdrant." text_embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1" + model="prithivida/Splade_PP_en_v1" ) text_embedder.warm_up() -embedding = text_embedder.run(text)["embedding"] +embedding = text_embedder.run(text)["sparse_embedding"] ``` ```python from haystack_integrations.components.embedders.fastembed import FastembedSparseDocumentEmbedder -from haystack.dataclasses import Document +from haystack import Document embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() doc = Document(content="fastembed is supported by and maintained by Qdrant.", meta={"long_answer": "no",}) result = embedder.run(documents=[doc]) ``` +You can use `FastembedRanker` by importing as: + +```python +from haystack import Document + +from haystack_integrations.components.rankers.fastembed import FastembedRanker + +query = "Who is maintaining Qdrant?" +documents = [ + Document( + content="This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="fastembed is supported by and maintained by Qdrant."), +] + +ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2") +ranker.warm_up() +reranked_documents = ranker.run(query=query, documents=documents)["documents"] + +print(reranked_documents[0]) + +# Document(id=..., +# content: 'fastembed is supported by and maintained by Qdrant.', +# score: 5.472434997558594..) +``` + ## License `fastembed-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/fastembed/examples/ranker_example.py b/integrations/fastembed/examples/ranker_example.py new file mode 100644 index 000000000..593334e90 --- /dev/null +++ b/integrations/fastembed/examples/ranker_example.py @@ -0,0 +1,22 @@ +from haystack import Document + +from haystack_integrations.components.rankers.fastembed import FastembedRanker + +query = "Who is maintaining Qdrant?" +documents = [ + Document( + content="This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="fastembed is supported by and maintained by Qdrant."), +] + +ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2") +ranker.warm_up() +reranked_documents = ranker.run(query=query, documents=documents)["documents"] + + +print(reranked_documents[0]) + +# Document(id=..., +# content: 'fastembed is supported by and maintained by Qdrant.', +# score: 5.472434997558594..) diff --git a/integrations/fastembed/pydoc/config.yml b/integrations/fastembed/pydoc/config.yml index aad50e52c..8ab538cf8 100644 --- a/integrations/fastembed/pydoc/config.yml +++ b/integrations/fastembed/pydoc/config.yml @@ -6,7 +6,8 @@ loaders: "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder", "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder", "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder", - "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder" + "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder", + "haystack_integrations.components.rankers.fastembed.ranker" ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index 69aba5562..abae78d8a 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.2.5"] +dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.4.2"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/fastembed-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] @@ -152,6 +154,10 @@ omit = ["*/tests/*", "*/__init__.py"] show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + [[tool.mypy.overrides]] module = [ "haystack.*", diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index 66f797549..3a68abcfb 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional from haystack.dataclasses.sparse_embedding import SparseEmbedding from tqdm import tqdm @@ -73,14 +73,19 @@ def get_embedding_backend( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): - embedding_backend_id = f"{model_name}{cache_dir}{threads}" + embedding_backend_id = f"{model_name}{cache_dir}{threads}{local_files_only}{model_kwargs}" if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances: return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _FastembedSparseEmbeddingBackend( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + local_files_only=local_files_only, + model_kwargs=model_kwargs, ) _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -97,9 +102,16 @@ def __init__( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): + model_kwargs = model_kwargs or {} + self.model = SparseTextEmbedding( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + local_files_only=local_files_only, + **model_kwargs, ) def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index 4b72389fa..a30d43cf4 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -16,7 +16,7 @@ class FastembedSparseDocumentEmbedder: from haystack.dataclasses import Document sparse_doc_embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", batch_size=32, ) @@ -53,7 +53,7 @@ class FastembedSparseDocumentEmbedder: def __init__( self, - model: str = "prithvida/Splade_PP_en_v1", + model: str = "prithivida/Splade_PP_en_v1", cache_dir: Optional[str] = None, threads: Optional[int] = None, batch_size: int = 32, @@ -62,12 +62,13 @@ def __init__( local_files_only: bool = False, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Create an FastembedDocumentEmbedder component. :param model: Local path or name of the model in Hugging Face's model hub, - such as `prithvida/Splade_PP_en_v1`. + such as `prithivida/Splade_PP_en_v1`. :param cache_dir: The path to the cache directory. Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. @@ -81,6 +82,7 @@ def __init__( :param local_files_only: If `True`, only use the model files in the `cache_dir`. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + :param model_kwargs: Dictionary containing model parameters such as `k`, `b`, `avg_len`, `language`. """ self.model_name = model @@ -92,6 +94,7 @@ def __init__( self.local_files_only = local_files_only self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.model_kwargs = model_kwargs def to_dict(self) -> Dict[str, Any]: """ @@ -110,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]: local_files_only=self.local_files_only, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + model_kwargs=self.model_kwargs, ) def warm_up(self): @@ -122,6 +126,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index 67348b2bd..c7296525f 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -19,7 +19,7 @@ class FastembedSparseTextEmbedder: "The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!") sparse_text_embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1" + model="prithivida/Splade_PP_en_v1" ) sparse_text_embedder.warm_up() @@ -29,17 +29,18 @@ class FastembedSparseTextEmbedder: def __init__( self, - model: str = "prithvida/Splade_PP_en_v1", + model: str = "prithivida/Splade_PP_en_v1", cache_dir: Optional[str] = None, threads: Optional[int] = None, progress_bar: bool = True, parallel: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Create a FastembedSparseTextEmbedder component. - :param model: Local path or name of the model in Fastembed's model hub, such as `prithvida/Splade_PP_en_v1` + :param model: Local path or name of the model in Fastembed's model hub, such as `prithivida/Splade_PP_en_v1` :param cache_dir: The path to the cache directory. Can be set using the `FASTEMBED_CACHE_PATH` env variable. Defaults to `fastembed_cache` in the system's temp directory. @@ -50,6 +51,7 @@ def __init__( If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. :param local_files_only: If `True`, only use the model files in the `cache_dir`. + :param model_kwargs: Dictionary containing model parameters such as `k`, `b`, `avg_len`, `language`. """ self.model_name = model @@ -58,6 +60,7 @@ def __init__( self.progress_bar = progress_bar self.parallel = parallel self.local_files_only = local_files_only + self.model_kwargs = model_kwargs def to_dict(self) -> Dict[str, Any]: """ @@ -74,6 +77,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, parallel=self.parallel, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) def warm_up(self): @@ -86,6 +90,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) @component.output_types(sparse_embedding=SparseEmbedding) diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py new file mode 100644 index 000000000..ece5e858b --- /dev/null +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py @@ -0,0 +1,3 @@ +from .ranker import FastembedRanker + +__all__ = ["FastembedRanker"] diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py new file mode 100644 index 000000000..8f077a30c --- /dev/null +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py @@ -0,0 +1,202 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict, logging + +from fastembed.rerank.cross_encoder import TextCrossEncoder + +logger = logging.getLogger(__name__) + + +@component +class FastembedRanker: + """ + Ranks Documents based on their similarity to the query using + [Fastembed models](https://qdrant.github.io/fastembed/examples/Supported_Models/). + + Documents are indexed from most to least semantically relevant to the query. + + Usage example: + ```python + from haystack import Document + from haystack_integrations.components.rankers.fastembed import FastembedRanker + + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2) + + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + print(output["documents"][0].content) + + # Berlin + ``` + """ + + def __init__( + self, + model_name: str = "Xenova/ms-marco-MiniLM-L-6-v2", + top_k: int = 10, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + batch_size: int = 64, + parallel: Optional[int] = None, + local_files_only: bool = False, + meta_fields_to_embed: Optional[List[str]] = None, + meta_data_separator: str = "\n", + ): + """ + Creates an instance of the 'FastembedRanker'. + + :param model_name: Fastembed model name. Check the list of supported models in the [Fastembed documentation](https://qdrant.github.io/fastembed/examples/Supported_Models/). + :param top_k: The maximum number of documents to return. + :param cache_dir: The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + :param threads: The number of threads single onnxruntime session can use. Defaults to None. + :param batch_size: Number of strings to encode at once. + :param parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + :param local_files_only: If `True`, only use the model files in the `cache_dir`. + :param meta_fields_to_embed: List of meta fields that should be concatenated + with the document content for reranking. + :param meta_data_separator: Separator used to concatenate the meta fields + to the Document content. + """ + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + self.model_name = model_name + self.top_k = top_k + self.cache_dir = cache_dir + self.threads = threads + self.batch_size = batch_size + self.parallel = parallel + self.local_files_only = local_files_only + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.meta_data_separator = meta_data_separator + self._model = None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model_name=self.model_name, + top_k=self.top_k, + cache_dir=self.cache_dir, + threads=self.threads, + batch_size=self.batch_size, + parallel=self.parallel, + local_files_only=self.local_files_only, + meta_fields_to_embed=self.meta_fields_to_embed, + meta_data_separator=self.meta_data_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FastembedRanker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initializes the component. + """ + if self._model is None: + self._model = TextCrossEncoder( + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, + ) + + def _prepare_fastembed_input_docs(self, documents: List[Document]) -> List[str]: + """ + Prepare the input by concatenating the document text with the metadata fields specified. + :param documents: The list of Document objects. + + :return: A list of strings to be given as input to Fastembed model. + """ + concatenated_input_list = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key) + ] + concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""]) + concatenated_input_list.append(concatenated_input) + + return concatenated_input_list + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Returns a list of documents ranked by their similarity to the given query, using FastEmbed. + + :param query: + The input query to compare the documents to. + :param documents: + A list of documents to be ranked. + :param top_k: + The maximum number of documents to return. + + :returns: + A dictionary with the following keys: + - `documents`: A list of documents closest to the query, sorted from most similar to least similar. + + :raises ValueError: If `top_k` is not > 0. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = "FastembedRanker expects a list of Documents as input. " + raise TypeError(msg) + if query == "": + msg = "No query provided" + raise ValueError(msg) + + if not documents: + return {"documents": []} + + top_k = top_k or self.top_k + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + if self._model is None: + msg = "The ranker model has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) + + fastembed_input_docs = self._prepare_fastembed_input_docs(documents) + + scores = list( + self._model.rerank( + query=query, + documents=fastembed_input_docs, + batch_size=self.batch_size, + parallel=self.parallel, + ) + ) + + # Combine the two lists into a single list of tuples + doc_scores = list(zip(documents, scores)) + + # Sort the list of tuples by the score in descending order + sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True) + + # Get the top_k documents + top_k_documents = [] + for doc, score in sorted_doc_scores[:top_k]: + doc.score = score + top_k_documents.append(doc) + + return {"documents": top_k_documents} diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index 631d9f1e0..994a6f883 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -2,6 +2,7 @@ from haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend import ( _FastembedEmbeddingBackendFactory, + _FastembedSparseEmbeddingBackendFactory, ) @@ -44,3 +45,28 @@ def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 embedding_backend.model.embed.assert_called_once_with(data) # restore the factory stateTrue _FastembedEmbeddingBackendFactory._instances = {} + + +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.SparseTextEmbedding") +def test_model_kwargs_initialization(mock_instructor): + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 40, + } + + # Invoke the backend factory with the BM25 configuration + _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( + model_name="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + # Check if SparseTextEmbedding was called with the correct arguments + mock_instructor.assert_called_once_with( + model_name="Qdrant/bm25", cache_dir=None, threads=None, local_files_only=False, **bm25_config + ) + + # Restore factory state after the test + _FastembedSparseEmbeddingBackendFactory._instances = {} diff --git a/integrations/fastembed/tests/test_fastembed_ranker.py b/integrations/fastembed/tests/test_fastembed_ranker.py new file mode 100644 index 000000000..e38229c87 --- /dev/null +++ b/integrations/fastembed/tests/test_fastembed_ranker.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest +from haystack import Document, default_from_dict + +from haystack_integrations.components.rankers.fastembed.ranker import ( + FastembedRanker, +) + + +class TestFastembedRanker: + def test_init_default(self): + """ + Test default initialization parameters for FastembedRanker. + """ + ranker = FastembedRanker(model_name="BAAI/bge-reranker-base") + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.top_k == 10 + assert ranker.cache_dir is None + assert ranker.threads is None + assert ranker.batch_size == 64 + assert ranker.parallel is None + assert not ranker.local_files_only + assert ranker.meta_fields_to_embed == [] + assert ranker.meta_data_separator == "\n" + + def test_init_with_parameters(self): + """ + Test custom initialization parameters for FastembedRanker. + """ + ranker = FastembedRanker( + model_name="BAAI/bge-reranker-base", + top_k=64, + cache_dir="fake_dir", + threads=2, + batch_size=50, + parallel=1, + local_files_only=True, + meta_fields_to_embed=["test_field"], + meta_data_separator=" | ", + ) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.top_k == 64 + assert ranker.cache_dir == "fake_dir" + assert ranker.threads == 2 + assert ranker.batch_size == 50 + assert ranker.parallel == 1 + assert ranker.local_files_only + assert ranker.meta_fields_to_embed == ["test_field"] + assert ranker.meta_data_separator == " | " + + def test_init_with_incorrect_input(self): + """ + Test for checking incorrect input format on init + """ + with pytest.raises( + ValueError, + match="top_k must be > 0, but got 0", + ): + FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2", top_k=0) + + with pytest.raises( + ValueError, + match="top_k must be > 0, but got -3", + ): + FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2", top_k=-3) + + def test_to_dict(self): + """ + Test serialization of FastembedRanker to a dictionary, using default initialization parameters. + """ + ranker = FastembedRanker(model_name="BAAI/bge-reranker-base") + ranker_dict = ranker.to_dict() + assert ranker_dict == { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "top_k": 10, + "cache_dir": None, + "threads": None, + "batch_size": 64, + "parallel": None, + "local_files_only": False, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + + def test_to_dict_with_custom_init_parameters(self): + """ + Test serialization of FastembedRanker to a dictionary, using custom initialization parameters. + """ + ranker = FastembedRanker( + model_name="BAAI/bge-reranker-base", + cache_dir="fake_dir", + threads=2, + top_k=5, + batch_size=50, + parallel=1, + local_files_only=True, + meta_fields_to_embed=["test_field"], + meta_data_separator=" | ", + ) + ranker_dict = ranker.to_dict() + assert ranker_dict == { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": "fake_dir", + "threads": 2, + "top_k": 5, + "batch_size": 50, + "parallel": 1, + "local_files_only": True, + "meta_fields_to_embed": ["test_field"], + "meta_data_separator": " | ", + }, + } + + def test_from_dict(self): + """ + Test deserialization of FastembedRanker from a dictionary, using default initialization parameters. + """ + ranker_dict = { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": None, + "threads": None, + "top_k": 5, + "batch_size": 50, + "parallel": None, + "local_files_only": False, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + ranker = default_from_dict(FastembedRanker, ranker_dict) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.cache_dir is None + assert ranker.threads is None + assert ranker.top_k == 5 + assert ranker.batch_size == 50 + assert ranker.parallel is None + assert not ranker.local_files_only + assert ranker.meta_fields_to_embed == [] + assert ranker.meta_data_separator == "\n" + + def test_from_dict_with_custom_init_parameters(self): + """ + Test deserialization of FastembedRanker from a dictionary, using custom initialization parameters. + """ + ranker_dict = { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": "fake_dir", + "threads": 2, + "top_k": 5, + "batch_size": 50, + "parallel": 1, + "local_files_only": True, + "meta_fields_to_embed": ["test_field"], + "meta_data_separator": " | ", + }, + } + ranker = default_from_dict(FastembedRanker, ranker_dict) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.cache_dir == "fake_dir" + assert ranker.threads == 2 + assert ranker.top_k == 5 + assert ranker.batch_size == 50 + assert ranker.parallel == 1 + assert ranker.local_files_only + assert ranker.meta_fields_to_embed == ["test_field"] + assert ranker.meta_data_separator == " | " + + def test_run_incorrect_input_format(self): + """ + Test for checking incorrect input format. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + ranker._model = "mock_model" + + query = "query" + string_input = "text" + list_integers_input = [1, 2, 3] + list_document = [Document("Document 1")] + + with pytest.raises( + TypeError, + match="FastembedRanker expects a list of Documents as input.", + ): + ranker.run(query=query, documents=string_input) + + with pytest.raises( + TypeError, + match="FastembedRanker expects a list of Documents as input.", + ): + ranker.run(query=query, documents=list_integers_input) + + with pytest.raises( + ValueError, + match="No query provided", + ): + ranker.run(query="", documents=list_document) + + with pytest.raises( + ValueError, + match="top_k must be > 0, but got -3", + ): + ranker.run(query=query, documents=list_document, top_k=-3) + + def test_run_no_warmup(self): + """ + Test for checking error when calling without a warmup. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + + query = "query" + list_document = [Document("Document 1")] + + with pytest.raises( + RuntimeError, + ): + ranker.run(query=query, documents=list_document) + + def test_run_empty_document_list(self): + """ + Test for no error when sending no documents. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + ranker._model = "mock_model" + + query = "query" + list_document = [] + + result = ranker.run(query=query, documents=list_document) + assert len(result["documents"]) == 0 + + def test_embed_metadata(self): + """ + Tests the embedding of metadata fields in document content for ranking. + """ + ranker = FastembedRanker( + model_name="model_name", + meta_fields_to_embed=["meta_field"], + ) + ranker._model = MagicMock() + + documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] + query = "test" + ranker.run(query=query, documents=documents) + + ranker._model.rerank.assert_called_once_with( + query=query, + documents=[ + "meta_value 0\ndocument-number 0", + "meta_value 1\ndocument-number 1", + "meta_value 2\ndocument-number 2", + "meta_value 3\ndocument-number 3", + "meta_value 4\ndocument-number 4", + ], + batch_size=64, + parallel=None, + ) + + @pytest.mark.integration + def test_run(self): + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2) + ranker.warm_up() + + query = "Who is maintaining Qdrant?" + documents = [ + Document( + content="This is built to be faster and lighter than other embedding \ +libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="This is some random input"), + Document(content="fastembed is supported by and maintained by Qdrant."), + ] + + result = ranker.run(query=query, documents=documents) + + assert len(result["documents"]) == 2 + first_document = result["documents"][0] + second_document = result["documents"][1] + + assert isinstance(first_document, Document) + assert isinstance(second_document, Document) + assert first_document.content == "fastembed is supported by and maintained by Qdrant." + assert first_document.score > second_document.score diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index d3f2023b8..7c0de196a 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -15,8 +15,8 @@ def test_init_default(self): """ Test default initialization parameters for FastembedSparseDocumentEmbedder. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.batch_size == 32 @@ -31,7 +31,7 @@ def test_init_with_parameters(self): Test custom initialization parameters for FastembedSparseDocumentEmbedder. """ embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, batch_size=64, @@ -41,7 +41,7 @@ def test_init_with_parameters(self): meta_fields_to_embed=["test_field"], embedding_separator=" | ", ) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.batch_size == 64 @@ -55,12 +55,12 @@ def test_to_dict(self): """ Test serialization of FastembedSparseDocumentEmbedder to a dictionary, using default initialization parameters. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "batch_size": 32, @@ -69,6 +69,7 @@ def test_to_dict(self): "local_files_only": False, "embedding_separator": "\n", "meta_fields_to_embed": [], + "model_kwargs": None, }, } @@ -77,7 +78,7 @@ def test_to_dict_with_custom_init_parameters(self): Test serialization of FastembedSparseDocumentEmbedder to a dictionary, using custom initialization parameters. """ embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, batch_size=64, @@ -91,7 +92,7 @@ def test_to_dict_with_custom_init_parameters(self): assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "batch_size": 64, @@ -100,6 +101,7 @@ def test_to_dict_with_custom_init_parameters(self): "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", + "model_kwargs": None, }, } @@ -111,7 +113,7 @@ def test_from_dict(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "batch_size": 32, @@ -123,7 +125,7 @@ def test_from_dict(self): }, } embedder = default_from_dict(FastembedSparseDocumentEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.batch_size == 32 @@ -141,7 +143,7 @@ def test_from_dict_with_custom_init_parameters(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder.FastembedSparseDocumentEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "batch_size": 64, @@ -153,7 +155,7 @@ def test_from_dict_with_custom_init_parameters(self): }, } embedder = default_from_dict(FastembedSparseDocumentEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.batch_size == 64 @@ -170,11 +172,15 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithivida/Splade_PP_en_v1", + cache_dir=None, + threads=None, + local_files_only=False, + model_kwargs=None, ) @patch( @@ -184,7 +190,7 @@ def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() embedder.warm_up() @@ -205,7 +211,7 @@ def test_embed(self): """ Test for checking output dimensions and embedding dimensions. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") embedder.embedding_backend = MagicMock() embedder.embedding_backend.embed = lambda x, **kwargs: self._generate_mocked_sparse_embedding( # noqa: ARG005 len(x) @@ -229,7 +235,7 @@ def test_embed_incorrect_input_format(self): """ Test for checking incorrect input format when creating embedding. """ - embedder = FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseDocumentEmbedder(model="prithivida/Splade_PP_en_v1") string_input = "text" list_integers_input = [1, 2, 3] @@ -275,10 +281,56 @@ def test_embed_metadata(self): parallel=None, ) + def test_init_with_model_kwargs_parameters(self): + """ + Test initialization of FastembedSparseDocumentEmbedder with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + assert embedder.model_kwargs == bm25_config + + @pytest.mark.integration + def test_run_with_model_kwargs(self): + """ + Integration test to check the embedding with model_kwargs parameters. + """ + bm42_config = { + "alpha": 0.2, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm42-all-minilm-l6-v2-attentions", + model_kwargs=bm42_config, + ) + embedder.warm_up() + + doc = Document(content="Example content using BM42") + + result = embedder.run(documents=[doc]) + embedding = result["documents"][0].sparse_embedding + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseDocumentEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index 7e9197493..9b73f5f3a 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -15,8 +15,8 @@ def test_init_default(self): """ Test default initialization parameters for FastembedSparseTextEmbedder. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.progress_bar is True @@ -27,13 +27,13 @@ def test_init_with_parameters(self): Test custom initialization parameters for FastembedSparseTextEmbedder. """ embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, progress_bar=False, parallel=1, ) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.progress_bar is False @@ -43,17 +43,18 @@ def test_to_dict(self): """ Test serialization of FastembedSparseTextEmbedder to a dictionary, using default initialization parameters. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") embedder_dict = embedder.to_dict() assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "progress_bar": True, "parallel": None, "local_files_only": False, + "model_kwargs": None, }, } @@ -62,7 +63,7 @@ def test_to_dict_with_custom_init_parameters(self): Test serialization of FastembedSparseTextEmbedder to a dictionary, using custom initialization parameters. """ embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", cache_dir="fake_dir", threads=2, progress_bar=False, @@ -73,12 +74,13 @@ def test_to_dict_with_custom_init_parameters(self): assert embedder_dict == { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "progress_bar": False, "parallel": 1, "local_files_only": True, + "model_kwargs": None, }, } @@ -89,7 +91,7 @@ def test_from_dict(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": None, "threads": None, "progress_bar": True, @@ -97,7 +99,7 @@ def test_from_dict(self): }, } embedder = default_from_dict(FastembedSparseTextEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir is None assert embedder.threads is None assert embedder.progress_bar is True @@ -110,7 +112,7 @@ def test_from_dict_with_custom_init_parameters(self): embedder_dict = { "type": "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder.FastembedSparseTextEmbedder", # noqa "init_parameters": { - "model": "prithvida/Splade_PP_en_v1", + "model": "prithivida/Splade_PP_en_v1", "cache_dir": "fake_dir", "threads": 2, "progress_bar": False, @@ -118,7 +120,7 @@ def test_from_dict_with_custom_init_parameters(self): }, } embedder = default_from_dict(FastembedSparseTextEmbedder, embedder_dict) - assert embedder.model_name == "prithvida/Splade_PP_en_v1" + assert embedder.model_name == "prithivida/Splade_PP_en_v1" assert embedder.cache_dir == "fake_dir" assert embedder.threads == 2 assert embedder.progress_bar is False @@ -131,11 +133,15 @@ def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithivida/Splade_PP_en_v1", + cache_dir=None, + threads=None, + local_files_only=False, + model_kwargs=None, ) @patch( @@ -145,7 +151,7 @@ def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. """ - embedder = FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1") + embedder = FastembedSparseTextEmbedder(model="prithivida/Splade_PP_en_v1") mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() embedder.warm_up() @@ -195,10 +201,58 @@ def test_run_wrong_incorrect_format(self): with pytest.raises(TypeError, match="FastembedSparseTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) + def test_init_with_model_kwargs_parameters(self): + """ + Test initialization of FastembedSparseTextEmbedder with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + assert embedder.model_kwargs == bm25_config + + @pytest.mark.integration + def test_run_with_model_kwargs(self): + """ + Integration test to check the embedding with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 256.0, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + embedder.warm_up() + + text = "Example content using BM25" + + result = embedder.run(text=text) + embedding = result["sparse_embedding"] + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseTextEmbedder( - model="prithvida/Splade_PP_en_v1", + model="prithivida/Splade_PP_en_v1", ) embedder.warm_up() diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 0fc7ce0ab..7171b0069 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,24 +1,55 @@ # Changelog -## [unreleased] +## [integrations/google_ai-v3.0.2] - 2024-11-19 + +### ๐Ÿ› Bug Fixes + +- Fix missing usage metadata in GoogleAIGeminiChatGenerator (#1195) + + +## [integrations/google_ai-v3.0.0] - 2024-11-12 + +### ๐Ÿ› Bug Fixes + +- `GoogleAIGeminiGenerator` - remove support for tools and change output type (#1177) + +### โš™๏ธ CI + +- Adopt uv as installer (#1142) + + +## [integrations/google_ai-v2.0.1] - 2024-10-15 + +### ๐Ÿš€ Features + +- Add chatrole tests and meta for GeminiChatGenerators (#1090) ### ๐Ÿ› Bug Fixes - Remove the use of deprecated gemini models (#1032) +- Chat roles for model responses in chat generators (#1030) +- Make sure that streaming works with function calls - (drop python3.8) (#1137) ### ๐Ÿงช Testing - Do not retry tests in `hatch run test` command (#954) -### โš™๏ธ Miscellaneous Tasks +### โš™๏ธ CI - Retry tests to reduce flakyness (#836) + +### ๐Ÿงน Chores + - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) -### Docs +### ๐ŸŒ€ Miscellaneous +- Ci: install `pytest-rerunfailures` where needed; add retry config to `test-cov` script (#845) +- Fix Google AI tests failing (#885) - Update GeminiGenerator docstrings (#964) - Update GoogleChatGenerator docstrings (#962) +- Feat: enable streaming in GoogleAIGemini (#1016) ## [integrations/google_ai-v1.1.0] - 2024-06-05 @@ -26,31 +57,50 @@ - Handle `TypeError: Could not create Blob` in `GoogleAIGeminiChatGenerator` (#772) +### ๐ŸŒ€ Miscellaneous + +- Chore: add license classifiers (#680) +- Chore: change the pydoc renderer class (#718) +- Fix Google AI integration tests (#786) +- Test: Fix tests skipping for Google AI integration (#788) + ## [integrations/google_ai-v1.0.0] - 2024-03-27 ### ๐Ÿ› Bug Fixes - Fix order of API docs (#447) -This PR will also push the docs to Readme - ### ๐Ÿ“š Documentation - Update category slug (#442) - Disable-class-def (#556) +### ๐ŸŒ€ Miscellaneous + +- Google AI - review docstrings (#533) +- Make tests show coverage (#566) +- Remove references to Python 3.7 (#601) +- Google Generators: change `answers` to `replies` (#626) + ## [integrations/google_ai-v0.2.0] - 2024-02-15 -### Google_ai +### ๐ŸŒ€ Miscellaneous - Create api docs (#354) +- Google AI - new secrets management (#424) ## [integrations/google_ai-v0.1.0] - 2024-01-25 -### Refact +### ๐ŸŒ€ Miscellaneous +- Add docstrings for `GoogleAIGeminiGenerator` and `GoogleAIGeminiChatGenerator` (#175) - [**breaking**] Adjust import paths (#268) ## [integrations/google_ai-v0.0.1] - 2024-01-03 +### ๐ŸŒ€ Miscellaneous + +- Gemini with Makersuite (#156) +- Fix google_ai integration versioning + diff --git a/integrations/google_ai/pyproject.toml b/integrations/google_ai/pyproject.toml index d06e0a53f..9a4a070e7 100644 --- a/integrations/google_ai/pyproject.toml +++ b/integrations/google_ai/pyproject.toml @@ -7,7 +7,7 @@ name = "google-ai-haystack" dynamic = ["version"] description = 'Use models like Gemini via Makersuite' readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] @@ -15,7 +15,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 4 - Beta", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -41,6 +40,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/google_ai-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -59,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 56c84968b..dbcab619d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -313,9 +313,24 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess """ replies: List[ChatMessage] = [] metadata = response_body.to_dict() + + # currently Google only supports one candidate and usage metadata reflects this + # this should be refactored when multiple candidates are supported + usage_metadata_openai_format = {} + + usage_metadata = metadata.get("usage_metadata") + if usage_metadata: + usage_metadata_openai_format = { + "prompt_tokens": usage_metadata["prompt_token_count"], + "completion_tokens": usage_metadata["candidates_token_count"], + "total_tokens": usage_metadata["total_token_count"], + } + for idx, candidate in enumerate(response_body.candidates): candidate_metadata = metadata["candidates"][idx] candidate_metadata.pop("content", None) # we remove content from the metadata + if usage_metadata_openai_format: + candidate_metadata["usage"] = usage_metadata_openai_format for part in candidate.content.parts: if part.text != "": @@ -347,20 +362,21 @@ def _get_stream_response( replies: List[ChatMessage] = [] for chunk in stream: content: Union[str, Dict[str, Any]] = "" - metadata = chunk.to_dict() # we store whole chunk as metadata in streaming calls - for candidate in chunk.candidates: - for part in candidate.content.parts: - if part.text != "": - content = part.text + dict_chunk = chunk.to_dict() + metadata = dict(dict_chunk) # we copy and store the whole chunk as metadata in streaming calls + for candidate in dict_chunk["candidates"]: + for part in candidate["content"]["parts"]: + if "text" in part and part["text"] != "": + content = part["text"] replies.append(ChatMessage(content=content, role=ChatRole.ASSISTANT, meta=metadata, name=None)) - elif part.function_call is not None: - metadata["function_call"] = part.function_call - content = dict(part.function_call.args.items()) + elif "function_call" in part and len(part["function_call"]) > 0: + metadata["function_call"] = part["function_call"] + content = part["function_call"]["args"] replies.append( ChatMessage( content=content, role=ChatRole.ASSISTANT, - name=part.function_call.name, + name=part["function_call"]["name"], meta=metadata, ) ) diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 218e16c4c..b032169df 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai -from google.ai.generativelanguage import Content, Part, Tool +from google.ai.generativelanguage import Content, Part from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory from haystack.core.component import component @@ -62,6 +62,16 @@ class GoogleAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs: + msg = ( + "GoogleAIGeminiGenerator does not support the `tools` parameter. " + " Use GoogleAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(GoogleAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -69,7 +79,6 @@ def __init__( model: str = "gemini-1.5-flash", generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -86,7 +95,6 @@ def __init__( :param safety_settings: The safety settings to use. A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) - :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. """ @@ -96,8 +104,7 @@ def __init__( self._model_name = model self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._model = GenerativeModel(self._model_name, tools=self._tools) + self._model = GenerativeModel(self._model_name) self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: @@ -126,11 +133,8 @@ def to_dict(self) -> Dict[str, Any]: model=self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -149,8 +153,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -178,7 +180,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -192,7 +194,7 @@ def run( :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - - `replies`: A list of strings or dictionaries with function calls. + - `replies`: A list of strings containing the generated responses. """ # check if streaming_callback is passed @@ -221,12 +223,6 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[str]: for part in candidate.content.parts: if part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index c4372db0d..cb42f0ff8 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -295,5 +295,11 @@ def test_past_conversation(): ] response = gemini_chat.run(messages=messages) assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + replies = response["replies"] + assert len(replies) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + assert all("usage" in reply.meta for reply in replies) + assert all("prompt_tokens" in reply.meta["usage"] for reply in replies) + assert all("completion_tokens" in reply.meta["usage"] for reply in replies) + assert all("total_tokens" in reply.meta["usage"] for reply in replies) diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 7206b7a43..07d194a59 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -2,32 +2,12 @@ from unittest.mock import patch import pytest -from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -41,40 +21,24 @@ def test_init(monkeypatch): top_k=0.5, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure") as mock_genai_configure: gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) mock_genai_configure.assert_called_once_with(api_key="test") assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] assert isinstance(gemini._model, GenerativeModel) +def test_init_fails_with_tools(): + with pytest.raises(TypeError, match="GoogleAIGeminiGenerator does not support the `tools` parameter."): + GoogleAIGeminiGenerator(tools=["tool1", "tool2"]) + + def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -88,7 +52,6 @@ def test_to_dict(monkeypatch): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, }, } @@ -105,32 +68,11 @@ def test_to_dict_with_param(monkeypatch): top_k=2, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", @@ -147,11 +89,6 @@ def test_to_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } @@ -175,11 +112,6 @@ def test_from_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -194,7 +126,6 @@ def test_from_dict_with_param(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) @@ -217,11 +148,6 @@ def test_from_dict(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -236,7 +162,6 @@ def test_from_dict(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) diff --git a/integrations/google_vertex/CHANGELOG.md b/integrations/google_vertex/CHANGELOG.md index 17a730b60..ea2a8fb18 100644 --- a/integrations/google_vertex/CHANGELOG.md +++ b/integrations/google_vertex/CHANGELOG.md @@ -1,11 +1,34 @@ # Changelog -## [unreleased] +## [integrations/google_vertex-v3.0.0] - 2024-11-14 + +### ๐Ÿ› Bug Fixes + +- VertexAIGeminiGenerator - remove support for tools and change output type (#1180) + +### โš™๏ธ Miscellaneous Tasks + +- Fix Vertex tests (#1163) + +## [integrations/google_vertex-v2.2.0] - 2024-10-23 + +### ๐Ÿ› Bug Fixes + +- Make "project-id" parameter optional during initialization (#1141) +- Make project-id optional in all VertexAI generators (#1147) + +### โš™๏ธ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + +## [integrations/google_vertex-v2.1.0] - 2024-10-04 ### ๐Ÿš€ Features - Enable streaming for VertexAIGeminiChatGenerator (#1014) - Add tests for VertexAIGeminiGenerator and enable streaming (#1012) +- Add chatrole tests and meta for GeminiChatGenerators (#1090) +- Add custom params to VertexAIGeminiGenerator and VertexAIGeminiChatGenerator (#1100) ### ๐Ÿ› Bug Fixes @@ -21,6 +44,7 @@ - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) ## [integrations/google_vertex-v1.1.0] - 2024-03-28 diff --git a/integrations/google_vertex/pyproject.toml b/integrations/google_vertex/pyproject.toml index a0cefbcd4..d8b7b3408 100644 --- a/integrations/google_vertex/pyproject.toml +++ b/integrations/google_vertex/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "google-cloud-aiplatform>=1.38", "pyarrow>3"] +dependencies = ["haystack-ai", "google-cloud-aiplatform>=1.61", "pyarrow>3"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google_vertex#readme" @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/google_vertex-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -59,8 +60,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py index 14102eb4b..ff8ce497b 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/captioner.py @@ -25,7 +25,7 @@ class VertexAIImageCaptioner: from haystack.dataclasses.byte_stream import ByteStream from haystack_integrations.components.generators.google_vertex import VertexAIImageCaptioner - captioner = VertexAIImageCaptioner(project_id=project_id) + captioner = VertexAIImageCaptioner() image = ByteStream( data=requests.get( @@ -41,14 +41,16 @@ class VertexAIImageCaptioner: ``` """ - def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "imagetext", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Generate image captions using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py index f09692daf..c52f76dc6 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py @@ -36,7 +36,7 @@ class VertexAIGeminiChatGenerator: from haystack.dataclasses import ChatMessage from haystack_integrations.components.generators.google_vertex import VertexAIGeminiChatGenerator - gemini_chat = VertexAIGeminiChatGenerator(project_id=project_id) + gemini_chat = VertexAIGeminiChatGenerator() messages = [ChatMessage.from_user("Tell me the name of a movie")] res = gemini_chat.run(messages) @@ -50,7 +50,7 @@ def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, @@ -65,7 +65,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py index c39c7f88b..096e642dd 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/code_generator.py @@ -20,7 +20,7 @@ class VertexAICodeGenerator: ```python from haystack_integrations.components.generators.google_vertex import VertexAICodeGenerator - generator = VertexAICodeGenerator(project_id=project_id) + generator = VertexAICodeGenerator() result = generator.run(prefix="def to_json(data):") @@ -45,14 +45,16 @@ class VertexAICodeGenerator: ``` """ - def __init__(self, *, model: str = "code-bison", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "code-bison", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Generate code using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py index 2b1c1b477..c9473b428 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/gemini.py @@ -15,8 +15,6 @@ HarmBlockThreshold, HarmCategory, Part, - Tool, - ToolConfig, ) logger = logging.getLogger(__name__) @@ -32,7 +30,7 @@ class VertexAIGeminiGenerator: from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator - gemini = VertexAIGeminiGenerator(project_id=project_id) + gemini = VertexAIGeminiGenerator() result = gemini.run(parts = ["What is the most interesting thing you know?"]) for answer in result["replies"]: print(answer) @@ -50,16 +48,24 @@ class VertexAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs or "tool_config" in kwargs: + msg = ( + "VertexAIGeminiGenerator does not support `tools` and `tool_config` parameters. " + "Use VertexAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(VertexAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, model: str = "gemini-1.5-flash", - project_id: str, + project_id: Optional[str] = None, location: Optional[str] = None, generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, - tool_config: Optional[ToolConfig] = None, system_instruction: Optional[Union[str, ByteStream, Part]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): @@ -69,7 +75,7 @@ def __init__( Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. For available models, see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models. :param location: The default location to use when making API calls, if not set uses us-central-1. :param generation_config: The generation config to use. @@ -86,10 +92,6 @@ def __init__( for [HarmBlockThreshold](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmBlockThreshold) and [HarmCategory](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.HarmCategory) for more details. - :param tools: List of tools to use when generating content. See the documentation for - [Tool](https://cloud.google.com/python/docs/reference/aiplatform/latest/vertexai.generative_models.Tool) - the list of supported arguments. - :param tool_config: The tool config to use. See the documentation for [ToolConfig](https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.ToolConfig) :param system_instruction: Default system instruction to use for generating content. :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. @@ -105,8 +107,6 @@ def __init__( # model parameters self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._tool_config = tool_config self._system_instruction = system_instruction self._streaming_callback = streaming_callback @@ -115,8 +115,6 @@ def __init__( self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, ) @@ -132,18 +130,6 @@ def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, A "stop_sequences": config._raw_generation_config.stop_sequences, } - def _tool_config_to_dict(self, tool_config: ToolConfig) -> Dict[str, Any]: - """Serializes the ToolConfig object into a dictionary.""" - - mode = tool_config._gapic_tool_config.function_calling_config.mode - allowed_function_names = tool_config._gapic_tool_config.function_calling_config.allowed_function_names - config_dict = {"function_calling_config": {"mode": mode}} - - if allowed_function_names: - config_dict["function_calling_config"]["allowed_function_names"] = allowed_function_names - - return config_dict - def to_dict(self) -> Dict[str, Any]: """ Serializes the component to a dictionary. @@ -160,15 +146,10 @@ def to_dict(self) -> Dict[str, Any]: location=self._location, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, - tool_config=self._tool_config, system_instruction=self._system_instruction, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.to_dict(t) for t in tools] - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = self._tool_config_to_dict(tool_config) + if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) return data @@ -184,22 +165,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "VertexAIGeminiGenerator": Deserialized component. """ - def _tool_config_from_dict(config_dict: Dict[str, Any]) -> ToolConfig: - """Deserializes the ToolConfig object from a dictionary.""" - function_calling_config = config_dict["function_calling_config"] - return ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=function_calling_config["mode"], - allowed_function_names=function_calling_config.get("allowed_function_names"), - ) - ) - - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) - if (tool_config := data["init_parameters"].get("tool_config")) is not None: - data["init_parameters"]["tool_config"] = _tool_config_from_dict(tool_config) if (serialized_callback_handler := data["init_parameters"].get("streaming_callback")) is not None: data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) return default_from_dict(cls, data) @@ -215,7 +182,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -257,12 +224,6 @@ def _get_response(self, response_body: GenerationResponse) -> List[str]: for part in candidate.content.parts: if part._raw_part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py index 0534a20f2..9301221b5 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/image_generator.py @@ -24,20 +24,27 @@ class VertexAIImageGenerator: from haystack_integrations.components.generators.google_vertex import VertexAIImageGenerator - generator = VertexAIImageGenerator(project_id=project_id) + generator = VertexAIImageGenerator() result = generator.run(prompt="Generate an image of a cute cat") result["images"][0].to_file(Path("my_image.png")) ``` """ - def __init__(self, *, model: str = "imagegeneration", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, + *, + model: str = "imagegeneration", + project_id: Optional[str] = None, + location: Optional[str] = None, + **kwargs, + ): """ Generates images using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py index 392a41e00..38eeb7c62 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/question_answering.py @@ -23,7 +23,7 @@ class VertexAIImageQA: from haystack.dataclasses.byte_stream import ByteStream from haystack_integrations.components.generators.google_vertex import VertexAIImageQA - qa = VertexAIImageQA(project_id=project_id) + qa = VertexAIImageQA() image = ByteStream.from_file_path("dog.jpg") @@ -35,14 +35,16 @@ class VertexAIImageQA: ``` """ - def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "imagetext", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Answers questions about an image using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py index 59061d91c..4f69dfb18 100644 --- a/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py +++ b/integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/text_generator.py @@ -25,7 +25,7 @@ class VertexAITextGenerator: ```python from haystack_integrations.components.generators.google_vertex import VertexAITextGenerator - generator = VertexAITextGenerator(project_id=project_id) + generator = VertexAITextGenerator() res = generator.run("Tell me a good interview question for a software engineer.") print(res["replies"][0]) @@ -45,14 +45,16 @@ class VertexAITextGenerator: ``` """ - def __init__(self, *, model: str = "text-bison", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, *, model: str = "text-bison", project_id: Optional[str] = None, location: Optional[str] = None, **kwargs + ): """ Generate text using a Google Vertex AI model. Authenticates using Google Cloud Application Default Credentials (ADCs). For more information see the official [Google documentation](https://cloud.google.com/docs/authentication/provide-credentials-adc). - :param project_id: ID of the GCP project to use. + :param project_id: ID of the GCP project to use. By default, it is set during Google Cloud authentication. :param model: Name of the model to use. :param location: The default location to use when making API calls, if not set uses us-central-1. :param kwargs: Additional keyword arguments to pass to the model. diff --git a/integrations/google_vertex/tests/chat/test_gemini.py b/integrations/google_vertex/tests/chat/test_gemini.py index 87b43d66b..73c99fe2f 100644 --- a/integrations/google_vertex/tests/chat/test_gemini.py +++ b/integrations/google_vertex/tests/chat/test_gemini.py @@ -22,11 +22,11 @@ name="get_current_weather", description="Get the current weather in a given location", parameters={ - "type_": "OBJECT", + "type": "object", "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": { - "type_": "STRING", + "type": "string", "enum": [ "celsius", "fahrenheit", @@ -90,14 +90,12 @@ def test_init(mock_vertexai_init, _mock_generative_model): @patch("haystack_integrations.components.generators.google_vertex.chat.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiChatGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiChatGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, @@ -132,6 +130,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): gemini = VertexAIGeminiChatGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, tools=[tool], @@ -144,7 +143,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -171,6 +170,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], + "property_ordering": ["location", "unit"], }, } ] @@ -194,7 +194,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, @@ -205,7 +205,7 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None assert gemini._safety_settings is None assert gemini._tools is None assert gemini._tool_config is None @@ -221,6 +221,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.chat.gemini.VertexAIGeminiChatGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -238,13 +239,19 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": { - "type_": "OBJECT", + "type": "object", "properties": { "location": { - "type_": "STRING", + "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, + "unit": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit", + ], + }, }, "required": ["location"], }, @@ -266,6 +273,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._tool_config, ToolConfig) @@ -290,7 +298,7 @@ def test_run(mock_generative_model): ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), ] - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiChatGenerator() response = gemini.run(messages=messages) mock_model.send_message.assert_called_once() @@ -315,7 +323,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiChatGenerator(project_id="TestID123", location=None, streaming_callback=streaming_callback) + gemini = VertexAIGeminiChatGenerator(streaming_callback=streaming_callback) messages = [ ChatMessage.from_system("You are a helpful assistant"), ChatMessage.from_user("What's the capital of France?"), diff --git a/integrations/google_vertex/tests/test_captioner.py b/integrations/google_vertex/tests/test_captioner.py index 26249dbee..3d849c738 100644 --- a/integrations/google_vertex/tests/test_captioner.py +++ b/integrations/google_vertex/tests/test_captioner.py @@ -22,14 +22,12 @@ def test_init(mock_model_class, mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.captioner.vertexai") @patch("haystack_integrations.components.generators.google_vertex.captioner.ImageTextModel") def test_to_dict(_mock_model_class, _mock_vertexai): - captioner = VertexAIImageCaptioner( - model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" - ) + captioner = VertexAIImageCaptioner(model="imagetext", number_of_results=1, language="it") assert captioner.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.captioner.VertexAIImageCaptioner", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "number_of_results": 1, "language": "it", @@ -45,14 +43,15 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.captioner.VertexAIImageCaptioner", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, + "location": None, "number_of_results": 1, "language": "it", }, } ) assert captioner._model_name == "imagetext" - assert captioner._project_id == "myproject-123456" + assert captioner._project_id is None assert captioner._location is None assert captioner._kwargs == {"number_of_results": 1, "language": "it"} assert captioner._model is not None @@ -63,9 +62,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model_class.from_pretrained.return_value = mock_model - captioner = VertexAIImageCaptioner( - model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" - ) + captioner = VertexAIImageCaptioner(model="imagetext", number_of_results=1, language="it") image = ByteStream(data=b"image data") captioner.run(image=image) diff --git a/integrations/google_vertex/tests/test_code_generator.py b/integrations/google_vertex/tests/test_code_generator.py index 129954062..132f4c945 100644 --- a/integrations/google_vertex/tests/test_code_generator.py +++ b/integrations/google_vertex/tests/test_code_generator.py @@ -22,14 +22,12 @@ def test_init(mock_model_class, mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.code_generator.vertexai") @patch("haystack_integrations.components.generators.google_vertex.code_generator.CodeGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): - generator = VertexAICodeGenerator( - model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 - ) + generator = VertexAICodeGenerator(model="code-bison", candidate_count=3, temperature=0.5) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.code_generator.VertexAICodeGenerator", "init_parameters": { "model": "code-bison", - "project_id": "myproject-123456", + "project_id": None, "location": None, "candidate_count": 3, "temperature": 0.5, @@ -45,14 +43,15 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.code_generator.VertexAICodeGenerator", "init_parameters": { "model": "code-bison", - "project_id": "myproject-123456", + "project_id": None, + "location": None, "candidate_count": 2, "temperature": 0.5, }, } ) assert generator._model_name == "code-bison" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == {"candidate_count": 2, "temperature": 0.5} assert generator._model is not None @@ -64,9 +63,7 @@ def test_run_calls_predict(mock_model_class, _mock_vertexai): mock_model = Mock() mock_model.predict.return_value = TextGenerationResponse("answer", None) mock_model_class.from_pretrained.return_value = mock_model - generator = VertexAICodeGenerator( - model="code-bison", project_id="myproject-123456", candidate_count=1, temperature=0.5 - ) + generator = VertexAICodeGenerator(model="code-bison", candidate_count=1, temperature=0.5) prefix = "def print_json(data):\n" generator.run(prefix=prefix) diff --git a/integrations/google_vertex/tests/test_gemini.py b/integrations/google_vertex/tests/test_gemini.py index 1543f3ccf..ff692c6f4 100644 --- a/integrations/google_vertex/tests/test_gemini.py +++ b/integrations/google_vertex/tests/test_gemini.py @@ -1,38 +1,17 @@ from unittest.mock import MagicMock, Mock, patch +import pytest from haystack import Pipeline from haystack.components.builders import PromptBuilder from haystack.dataclasses import StreamingChunk from vertexai.generative_models import ( - FunctionDeclaration, GenerationConfig, HarmBlockThreshold, HarmCategory, - Tool, - ToolConfig, ) from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") @@ -48,50 +27,42 @@ def test_init(mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) mock_vertexai_init.assert_called() assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] - assert gemini._tool_config == tool_config assert gemini._system_instruction == "Please provide brief answers." +def test_init_fails_with_tools_or_tool_config(): + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tools=["tool1", "tool2"]) + + with pytest.raises(TypeError, match="VertexAIGeminiGenerator does not support `tools`"): + VertexAIGeminiGenerator(tool_config={"custom": "config"}) + + @patch("haystack_integrations.components.generators.google_vertex.gemini.vertexai_init") @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") def test_to_dict(_mock_vertexai_init, _mock_generative_model): - gemini = VertexAIGeminiGenerator( - project_id="TestID123", - ) + gemini = VertexAIGeminiGenerator() assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "model": "gemini-1.5-flash", - "project_id": "TestID123", + "project_id": None, "location": None, "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, - "tool_config": None, "system_instruction": None, }, } @@ -110,20 +81,11 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - tool = Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC]) - tool_config = ToolConfig( - function_calling_config=ToolConfig.FunctionCallingConfig( - mode=ToolConfig.FunctionCallingConfig.Mode.ANY, - allowed_function_names=["get_current_weather_func"], - ) - ) - gemini = VertexAIGeminiGenerator( project_id="TestID123", + location="TestLocation", generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], - tool_config=tool_config, system_instruction="Please provide brief answers.", ) assert gemini.to_dict() == { @@ -131,7 +93,7 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): "init_parameters": { "model": "gemini-1.5-flash", "project_id": "TestID123", - "location": None, + "location": "TestLocation", "generation_config": { "temperature": 0.5, "top_p": 0.5, @@ -142,33 +104,6 @@ def test_to_dict_with_params(_mock_vertexai_init, _mock_generative_model): }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, "streaming_callback": None, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location", - "parameters": { - "type_": "OBJECT", - "properties": { - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, - }, - "required": ["location"], - }, - } - ] - } - ], - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -181,23 +116,21 @@ def test_from_dict(_mock_vertexai_init, _mock_generative_model): { "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { - "project_id": "TestID123", + "project_id": None, + "location": None, "model": "gemini-1.5-flash", "generation_config": None, "safety_settings": None, - "tools": None, "streaming_callback": None, - "tool_config": None, "system_instruction": None, }, } ) assert gemini._model_name == "gemini-1.5-flash" - assert gemini._project_id == "TestID123" + assert gemini._project_id is None + assert gemini._location is None assert gemini._safety_settings is None - assert gemini._tools is None - assert gemini._tool_config is None assert gemini._system_instruction is None assert gemini._generation_config is None @@ -210,6 +143,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "type": "haystack_integrations.components.generators.google_vertex.gemini.VertexAIGeminiGenerator", "init_parameters": { "project_id": "TestID123", + "location": "TestLocation", "model": "gemini-1.5-flash", "generation_config": { "temperature": 0.5, @@ -220,34 +154,7 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): "stop_sequences": ["stop"], }, "safety_settings": {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH}, - "tools": [ - { - "function_declarations": [ - { - "name": "get_current_weather", - "parameters": { - "type_": "OBJECT", - "properties": { - "unit": {"type_": "STRING", "enum": ["celsius", "fahrenheit"]}, - "location": { - "type_": "STRING", - "description": "The city and state, e.g. San Francisco, CA", - }, - }, - "required": ["location"], - }, - "description": "Get the current weather in a given location", - } - ] - } - ], "streaming_callback": None, - "tool_config": { - "function_calling_config": { - "mode": ToolConfig.FunctionCallingConfig.Mode.ANY, - "allowed_function_names": ["get_current_weather_func"], - } - }, "system_instruction": "Please provide brief answers.", }, } @@ -255,14 +162,10 @@ def test_from_dict_with_param(_mock_vertexai_init, _mock_generative_model): assert gemini._model_name == "gemini-1.5-flash" assert gemini._project_id == "TestID123" + assert gemini._location == "TestLocation" assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert repr(gemini._tools) == repr([Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])]) assert isinstance(gemini._generation_config, GenerationConfig) - assert isinstance(gemini._tool_config, ToolConfig) assert gemini._system_instruction == "Please provide brief answers." - assert ( - gemini._tool_config._gapic_tool_config.function_calling_config.mode == ToolConfig.FunctionCallingConfig.Mode.ANY - ) @patch("haystack_integrations.components.generators.google_vertex.gemini.GenerativeModel") @@ -271,7 +174,7 @@ def test_run(mock_generative_model): mock_model.generate_content.return_value = MagicMock() mock_generative_model.return_value = mock_model - gemini = VertexAIGeminiGenerator(project_id="TestID123", location=None) + gemini = VertexAIGeminiGenerator() response = gemini.run(["What's the weather like today?"]) @@ -297,7 +200,7 @@ def streaming_callback(_chunk: StreamingChunk) -> None: nonlocal streaming_callback_called streaming_callback_called = True - gemini = VertexAIGeminiGenerator(model="gemini-pro", project_id="TestID123", streaming_callback=streaming_callback) + gemini = VertexAIGeminiGenerator(model="gemini-pro", streaming_callback=streaming_callback) gemini.run(["Come on, stream!"]) assert streaming_callback_called diff --git a/integrations/google_vertex/tests/test_image_generator.py b/integrations/google_vertex/tests/test_image_generator.py index 6cd42a11c..860c1ec43 100644 --- a/integrations/google_vertex/tests/test_image_generator.py +++ b/integrations/google_vertex/tests/test_image_generator.py @@ -30,7 +30,6 @@ def test_init(mock_model_class, mock_vertexai): def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageGenerator( model="imagetext", - project_id="myproject-123456", guidance_scale=12, number_of_images=3, ) @@ -38,7 +37,7 @@ def test_to_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.image_generator.VertexAIImageGenerator", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "guidance_scale": 12, "number_of_images": 3, @@ -54,7 +53,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.image_generator.VertexAIImageGenerator", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "guidance_scale": 12, "number_of_images": 3, @@ -62,7 +61,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } ) assert generator._model_name == "imagetext" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == { "guidance_scale": 12, @@ -78,7 +77,6 @@ def test_run_calls_generate_images(mock_model_class, _mock_vertexai): mock_model_class.from_pretrained.return_value = mock_model generator = VertexAIImageGenerator( model="imagetext", - project_id="myproject-123456", guidance_scale=12, number_of_images=3, ) diff --git a/integrations/google_vertex/tests/test_question_answering.py b/integrations/google_vertex/tests/test_question_answering.py index 3f414f0e0..a36e47b6f 100644 --- a/integrations/google_vertex/tests/test_question_answering.py +++ b/integrations/google_vertex/tests/test_question_answering.py @@ -26,14 +26,13 @@ def test_init(mock_model_class, mock_vertexai): def test_to_dict(_mock_model_class, _mock_vertexai): generator = VertexAIImageQA( model="imagetext", - project_id="myproject-123456", number_of_results=3, ) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.question_answering.VertexAIImageQA", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "number_of_results": 3, }, @@ -48,14 +47,14 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.question_answering.VertexAIImageQA", "init_parameters": { "model": "imagetext", - "project_id": "myproject-123456", + "project_id": None, "location": None, "number_of_results": 3, }, } ) assert generator._model_name == "imagetext" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == {"number_of_results": 3} @@ -68,7 +67,6 @@ def test_run_calls_ask_question(mock_model_class, _mock_vertexai): mock_model_class.from_pretrained.return_value = mock_model generator = VertexAIImageQA( model="imagetext", - project_id="myproject-123456", number_of_results=3, ) diff --git a/integrations/google_vertex/tests/test_text_generator.py b/integrations/google_vertex/tests/test_text_generator.py index 3e5248dc7..cc3f15312 100644 --- a/integrations/google_vertex/tests/test_text_generator.py +++ b/integrations/google_vertex/tests/test_text_generator.py @@ -24,14 +24,12 @@ def test_init(mock_model_class, mock_vertexai): @patch("haystack_integrations.components.generators.google_vertex.text_generator.TextGenerationModel") def test_to_dict(_mock_model_class, _mock_vertexai): grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") - generator = VertexAITextGenerator( - model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source - ) + generator = VertexAITextGenerator(model="text-bison", temperature=0.2, grounding_source=grounding_source) assert generator.to_dict() == { "type": "haystack_integrations.components.generators.google_vertex.text_generator.VertexAITextGenerator", "init_parameters": { "model": "text-bison", - "project_id": "myproject-123456", + "project_id": None, "location": None, "temperature": 0.2, "grounding_source": { @@ -55,7 +53,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): "type": "haystack_integrations.components.generators.google_vertex.text_generator.VertexAITextGenerator", "init_parameters": { "model": "text-bison", - "project_id": "myproject-123456", + "project_id": None, "location": None, "temperature": 0.2, "grounding_source": { @@ -71,7 +69,7 @@ def test_from_dict(_mock_model_class, _mock_vertexai): } ) assert generator._model_name == "text-bison" - assert generator._project_id == "myproject-123456" + assert generator._project_id is None assert generator._location is None assert generator._kwargs == { "temperature": 0.2, @@ -86,9 +84,7 @@ def test_run_calls_get_captions(mock_model_class, _mock_vertexai): mock_model.predict.return_value = MagicMock() mock_model_class.from_pretrained.return_value = mock_model grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") - generator = VertexAITextGenerator( - model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source - ) + generator = VertexAITextGenerator(model="text-bison", temperature=0.2, grounding_source=grounding_source) prompt = "What is the answer?" generator.run(prompt=prompt) diff --git a/integrations/instructor_embedders/CHANGELOG.md b/integrations/instructor_embedders/CHANGELOG.md new file mode 100644 index 000000000..2c22fa90c --- /dev/null +++ b/integrations/instructor_embedders/CHANGELOG.md @@ -0,0 +1,61 @@ +# Changelog + +## [integrations/instructor_embedders-v0.4.1] - 2024-10-18 + +### ๐Ÿ“š Documentation + +- Disable-class-def (#556) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + +## [integrations/instructor_embedders-v0.4.0] - 2024-02-21 + +### ๐Ÿ› Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### ๐Ÿ“š Documentation + +- Update category slug (#442) + +## [integrations/instructor_embedders-v0.3.0] - 2024-02-15 + +### ๐Ÿš€ Features + +- Generate API docs (#380) + +### ๐Ÿ“š Documentation + +- Update paths and titles (#397) + +## [integrations/instructor_embedders-v0.2.1] - 2024-01-30 + +## [integrations/instructor_embedders-v0.2.0] - 2024-01-22 + +### โš™๏ธ Miscellaneous Tasks + +- Replace - with _ (#114) +- chore!: Rename `model_name_or_path` to `model` in the Instructor integration (#229) + +* rename model_name_or_path in doc embedder + +* fix tests for doc embedder + +* rename model_name_or_path to model in text embedder + +* fix tests for text embedder + +* feedback + + diff --git a/integrations/instructor_embedders/pyproject.toml b/integrations/instructor_embedders/pyproject.toml index 458c0ae0c..e165aa10d 100644 --- a/integrations/instructor_embedders/pyproject.toml +++ b/integrations/instructor_embedders/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ # Commenting some of them to not interfere with the dependencies of Haystack. #"transformers==4.20.0", "datasets>=2.2.0", + "huggingface_hub<0.26.0", #"pyarrow==8.0.0", "jsonlines", "numpy", @@ -64,6 +65,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/instructor_embedders-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -83,8 +85,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["38", "39", "310", "311"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index cbd8df479..c89eeacb4 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -43,6 +43,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/jina-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -58,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/langfuse/CHANGELOG.md b/integrations/langfuse/CHANGELOG.md index ccd68ded3..7cf1cc0c4 100644 --- a/integrations/langfuse/CHANGELOG.md +++ b/integrations/langfuse/CHANGELOG.md @@ -1,5 +1,25 @@ # Changelog +## [integrations/langfuse-v0.6.0] - 2024-11-18 + +### ๐Ÿš€ Features + +- Add support for ttft (#1161) + +### โš™๏ธ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + +## [integrations/langfuse-v0.5.0] - 2024-10-01 + +### โš™๏ธ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) + +### Langfuse + +- Add invocation_context to identify traces (#1089) + ## [integrations/langfuse-v0.4.0] - 2024-09-17 ### ๐Ÿš€ Features diff --git a/integrations/langfuse/example/chat.py b/integrations/langfuse/example/chat.py index 0d9c42787..2308ed1f4 100644 --- a/integrations/langfuse/example/chat.py +++ b/integrations/langfuse/example/chat.py @@ -49,6 +49,16 @@ ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) + response = pipe.run( + data={ + "prompt_builder": { + "template_variables": {"location": "Berlin"}, + "template": messages, + }, + "tracer": { + "invocation_context": {"some_key": "some_value"}, + }, + } + ) print(response["llm"]["replies"][0]) print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/pyproject.toml b/integrations/langfuse/pyproject.toml index 61de4596c..44397b572 100644 --- a/integrations/langfuse/pyproject.toml +++ b/integrations/langfuse/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/langfuse-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -64,8 +65,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py index 29f58d722..ff0a7c6ed 100644 --- a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py @@ -1,8 +1,12 @@ -from haystack import component, tracing +from typing import Any, Dict, Optional + +from haystack import component, logging, tracing from haystack_integrations.tracing.langfuse import LangfuseTracer from langfuse import Langfuse +logger = logging.getLogger(__name__) + @component class LangfuseConnector: @@ -105,12 +109,20 @@ def __init__(self, name: str, public: bool = False): tracing.enable_tracing(self.tracer) @component.output_types(name=str, trace_url=str) - def run(self): + def run(self, invocation_context: Optional[Dict[str, Any]] = None): """ Runs the LangfuseConnector component. + :param invocation_context: A dictionary with additional context for the invocation. This parameter + is useful when users want to mark this particular invocation with additional information, e.g. + a run id from their own execution framework, user id, etc. These key-value pairs are then visible + in the Langfuse traces. :returns: A dictionary with the following keys: - `name`: The name of the tracing component. - `trace_url`: The URL to the tracing data. """ + logger.debug( + "Langfuse tracer invoked with the following context: '{invocation_context}'", + invocation_context=invocation_context, + ) return {"name": self.name, "trace_url": self.tracer.get_trace_url()} diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 94064a0d1..c1f8d4d93 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -1,7 +1,10 @@ import contextlib import os -from typing import Any, Dict, Iterator, Optional, Union +from contextvars import ContextVar +from datetime import datetime +from typing import Any, Dict, Iterator, List, Optional, Union +from haystack import logging from haystack.components.generators.openai_utils import _convert_message_to_openai_format from haystack.dataclasses import ChatMessage from haystack.tracing import Span, Tracer, tracer @@ -9,6 +12,8 @@ import langfuse +logger = logging.getLogger(__name__) + HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" _SUPPORTED_GENERATORS = [ "AzureOpenAIGenerator", @@ -28,6 +33,17 @@ ] _ALL_SUPPORTED_GENERATORS = _SUPPORTED_GENERATORS + _SUPPORTED_CHAT_GENERATORS +# These are the keys used by Haystack for traces and span. +# We keep them here to avoid making typos when using them. +_PIPELINE_RUN_KEY = "haystack.pipeline.run" +_COMPONENT_NAME_KEY = "haystack.component.name" +_COMPONENT_TYPE_KEY = "haystack.component.type" +_COMPONENT_OUTPUT_KEY = "haystack.component.output" + +# Context var used to keep track of tracing related info. +# This mainly useful for parents spans. +tracing_context_var: ContextVar[Dict[Any, Any]] = ContextVar("tracing_context", default={}) + class LangfuseSpan(Span): """ @@ -82,7 +98,7 @@ def set_content_tag(self, key: str, value: Any) -> None: self._data[key] = value - def raw_span(self) -> Any: + def raw_span(self) -> "Union[langfuse.client.StatefulSpanClient, langfuse.client.StatefulTraceClient]": """ Return the underlying span instance. @@ -111,75 +127,90 @@ def __init__(self, tracer: "langfuse.Langfuse", name: str = "Haystack", public: and only accessible to the Langfuse account owner. """ self._tracer = tracer - self._context: list[LangfuseSpan] = [] + self._context: List[LangfuseSpan] = [] self._name = name self._public = public self.enforce_flush = os.getenv(HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR, "true").lower() == "true" @contextlib.contextmanager - def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> Iterator[Span]: - """ - Start and manage a new trace span. - :param operation_name: The name of the operation. - :param tags: A dictionary of tags to attach to the span. - :return: A context manager yielding the span. - """ + def trace( + self, operation_name: str, tags: Optional[Dict[str, Any]] = None, parent_span: Optional[Span] = None + ) -> Iterator[Span]: tags = tags or {} - span_name = tags.get("haystack.component.name", operation_name) - - if tags.get("haystack.component.type") in _ALL_SUPPORTED_GENERATORS: - span = LangfuseSpan(self.current_span().raw_span().generation(name=span_name)) + span_name = tags.get(_COMPONENT_NAME_KEY, operation_name) + + # Create new span depending whether there's a parent span or not + if not parent_span: + if operation_name != _PIPELINE_RUN_KEY: + logger.warning( + "Creating a new trace without a parent span is not recommended for operation '{operation_name}'.", + operation_name=operation_name, + ) + # Create a new trace if no parent span is provided + span = LangfuseSpan( + self._tracer.trace( + name=self._name, + public=self._public, + id=tracing_context_var.get().get("trace_id"), + user_id=tracing_context_var.get().get("user_id"), + session_id=tracing_context_var.get().get("session_id"), + tags=tracing_context_var.get().get("tags"), + version=tracing_context_var.get().get("version"), + ) + ) + elif tags.get(_COMPONENT_TYPE_KEY) in _ALL_SUPPORTED_GENERATORS: + span = LangfuseSpan(parent_span.raw_span().generation(name=span_name)) else: - span = LangfuseSpan(self.current_span().raw_span().span(name=span_name)) + span = LangfuseSpan(parent_span.raw_span().span(name=span_name)) self._context.append(span) span.set_tags(tags) yield span - if tags.get("haystack.component.type") in _SUPPORTED_GENERATORS: - meta = span._data.get("haystack.component.output", {}).get("meta") + # Update span metadata based on component type + if tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_GENERATORS: + # Haystack returns one meta dict for each message, but the 'usage' value + # is always the same, let's just pick the first item + meta = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("meta") if meta: - # Haystack returns one meta dict for each message, but the 'usage' value - # is always the same, let's just pick the first item m = meta[0] span._span.update(usage=m.get("usage") or None, model=m.get("model")) - elif tags.get("haystack.component.type") in _SUPPORTED_CHAT_GENERATORS: - replies = span._data.get("haystack.component.output", {}).get("replies") + elif tags.get(_COMPONENT_TYPE_KEY) in _SUPPORTED_CHAT_GENERATORS: + replies = span._data.get(_COMPONENT_OUTPUT_KEY, {}).get("replies") if replies: meta = replies[0].meta - span._span.update(usage=meta.get("usage") or None, model=meta.get("model")) - - pipeline_input = tags.get("haystack.pipeline.input_data", None) - if pipeline_input: - span._span.update(input=tags["haystack.pipeline.input_data"]) - pipeline_output = tags.get("haystack.pipeline.output_data", None) - if pipeline_output: - span._span.update(output=tags["haystack.pipeline.output_data"]) - - span.raw_span().end() + completion_start_time = meta.get("completion_start_time") + if completion_start_time: + try: + completion_start_time = datetime.fromisoformat(completion_start_time) + except ValueError: + logger.error(f"Failed to parse completion_start_time: {completion_start_time}") + completion_start_time = None + span._span.update( + usage=meta.get("usage") or None, + model=meta.get("model"), + completion_start_time=completion_start_time, + ) + + raw_span = span.raw_span() + if isinstance(raw_span, langfuse.client.StatefulSpanClient): + raw_span.end() self._context.pop() - if len(self._context) == 1: - # The root span has to be a trace, which need to be removed from the context after the pipeline run - self._context.pop() - - if self.enforce_flush: - self.flush() + if self.enforce_flush: + self.flush() def flush(self): self._tracer.flush() - def current_span(self) -> Span: + def current_span(self) -> Optional[Span]: """ - Return the currently active span. + Return the current active span. - :return: The currently active span. + :return: The current span if available, else None. """ - if not self._context: - # The root span has to be a trace - self._context.append(LangfuseSpan(self._tracer.trace(name=self._name, public=self._public))) - return self._context[-1] + return self._context[-1] if self._context else None def get_trace_url(self) -> str: """ diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index c6bf4acdf..42ae1d07d 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -1,9 +1,43 @@ -import os +import datetime from unittest.mock import MagicMock, Mock, patch +from haystack.dataclasses import ChatMessage from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer +class MockSpan: + def __init__(self): + self._data = {} + self._span = self + self.operation_name = "operation_name" + + def raw_span(self): + return self + + def span(self, name=None): + # assert correct operation name passed to the span + assert name == "operation_name" + return self + + def update(self, **kwargs): + self._data.update(kwargs) + + def generation(self, name=None): + return self + + def end(self): + pass + + +class MockTracer: + + def trace(self, name, **kwargs): + return MockSpan() + + def flush(self): + pass + + class TestLangfuseTracer: # LangfuseTracer can be initialized with a Langfuse instance, a name and a boolean value for public. @@ -35,7 +69,7 @@ def test_create_new_span(self): tracer = LangfuseTracer(tracer=mock_tracer, name="Haystack", public=False) with tracer.trace("operation_name", tags={"tag1": "value1", "tag2": "value2"}) as span: - assert len(tracer._context) == 2, "The trace span should have been added to the the root context span" + assert len(tracer._context) == 1, "The trace span should have been added to the the root context span" assert span.raw_span().operation_name == "operation_name" assert span.raw_span().metadata == {"tag1": "value1", "tag2": "value2"} @@ -45,37 +79,6 @@ def test_create_new_span(self): # check that update method is called on the span instance with the provided key value pairs def test_update_span_with_pipeline_input_output_data(self): - class MockTracer: - - def trace(self, name, **kwargs): - return MockSpan() - - def flush(self): - pass - - class MockSpan: - def __init__(self): - self._data = {} - self._span = self - self.operation_name = "operation_name" - - def raw_span(self): - return self - - def span(self, name=None): - # assert correct operation name passed to the span - assert name == "operation_name" - return self - - def update(self, **kwargs): - self._data.update(kwargs) - - def generation(self, name=None): - return self - - def end(self): - pass - tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: assert span.raw_span()._data["metadata"] == {"haystack.pipeline.input_data": "hello"} @@ -83,6 +86,40 @@ def end(self): with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.output_data": "bye"}) as span: assert span.raw_span()._data["metadata"] == {"haystack.pipeline.output_data": "bye"} + def test_trace_generation(self): + tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + tags = { + "haystack.component.type": "OpenAIChatGenerator", + "haystack.component.output": { + "replies": [ + ChatMessage.from_assistant( + "", meta={"completion_start_time": "2021-07-27T16:02:08.012345", "model": "test_model"} + ) + ] + }, + } + with tracer.trace(operation_name="operation_name", tags=tags) as span: + ... + assert span.raw_span()._data["usage"] is None + assert span.raw_span()._data["model"] == "test_model" + assert span.raw_span()._data["completion_start_time"] == datetime.datetime(2021, 7, 27, 16, 2, 8, 12345) + + def test_trace_generation_invalid_start_time(self): + tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + tags = { + "haystack.component.type": "OpenAIChatGenerator", + "haystack.component.output": { + "replies": [ + ChatMessage.from_assistant("", meta={"completion_start_time": "foobar", "model": "test_model"}), + ] + }, + } + with tracer.trace(operation_name="operation_name", tags=tags) as span: + ... + assert span.raw_span()._data["usage"] is None + assert span.raw_span()._data["model"] == "test_model" + assert span.raw_span()._data["completion_start_time"] is None + def test_update_span_gets_flushed_by_default(self): tracer_mock = Mock() diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 936064e0a..e5737b861 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -43,27 +43,37 @@ def test_tracing_integration(llm_class, env_var, expected_trace): ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) + response = pipe.run( + data={ + "prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}, + "tracer": {"invocation_context": {"user_id": "user_42"}}, + } + ) assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] - # add a random delay between 1 and 3 seconds to make sure the trace is flushed - # and that the trace is available in Langfuse when we fetch it below - time.sleep(random.uniform(1, 3)) - - url = "https://cloud.langfuse.com/api/public/traces/" trace_url = response["tracer"]["trace_url"] uuid = os.path.basename(urlparse(trace_url).path) + url = f"https://cloud.langfuse.com/api/public/traces/{uuid}" - try: - response = requests.get( - url + uuid, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) + # Poll the Langfuse API a bit as the trace might not be ready right away + attempts = 5 + delay = 1 + while attempts >= 0: + res = requests.get( + url, auth=HTTPBasicAuth(os.environ["LANGFUSE_PUBLIC_KEY"], os.environ["LANGFUSE_SECRET_KEY"]) ) - assert response.status_code == 200, f"Failed to retrieve data from Langfuse API: {response.status_code}" + if attempts > 0 and res.status_code != 200: + attempts -= 1 + time.sleep(delay) + delay *= 2 + continue + assert res.status_code == 200, f"Failed to retrieve data from Langfuse API: {res.status_code}" # check if the trace contains the expected LLM name - assert expected_trace in str(response.content) + assert expected_trace in str(res.content) # check if the trace contains the expected generation span - assert "GENERATION" in str(response.content) - except requests.exceptions.RequestException as e: - pytest.fail(f"Failed to retrieve data from Langfuse API: {e}") + assert "GENERATION" in str(res.content) + # check if the trace contains the expected user_id + assert "user_42" in str(res.content) + break diff --git a/integrations/llama_cpp/pyproject.toml b/integrations/llama_cpp/pyproject.toml index 673df575a..a33434e1b 100644 --- a/integrations/llama_cpp/pyproject.toml +++ b/integrations/llama_cpp/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87,<0.3.0"] +dependencies = ["haystack-ai", "llama-cpp-python>=0.2.87"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/llama_cpp#readme" @@ -45,6 +45,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/llama_cpp-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -65,8 +66,9 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/llama_cpp/tests/test_chat_generator.py b/integrations/llama_cpp/tests/test_chat_generator.py index 1d4c9cf82..802fe9128 100644 --- a/integrations/llama_cpp/tests/test_chat_generator.py +++ b/integrations/llama_cpp/tests/test_chat_generator.py @@ -342,7 +342,7 @@ def generator(self, model_path, capsys): hf_tokenizer_path = "meetkai/functionary-small-v2.4-GGUF" generator = LlamaCppChatGenerator( model=model_path, - n_ctx=8192, + n_ctx=512, n_batch=512, model_kwargs={ "chat_format": "functionary-v2", @@ -399,7 +399,6 @@ def test_function_call_and_execute(self, generator): "type": "string", "description": "The city and state, e.g. San Francisco, CA", }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, }, "required": ["location"], }, @@ -407,7 +406,8 @@ def test_function_call_and_execute(self, generator): } ] - response = generator.run(messages=messages, generation_kwargs={"tools": tools}) + tool_choice = {"type": "function", "function": {"name": "get_current_temperature"}} + response = generator.run(messages=messages, generation_kwargs={"tools": tools, "tool_choice": tool_choice}) available_functions = { "get_current_temperature": self.get_current_temperature, diff --git a/integrations/mistral/pyproject.toml b/integrations/mistral/pyproject.toml index 16f332331..06d02c0aa 100644 --- a/integrations/mistral/pyproject.toml +++ b/integrations/mistral/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/mistral-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/mongodb_atlas/pyproject.toml b/integrations/mongodb_atlas/pyproject.toml index 95ed6c03a..bdf1a2dc1 100644 --- a/integrations/mongodb_atlas/pyproject.toml +++ b/integrations/mongodb_atlas/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/mongodb_atlas-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index f66536fe5..a536e431d 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,15 +1,23 @@ # Changelog -## [unreleased] +## [integrations/nvidia-v0.1.1] - 2024-11-14 + +### ๐Ÿ› Bug Fixes + +- Fixes to NvidiaRanker (#1191) + +## [integrations/nvidia-v0.1.0] - 2024-11-13 ### ๐Ÿš€ Features - Update default embedding model to nvidia/nv-embedqa-e5-v5 (#1015) - Add NVIDIA NIM ranker support (#1023) +- Raise error when attempting to embed empty documents/strings with Nvidia embedders (#1118) ### ๐Ÿ› Bug Fixes - Lints in `nvidia-haystack` (#993) +- Missing Nvidia embedding truncate mode (#1043) ### ๐Ÿšœ Refactor @@ -27,6 +35,8 @@ - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) ### Docs diff --git a/integrations/nvidia/pyproject.toml b/integrations/nvidia/pyproject.toml index b5c6dd205..586b50848 100644 --- a/integrations/nvidia/pyproject.toml +++ b/integrations/nvidia/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "requests"] +dependencies = ["haystack-ai", "requests", "tqdm"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme" @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/nvidia-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/nvidia/src/haystack_integrations/__init__.py b/integrations/nvidia/src/haystack_integrations/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/__init__.py b/integrations/nvidia/src/haystack_integrations/components/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py index bc2d9372c..827ad7dc6 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .document_embedder import NvidiaDocumentEmbedder from .text_embedder import NvidiaTextEmbedder from .truncate import EmbeddingTruncateMode diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index 3e911e4f4..606ec78fd 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Tuple, Union @@ -5,10 +9,9 @@ from haystack.utils import Secret, deserialize_secrets_inplace from tqdm import tqdm +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode - _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -167,7 +170,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaDocumentEmbedder": :returns: The deserialized component. """ - deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: @@ -231,6 +236,11 @@ def run(self, documents: List[Document]): ) raise TypeError(msg) + for doc in documents: + if not doc.content: + msg = f"Document '{doc.id}' has no content to embed." + raise ValueError(msg) + texts_to_embed = self._prepare_texts_to_embed(documents) embeddings, metadata = self._embed_batch(texts_to_embed, self.batch_size) for doc, emb in zip(documents, embeddings): diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 0387c32b7..4b7072f33 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -1,13 +1,16 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.embedders.nvidia.truncate import EmbeddingTruncateMode from haystack_integrations.utils.nvidia import NimBackend, is_hosted, url_validation -from .truncate import EmbeddingTruncateMode - _DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" @@ -175,6 +178,9 @@ def run(self, text: str): "In case you want to embed a list of Documents, please use the NvidiaDocumentEmbedder." ) raise TypeError(msg) + elif not text: + msg = "Cannot embed an empty string." + raise ValueError(msg) assert self.backend is not None text_to_embed = self.prefix + text + self.suffix diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py index 3a8eb9d07..931c3cce3 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/generators/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py index 18354ea17..b809d83b9 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .generator import NvidiaGenerator __all__ = ["NvidiaGenerator"] diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 3eadcc5df..5bf71a9e1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py index 29cb2f7f5..05daa1c54 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .ranker import NvidiaRanker __all__ = ["NvidiaRanker"] diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 46c736883..9938b37d1 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -1,17 +1,23 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings from typing import Any, Dict, List, Optional, Union -from haystack import Document, component, default_from_dict, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict, logging from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.components.rankers.nvidia.truncate import RankerTruncateMode from haystack_integrations.utils.nvidia import NimBackend, url_validation -from .truncate import RankerTruncateMode +logger = logging.getLogger(__name__) _DEFAULT_MODEL = "nvidia/nv-rerankqa-mistral-4b-v3" _MODEL_ENDPOINT_MAP = { "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v1/reranking", } @@ -50,7 +56,7 @@ def __init__( model: Optional[str] = None, truncate: Optional[Union[RankerTruncateMode, str]] = None, api_url: Optional[str] = None, - api_key: Optional[Secret] = None, + api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), top_k: int = 5, ): """ @@ -99,6 +105,7 @@ def __init__( self._api_key = Secret.from_env_var("NVIDIA_API_KEY") self._top_k = top_k self._initialized = False + self._backend: Optional[Any] = None def to_dict(self) -> Dict[str, Any]: """ @@ -112,7 +119,7 @@ def to_dict(self) -> Dict[str, Any]: top_k=self._top_k, truncate=self._truncate, api_url=self._api_url, - api_key=self._api_key, + api_key=self._api_key.to_dict() if self._api_key else None, ) @classmethod @@ -123,7 +130,9 @@ def from_dict(cls, data: Dict[str, Any]) -> "NvidiaRanker": :param data: A dictionary containing the ranker's attributes. :returns: The deserialized ranker. """ - deserialize_secrets_inplace(data, keys=["api_key"]) + init_parameters = data.get("init_parameters", {}) + if init_parameters: + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) return default_from_dict(cls, data) def warm_up(self): @@ -169,16 +178,16 @@ def run( msg = "The ranker has not been loaded. Please call warm_up() before running." raise RuntimeError(msg) if not isinstance(query, str): - msg = "Ranker expects the `query` parameter to be a string." + msg = "NvidiaRanker expects the `query` parameter to be a string." raise TypeError(msg) if not isinstance(documents, list): - msg = "Ranker expects the `documents` parameter to be a list." + msg = "NvidiaRanker expects the `documents` parameter to be a list." raise TypeError(msg) if not all(isinstance(doc, Document) for doc in documents): - msg = "Ranker expects the `documents` parameter to be a list of Document objects." + msg = "NvidiaRanker expects the `documents` parameter to be a list of Document objects." raise TypeError(msg) if top_k is not None and not isinstance(top_k, int): - msg = "Ranker expects the `top_k` parameter to be an integer." + msg = "NvidiaRanker expects the `top_k` parameter to be an integer." raise TypeError(msg) if len(documents) == 0: @@ -186,6 +195,7 @@ def run( top_k = top_k if top_k is not None else self._top_k if top_k < 1: + logger.warning("top_k should be at least 1, returning nothing") warnings.warn("top_k should be at least 1, returning nothing", stacklevel=2) return {"documents": []} diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py index 3b5d7f40a..649ceaf9d 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/truncate.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from enum import Enum diff --git a/integrations/nvidia/src/haystack_integrations/utils/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py index da301d29d..f08cda6cd 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from .nim_backend import Model, NimBackend from .utils import is_hosted, url_validation diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py index 0d1f57e5c..0279cf608 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/nim_backend.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple @@ -50,16 +54,20 @@ def __init__( def embed(self, texts: List[str]) -> Tuple[List[List[float]], Dict[str, Any]]: url = f"{self.api_url}/embeddings" - res = self.session.post( - url, - json={ - "model": self.model, - "input": texts, - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() + try: + res = self.session.post( + url, + json={ + "model": self.model, + "input": texts, + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + except requests.HTTPError as e: + msg = f"Failed to query embedding endpoint: Error - {e.response.text}" + raise ValueError(msg) from e data = res.json() # Sort the embeddings by index, we don't know whether they're out of order or not @@ -73,21 +81,25 @@ def generate(self, prompt: str) -> Tuple[List[str], List[Dict[str, Any]]]: # This is the same for local containers and the cloud API. url = f"{self.api_url}/chat/completions" - res = self.session.post( - url, - json={ - "model": self.model, - "messages": [ - { - "role": "user", - "content": prompt, - }, - ], - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() + try: + res = self.session.post( + url, + json={ + "model": self.model, + "messages": [ + { + "role": "user", + "content": prompt, + }, + ], + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + except requests.HTTPError as e: + msg = f"Failed to query chat completion endpoint: Error - {e.response.text}" + raise ValueError(msg) from e completions = res.json() choices = completions["choices"] @@ -139,17 +151,21 @@ def rank( ) -> List[Dict[str, Any]]: url = endpoint or f"{self.api_url}/ranking" - res = self.session.post( - url, - json={ - "model": self.model, - "query": {"text": query}, - "passages": [{"text": doc.content} for doc in documents], - **self.model_kwargs, - }, - timeout=REQUEST_TIMEOUT, - ) - res.raise_for_status() + try: + res = self.session.post( + url, + json={ + "model": self.model, + "query": {"text": query}, + "passages": [{"text": doc.content} for doc in documents], + **self.model_kwargs, + }, + timeout=REQUEST_TIMEOUT, + ) + res.raise_for_status() + except requests.HTTPError as e: + msg = f"Failed to rank endpoint: Error - {e.response.text}" + raise ValueError(msg) from e data = res.json() assert "rankings" in data, f"Expected 'rankings' in response, got {data}" diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py index 7d4dfc3b4..f07989405 100644 --- a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -1,9 +1,13 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import warnings -from typing import List +from typing import List, Optional from urllib.parse import urlparse, urlunparse -def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: +def url_validation(api_url: str, default_api_url: Optional[str], allowed_paths: List[str]) -> str: """ Validate and normalize an API URL. diff --git a/integrations/nvidia/tests/__init__.py b/integrations/nvidia/tests/__init__.py index 47611e0b9..38adc654d 100644 --- a/integrations/nvidia/tests/__init__.py +++ b/integrations/nvidia/tests/__init__.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + from .conftest import MockBackend __all__ = ["MockBackend"] diff --git a/integrations/nvidia/tests/conftest.py b/integrations/nvidia/tests/conftest.py index a6c78ba4e..b6346c672 100644 --- a/integrations/nvidia/tests/conftest.py +++ b/integrations/nvidia/tests/conftest.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Dict, List, Optional, Tuple import pytest diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py index 426bacc25..506fbc385 100644 --- a/integrations/nvidia/tests/test_base_url.py +++ b/integrations/nvidia/tests/test_base_url.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index bef0f996e..7e0e02f3d 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -104,7 +108,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", @@ -122,15 +126,32 @@ def from_dict(self, monkeypatch): }, } component = NvidiaDocumentEmbedder.from_dict(data) - assert component.model == "nvolveqa_40k" + assert component.model == "playground_nvolveqa_40k" assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" + assert component.batch_size == 10 + assert component.progress_bar is False + assert component.meta_fields_to_embed == ["test_field"] + assert component.embedding_separator == " | " + assert component.truncate == EmbeddingTruncateMode.START + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", + "init_parameters": {}, + } + component = NvidiaDocumentEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" assert component.batch_size == 32 assert component.progress_bar assert component.meta_fields_to_embed == [] assert component.embedding_separator == "\n" - assert component.truncate == EmbeddingTruncateMode.START + assert component.truncate is None def test_prepare_texts_to_embed_w_metadata(self): documents = [ @@ -326,6 +347,17 @@ def test_run_wrong_input_format(self): with pytest.raises(TypeError, match="NvidiaDocumentEmbedder expects a list of Documents as input"): embedder.run(documents=list_integers_input) + def test_run_empty_document(self): + model = "playground_nvolveqa_40k" + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaDocumentEmbedder(model, api_key=api_key) + + embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) + + with pytest.raises(ValueError, match="no content to embed"): + embedder.run(documents=[Document(content="")]) + def test_run_on_empty_list(self): model = "playground_nvolveqa_40k" api_key = Secret.from_token("fake-api-key") diff --git a/integrations/nvidia/tests/test_embedding_truncate_mode.py b/integrations/nvidia/tests/test_embedding_truncate_mode.py index e74d0308c..16f9112ea 100644 --- a/integrations/nvidia/tests/test_embedding_truncate_mode.py +++ b/integrations/nvidia/tests/test_embedding_truncate_mode.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import pytest from haystack_integrations.components.embedders.nvidia import EmbeddingTruncateMode diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 0bd8b1fc6..055830ae5 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2024-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 + import os import pytest diff --git a/integrations/nvidia/tests/test_ranker.py b/integrations/nvidia/tests/test_ranker.py index 566fd18a8..d66bb0f65 100644 --- a/integrations/nvidia/tests/test_ranker.py +++ b/integrations/nvidia/tests/test_ranker.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import re from typing import Any, Optional, Union @@ -256,3 +260,48 @@ def test_warm_up_once(self, monkeypatch) -> None: backend = client._backend client.warm_up() assert backend == client._backend + + def test_to_dict(self) -> None: + client = NvidiaRanker() + assert client.to_dict() == { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + }, + } + + def test_from_dict(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": { + "model": "nvidia/nv-rerankqa-mistral-4b-v3", + "top_k": 5, + "truncate": None, + "api_url": None, + "api_key": {"type": "env_var", "env_vars": ["NVIDIA_API_KEY"], "strict": True}, + }, + } + ) + assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client._top_k == 5 + assert client._truncate is None + assert client._api_url is None + assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") + + def test_from_dict_defaults(self) -> None: + client = NvidiaRanker.from_dict( + { + "type": "haystack_integrations.components.rankers.nvidia.ranker.NvidiaRanker", + "init_parameters": {}, + } + ) + assert client._model == "nvidia/nv-rerankqa-mistral-4b-v3" + assert client._top_k == 5 + assert client._truncate is None + assert client._api_url is None + assert client._api_key == Secret.from_env_var("NVIDIA_API_KEY") diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 7c8428cc2..278fa5191 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -1,3 +1,7 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + import os import pytest @@ -77,7 +81,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): }, } - def from_dict(self, monkeypatch): + def test_from_dict(self, monkeypatch): monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", @@ -95,7 +99,20 @@ def from_dict(self, monkeypatch): assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" - assert component.truncate == "START" + assert component.truncate == EmbeddingTruncateMode.START + + def test_from_dict_defaults(self, monkeypatch): + monkeypatch.setenv("NVIDIA_API_KEY", "fake-api-key") + data = { + "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", + "init_parameters": {}, + } + component = NvidiaTextEmbedder.from_dict(data) + assert component.model == "nvidia/nv-embedqa-e5-v5" + assert component.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia" + assert component.prefix == "" + assert component.suffix == "" + assert component.truncate is None @pytest.mark.usefixtures("mock_local_models") def test_run_default_model(self): @@ -147,6 +164,17 @@ def test_run_wrong_input_format(self): with pytest.raises(TypeError, match="NvidiaTextEmbedder expects a string as an input"): embedder.run(text=list_integers_input) + def test_run_empty_string(self): + model = "playground_nvolveqa_40k" + api_key = Secret.from_token("fake-api-key") + embedder = NvidiaTextEmbedder(model, api_key=api_key) + + embedder.warm_up() + embedder.backend = MockBackend(model=model, api_key=api_key) + + with pytest.raises(ValueError, match="empty string"): + embedder.run(text="") + @pytest.mark.skipif( not os.environ.get("NVIDIA_NIM_EMBEDDER_MODEL", None) or not os.environ.get("NVIDIA_NIM_ENDPOINT_URL", None), reason="Export an env var called NVIDIA_NIM_EMBEDDER_MODEL containing the hosted model name and " diff --git a/integrations/ollama/CHANGELOG.md b/integrations/ollama/CHANGELOG.md index 5da725e87..55c6aa7b7 100644 --- a/integrations/ollama/CHANGELOG.md +++ b/integrations/ollama/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/ollama-v1.1.0] - 2024-10-11 + +### ๐Ÿš€ Features + +- Add `keep_alive` parameter to Ollama Generators (#1131) + +### โš™๏ธ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) + ## [integrations/ollama-v1.0.1] - 2024-09-26 ### ๐Ÿ› Bug Fixes diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index bc8555140..598d1d214 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -46,6 +46,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/ollama-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -66,8 +67,9 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py index 9502a187e..558fd593e 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage, StreamingChunk @@ -38,6 +38,7 @@ def __init__( url: str = "http://localhost:11434", generation_kwargs: Optional[Dict[str, Any]] = None, timeout: int = 120, + keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -54,12 +55,21 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param keep_alive: + The option that controls how long the model will stay loaded into memory following the request. + If not set, it will use the default value from the Ollama (5 minutes). + The value can be set to: + - a duration string (such as "10m" or "24h") + - a number in seconds (such as 3600) + - any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") + - '0' which will unload the model immediately after generating a response. """ self.timeout = timeout self.generation_kwargs = generation_kwargs or {} self.url = url self.model = model + self.keep_alive = keep_alive self.streaming_callback = streaming_callback self._client = Client(host=self.url, timeout=self.timeout) @@ -76,6 +86,7 @@ def to_dict(self) -> Dict[str, Any]: self, model=self.model, url=self.url, + keep_alive=self.keep_alive, generation_kwargs=self.generation_kwargs, timeout=self.timeout, streaming_callback=callback_name, @@ -165,7 +176,9 @@ def run( stream = self.streaming_callback is not None messages = [self._message_to_dict(message) for message in messages] - response = self._client.chat(model=self.model, messages=messages, stream=stream, options=generation_kwargs) + response = self._client.chat( + model=self.model, messages=messages, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs + ) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py index d92932c3e..058948e8a 100644 --- a/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py +++ b/integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk @@ -36,6 +36,7 @@ def __init__( template: Optional[str] = None, raw: bool = False, timeout: int = 120, + keep_alive: Optional[Union[float, str]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -59,6 +60,14 @@ def __init__( :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. + :param keep_alive: + The option that controls how long the model will stay loaded into memory following the request. + If not set, it will use the default value from the Ollama (5 minutes). + The value can be set to: + - a duration string (such as "10m" or "24h") + - a number in seconds (such as 3600) + - any negative number which will keep the model loaded in memory (e.g. -1 or "-1m") + - '0' which will unload the model immediately after generating a response. """ self.timeout = timeout self.raw = raw @@ -66,6 +75,7 @@ def __init__( self.system_prompt = system_prompt self.model = model self.url = url + self.keep_alive = keep_alive self.generation_kwargs = generation_kwargs or {} self.streaming_callback = streaming_callback @@ -87,6 +97,7 @@ def to_dict(self) -> Dict[str, Any]: system_prompt=self.system_prompt, model=self.model, url=self.url, + keep_alive=self.keep_alive, generation_kwargs=self.generation_kwargs, streaming_callback=callback_name, ) @@ -172,7 +183,9 @@ def run( stream = self.streaming_callback is not None - response = self._client.generate(model=self.model, prompt=prompt, stream=stream, options=generation_kwargs) + response = self._client.generate( + model=self.model, prompt=prompt, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs + ) if stream: chunks: List[StreamingChunk] = self._handle_streaming_response(response) diff --git a/integrations/ollama/tests/test_chat_generator.py b/integrations/ollama/tests/test_chat_generator.py index a46758df3..5ac9289aa 100644 --- a/integrations/ollama/tests/test_chat_generator.py +++ b/integrations/ollama/tests/test_chat_generator.py @@ -26,12 +26,14 @@ def test_init_default(self): assert component.url == "http://localhost:11434" assert component.generation_kwargs == {} assert component.timeout == 120 + assert component.keep_alive is None def test_init(self): component = OllamaChatGenerator( model="llama2", url="http://my-custom-endpoint:11434", generation_kwargs={"temperature": 0.5}, + keep_alive="10m", timeout=5, ) @@ -39,6 +41,7 @@ def test_init(self): assert component.url == "http://my-custom-endpoint:11434" assert component.generation_kwargs == {"temperature": 0.5} assert component.timeout == 5 + assert component.keep_alive == "10m" def test_to_dict(self): component = OllamaChatGenerator( @@ -46,6 +49,7 @@ def test_to_dict(self): streaming_callback=print_streaming_chunk, url="custom_url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + keep_alive="5m", ) data = component.to_dict() assert data == { @@ -53,6 +57,7 @@ def test_to_dict(self): "init_parameters": { "timeout": 120, "model": "llama2", + "keep_alive": "5m", "url": "custom_url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -66,6 +71,7 @@ def test_from_dict(self): "timeout": 120, "model": "llama2", "url": "custom_url", + "keep_alive": "5m", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, @@ -75,6 +81,7 @@ def test_from_dict(self): assert component.streaming_callback is print_streaming_chunk assert component.url == "custom_url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.keep_alive == "5m" def test_build_message_from_ollama_response(self): model = "some_model" diff --git a/integrations/ollama/tests/test_generator.py b/integrations/ollama/tests/test_generator.py index c4c6906db..b02370234 100644 --- a/integrations/ollama/tests/test_generator.py +++ b/integrations/ollama/tests/test_generator.py @@ -45,6 +45,7 @@ def test_init_default(self): assert component.template is None assert component.raw is False assert component.timeout == 120 + assert component.keep_alive is None assert component.streaming_callback is None def test_init(self): @@ -57,6 +58,7 @@ def callback(x: StreamingChunk): generation_kwargs={"temperature": 0.5}, system_prompt="You are Luigi from Super Mario Bros.", timeout=5, + keep_alive="10m", streaming_callback=callback, ) assert component.model == "llama2" @@ -66,6 +68,7 @@ def callback(x: StreamingChunk): assert component.template is None assert component.raw is False assert component.timeout == 5 + assert component.keep_alive == "10m" assert component.streaming_callback == callback component = OllamaGenerator() @@ -80,6 +83,7 @@ def callback(x: StreamingChunk): "model": "orca-mini", "url": "http://localhost:11434", "streaming_callback": None, + "keep_alive": None, "generation_kwargs": {}, }, } @@ -89,6 +93,7 @@ def test_to_dict_with_parameters(self): model="llama2", streaming_callback=print_streaming_chunk, url="going_to_51_pegasi_b_for_weekend", + keep_alive="10m", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) data = component.to_dict() @@ -100,6 +105,7 @@ def test_to_dict_with_parameters(self): "template": None, "system_prompt": None, "model": "llama2", + "keep_alive": "10m", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -115,6 +121,7 @@ def test_from_dict(self): "template": None, "system_prompt": None, "model": "llama2", + "keep_alive": "5m", "url": "going_to_51_pegasi_b_for_weekend", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -125,6 +132,7 @@ def test_from_dict(self): assert component.streaming_callback is print_streaming_chunk assert component.url == "going_to_51_pegasi_b_for_weekend" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + assert component.keep_alive == "5m" @pytest.mark.integration def test_ollama_generator_run_streaming(self): diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md index 713848915..afd8a57c2 100644 --- a/integrations/opensearch/CHANGELOG.md +++ b/integrations/opensearch/CHANGELOG.md @@ -1,6 +1,10 @@ # Changelog -## [integrations/opensearch-v1.0.0] - 2024-09-12 +## [integrations/opensearch-v1.1.0] - 2024-10-29 + +### ๐Ÿš€ Features + +- Efficient knn filtering support for OpenSearch (#1134) ### ๐Ÿ“š Documentation @@ -13,6 +17,9 @@ ### โš™๏ธ Miscellaneous Tasks - OpenSearch - remove legacy filter support (#1067) +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) ### Docs @@ -83,8 +90,6 @@ This PR will also push the docs to Readme - Fix links in docstrings (#188) - - ### ๐Ÿšœ Refactor - Use `hatch_vcs` to manage integrations versioning (#103) @@ -95,15 +100,12 @@ This PR will also push the docs to Readme - Fix import and increase version (#77) - - ## [integrations/opensearch-v0.1.0] - 2023-12-04 ### ๐Ÿ› Bug Fixes - Fix license headers - ## [integrations/opensearch-v0.0.2] - 2023-11-30 ### ๐Ÿš€ Features diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 24f1653bd..54c194470 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/opensearch-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -63,8 +64,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "boto3"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "boto3"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py index e159634cf..1e9bb9132 100644 --- a/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py +++ b/integrations/opensearch/src/haystack_integrations/components/retrievers/opensearch/embedding_retriever.py @@ -31,6 +31,7 @@ def __init__( filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, custom_query: Optional[Dict[str, Any]] = None, raise_on_failure: bool = True, + efficient_filtering: bool = False, ): """ Create the OpenSearchEmbeddingRetriever component. @@ -85,6 +86,8 @@ def __init__( :param raise_on_failure: If `True`, raises an exception if the API call fails. If `False`, logs a warning and returns an empty list. + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". :raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore. """ @@ -100,6 +103,7 @@ def __init__( ) self._custom_query = custom_query self._raise_on_failure = raise_on_failure + self._efficient_filtering = efficient_filtering def to_dict(self) -> Dict[str, Any]: """ @@ -116,6 +120,7 @@ def to_dict(self) -> Dict[str, Any]: filter_policy=self._filter_policy.value, custom_query=self._custom_query, raise_on_failure=self._raise_on_failure, + efficient_filtering=self._efficient_filtering, ) @classmethod @@ -146,6 +151,7 @@ def run( filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, custom_query: Optional[Dict[str, Any]] = None, + efficient_filtering: Optional[bool] = None, ): """ Retrieve documents using a vector similarity metric. @@ -196,6 +202,9 @@ def run( ) ``` + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". + :returns: Dictionary with key "documents" containing the retrieved Documents. - documents: List of Document similar to `query_embedding`. @@ -208,6 +217,8 @@ def run( top_k = self._top_k if custom_query is None: custom_query = self._custom_query + if efficient_filtering is None: + efficient_filtering = self._efficient_filtering docs: List[Document] = [] @@ -217,6 +228,7 @@ def run( filters=filters, top_k=top_k, custom_query=custom_query, + efficient_filtering=efficient_filtering, ) except Exception as e: if self._raise_on_failure: diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 6f7a6c96e..4ec2420b3 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -438,6 +438,7 @@ def _embedding_retrieval( filters: Optional[Dict[str, Any]] = None, top_k: int = 10, custom_query: Optional[Dict[str, Any]] = None, + efficient_filtering: bool = False, ) -> List[Document]: """ Retrieves documents that are most similar to the query embedding using a vector similarity metric. @@ -474,6 +475,8 @@ def _embedding_retrieval( } ``` + :param efficient_filtering: If `True`, the filter will be applied during the approximate kNN search. + This is only supported for knn engines "faiss" and "lucene" and does not work with the default "nmslib". :raises ValueError: If `query_embedding` is an empty list :returns: List of Document that are most similar to `query_embedding` """ @@ -509,7 +512,10 @@ def _embedding_retrieval( } if filters: - body["query"]["bool"]["filter"] = normalize_filters(filters) + if efficient_filtering: + body["query"]["bool"]["must"][0]["knn"]["embedding"]["filter"] = normalize_filters(filters) + else: + body["query"]["bool"]["filter"] = normalize_filters(filters) body["size"] = top_k diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index 9cc4bf4ea..043f59891 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -337,6 +337,27 @@ def document_store_embedding_dim_4(self, request): yield store store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + @pytest.fixture + def document_store_embedding_dim_4_faiss(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = ["https://localhost:9200"] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = OpenSearchDocumentStore( + hosts=hosts, + index=index, + http_auth=("admin", "admin"), + verify_certs=False, + embedding_dim=4, + method={"space_type": "innerproduct", "engine": "faiss", "name": "hnsw"}, + ) + yield store + store.client.indices.delete(index=index, params={"ignore": [400, 404]}) + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): """ The OpenSearchDocumentStore.filter_documents() method returns a Documents with their score set. @@ -690,6 +711,29 @@ def test_embedding_retrieval_with_filters(self, document_store_embedding_dim_4: assert len(results) == 1 assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_with_filters_efficient_filtering( + self, document_store_embedding_dim_4_faiss: OpenSearchDocumentStore + ): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + document_store_embedding_dim_4_faiss.write_documents(docs) + + filters = {"field": "meta_field", "operator": "==", "value": "custom_value"} + results = document_store_embedding_dim_4_faiss._embedding_retrieval( + query_embedding=[0.1, 0.1, 0.1, 0.1], + filters=filters, + efficient_filtering=True, + ) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + def test_embedding_retrieval_pagination(self, document_store_embedding_dim_4: OpenSearchDocumentStore): """ Test that handling of pagination works as expected, when the matching documents are > 10. diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 75c191946..84e9828ca 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -19,6 +19,7 @@ def test_init_default(): assert retriever._filters == {} assert retriever._top_k == 10 assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._efficient_filtering is False retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") assert retriever._filter_policy == FilterPolicy.REPLACE @@ -82,6 +83,7 @@ def test_to_dict(_mock_opensearch_client): "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": True, + "efficient_filtering": False, }, } @@ -101,6 +103,7 @@ def test_from_dict(_mock_opensearch_client): "filter_policy": "replace", "custom_query": {"some": "custom query"}, "raise_on_failure": False, + "efficient_filtering": True, }, } retriever = OpenSearchEmbeddingRetriever.from_dict(data) @@ -110,6 +113,7 @@ def test_from_dict(_mock_opensearch_client): assert retriever._custom_query == {"some": "custom query"} assert retriever._raise_on_failure is False assert retriever._filter_policy == FilterPolicy.REPLACE + assert retriever._efficient_filtering is True # For backwards compatibility with older versions of the retriever without a filter policy data = { @@ -139,6 +143,7 @@ def test_run(): filters={}, top_k=10, custom_query=None, + efficient_filtering=False, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -150,7 +155,11 @@ def test_run_init_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever( - document_store=mock_store, filters={"from": "init"}, top_k=11, custom_query="custom_query" + document_store=mock_store, + filters={"from": "init"}, + top_k=11, + custom_query="custom_query", + efficient_filtering=True, ) res = retriever.run(query_embedding=[0.5, 0.7]) mock_store._embedding_retrieval.assert_called_once_with( @@ -158,6 +167,7 @@ def test_run_init_params(): filters={"from": "init"}, top_k=11, custom_query="custom_query", + efficient_filtering=True, ) assert len(res) == 1 assert len(res["documents"]) == 1 @@ -169,12 +179,13 @@ def test_run_time_params(): mock_store = Mock(spec=OpenSearchDocumentStore) mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])] retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11) - res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9) + res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, efficient_filtering=True) mock_store._embedding_retrieval.assert_called_once_with( query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9, custom_query=None, + efficient_filtering=True, ) assert len(res) == 1 assert len(res["documents"]) == 1 diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index 305af6042..6149997ed 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -54,6 +54,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/optimum-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -76,8 +77,9 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 014d163bc..3f20dfbb1 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -42,6 +42,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/pgvector-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -62,8 +63,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index a02c46200..8e9c0f2fc 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ -CREATE TABLE IF NOT EXISTS {table_name} ( +CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} ( id VARCHAR(128) PRIMARY KEY, embedding VECTOR({embedding_dimension}), content TEXT, @@ -36,7 +36,7 @@ """ INSERT_STATEMENT = """ -INSERT INTO {table_name} +INSERT INTO {schema_name}.{table_name} (id, embedding, content, dataframe, blob_data, blob_meta, blob_mime_type, meta) VALUES (%(id)s, %(embedding)s, %(content)s, %(dataframe)s, %(blob_data)s, %(blob_meta)s, %(blob_mime_type)s, %(meta)s) """ @@ -54,7 +54,7 @@ KEYWORD_QUERY = """ SELECT {table_name}.*, ts_rank_cd(to_tsvector({language}, content), query) AS score -FROM {table_name}, plainto_tsquery({language}, %s) query +FROM {schema_name}.{table_name}, plainto_tsquery({language}, %s) query WHERE to_tsvector({language}, content) @@ query """ @@ -78,6 +78,7 @@ def __init__( self, *, connection_string: Secret = Secret.from_env_var("PG_CONN_STR"), + schema_name: str = "public", table_name: str = "haystack_documents", language: str = "english", embedding_dimension: int = 768, @@ -96,7 +97,12 @@ def __init__( A specific table to store Haystack documents will be created if it doesn't exist yet. :param connection_string: The connection string to use to connect to the PostgreSQL database, defined as an - environment variable, e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"` + environment variable. It can be provided in either URI format + e.g.: `PG_CONN_STR="postgresql://USER:PASSWORD@HOST:PORT/DB_NAME"`, or keyword/value format + e.g.: `PG_CONN_STR="host=HOST port=PORT dbname=DBNAME user=USER password=PASSWORD"` + See [PostgreSQL Documentation](https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) + for more details. + :param schema_name: The name of the schema the table is created in. The schema must already exist. :param table_name: The name of the table to use to store Haystack documents. :param language: The language to be used to parse query and document content in keyword retrieval. To see the list of available languages, you can run the following SQL query in your PostgreSQL database: @@ -133,6 +139,7 @@ def __init__( self.connection_string = connection_string self.table_name = table_name + self.schema_name = schema_name self.embedding_dimension = embedding_dimension if vector_function not in VALID_VECTOR_FUNCTIONS: msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" @@ -203,6 +210,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, connection_string=self.connection_string.to_dict(), + schema_name=self.schema_name, table_name=self.table_name, embedding_dimension=self.embedding_dimension, vector_function=self.vector_function, @@ -262,7 +270,9 @@ def _create_table_if_not_exists(self): """ create_sql = SQL(CREATE_TABLE_STATEMENT).format( - table_name=Identifier(self.table_name), embedding_dimension=SQLLiteral(self.embedding_dimension) + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + embedding_dimension=SQLLiteral(self.embedding_dimension), ) self._execute_sql(create_sql, error_msg="Could not create table in PgvectorDocumentStore") @@ -270,12 +280,18 @@ def _create_table_if_not_exists(self): def delete_table(self): """ Deletes the table used to store Haystack documents. - The name of the table (`table_name`) is defined when initializing the `PgvectorDocumentStore`. + The name of the schema (`schema_name`) and the name of the table (`table_name`) + are defined when initializing the `PgvectorDocumentStore`. """ + delete_sql = SQL("DROP TABLE IF EXISTS {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + ) - delete_sql = SQL("DROP TABLE IF EXISTS {table_name}").format(table_name=Identifier(self.table_name)) - - self._execute_sql(delete_sql, error_msg=f"Could not delete table {self.table_name} in PgvectorDocumentStore") + self._execute_sql( + delete_sql, + error_msg=f"Could not delete table {self.schema_name}.{self.table_name} in PgvectorDocumentStore", + ) def _create_keyword_index_if_not_exists(self): """ @@ -283,15 +299,16 @@ def _create_keyword_index_if_not_exists(self): """ index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.keyword_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.keyword_index_name), "Could not check if keyword index exists", ).fetchone() ) sql_create_index = SQL( - "CREATE INDEX {index_name} ON {table_name} USING GIN (to_tsvector({language}, content))" + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING GIN (to_tsvector({language}, content))" ).format( + schema_name=Identifier(self.schema_name), index_name=Identifier(self.keyword_index_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), @@ -314,8 +331,8 @@ def _handle_hnsw(self): index_exists = bool( self._execute_sql( - "SELECT 1 FROM pg_indexes WHERE tablename = %s AND indexname = %s", - (self.table_name, self.hnsw_index_name), + "SELECT 1 FROM pg_indexes WHERE schemaname = %s AND tablename = %s AND indexname = %s", + (self.schema_name, self.table_name, self.hnsw_index_name), "Could not check if HNSW index exists", ).fetchone() ) @@ -345,8 +362,13 @@ def _create_hnsw_index(self): if key in HNSW_INDEX_CREATION_VALID_KWARGS } - sql_create_index = SQL("CREATE INDEX {index_name} ON {table_name} USING hnsw (embedding {ops}) ").format( - index_name=Identifier(self.hnsw_index_name), table_name=Identifier(self.table_name), ops=SQL(pg_ops) + sql_create_index = SQL( + "CREATE INDEX {index_name} ON {schema_name}.{table_name} USING hnsw (embedding {ops}) " + ).format( + schema_name=Identifier(self.schema_name), + index_name=Identifier(self.hnsw_index_name), + table_name=Identifier(self.table_name), + ops=SQL(pg_ops), ) if actual_hnsw_index_creation_kwargs: @@ -365,7 +387,9 @@ def count_documents(self) -> int: Returns how many documents are present in the document store. """ - sql_count = SQL("SELECT COUNT(*) FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_count = SQL("SELECT COUNT(*) FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) count = self._execute_sql(sql_count, error_msg="Could not count documents in PgvectorDocumentStore").fetchone()[ 0 @@ -391,7 +415,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc msg = "Invalid filter syntax. See https://docs.haystack.deepset.ai/docs/metadata-filtering for details." raise ValueError(msg) - sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + sql_filter = SQL("SELECT * FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) params = () if filters: @@ -430,7 +456,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D db_documents = self._from_haystack_to_pg_documents(documents) - sql_insert = SQL(INSERT_STATEMENT).format(table_name=Identifier(self.table_name)) + sql_insert = SQL(INSERT_STATEMENT).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name) + ) if policy == DuplicatePolicy.OVERWRITE: sql_insert += SQL(UPDATE_STATEMENT) @@ -539,8 +567,10 @@ def delete_documents(self, document_ids: List[str]) -> None: document_ids_str = ", ".join(f"'{document_id}'" for document_id in document_ids) - delete_sql = SQL("DELETE FROM {table_name} WHERE id IN ({document_ids_str})").format( - table_name=Identifier(self.table_name), document_ids_str=SQL(document_ids_str) + delete_sql = SQL("DELETE FROM {schema_name}.{table_name} WHERE id IN ({document_ids_str})").format( + schema_name=Identifier(self.schema_name), + table_name=Identifier(self.table_name), + document_ids_str=SQL(document_ids_str), ) self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") @@ -566,6 +596,7 @@ def _keyword_retrieval( raise ValueError(msg) sql_select = SQL(KEYWORD_QUERY).format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), language=SQLLiteral(self.language), query=SQLLiteral(query), @@ -639,7 +670,8 @@ def _embedding_retrieval( elif vector_function == "l2_distance": score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" - sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + sql_select = SQL("SELECT *, {score} FROM {schema_name}.{table_name}").format( + schema_name=Identifier(self.schema_name), table_name=Identifier(self.table_name), score=SQL(score_definition), ) diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 93514b71c..4af4fc8de 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -47,6 +47,7 @@ def test_init(monkeypatch): monkeypatch.setenv("PG_CONN_STR", "some_connection_string") document_store = PgvectorDocumentStore( + schema_name="my_schema", table_name="my_table", embedding_dimension=512, vector_function="l2_distance", @@ -59,6 +60,7 @@ def test_init(monkeypatch): keyword_index_name="my_keyword_index", ) + assert document_store.schema_name == "my_schema" assert document_store.table_name == "my_table" assert document_store.embedding_dimension == 512 assert document_store.vector_function == "l2_distance" @@ -93,6 +95,7 @@ def test_to_dict(monkeypatch): "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, "table_name": "my_table", + "schema_name": "public", "embedding_dimension": 512, "vector_function": "l2_distance", "recreate_table": True, diff --git a/integrations/pgvector/tests/test_retrievers.py b/integrations/pgvector/tests/test_retrievers.py index 290891307..4125c3e3a 100644 --- a/integrations/pgvector/tests/test_retrievers.py +++ b/integrations/pgvector/tests/test_retrievers.py @@ -50,6 +50,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", @@ -175,6 +176,7 @@ def test_to_dict(self, mock_store): "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", "init_parameters": { "connection_string": {"env_vars": ["PG_CONN_STR"], "strict": True, "type": "env_var"}, + "schema_name": "public", "table_name": "haystack", "embedding_dimension": 768, "vector_function": "cosine_similarity", diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index 3f2e4d6bd..1a19cb2b7 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -44,6 +44,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/pinecone-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -66,8 +67,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "numpy"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/qdrant/CHANGELOG.md b/integrations/qdrant/CHANGELOG.md index a275529f8..57dd257d5 100644 --- a/integrations/qdrant/CHANGELOG.md +++ b/integrations/qdrant/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [integrations/qdrant-v7.0.0] - 2024-10-29 + +### โš™๏ธ Miscellaneous Tasks + +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + ## [integrations/qdrant-v6.0.0] - 2024-09-13 ## [integrations/qdrant-v5.1.0] - 2024-09-12 @@ -105,8 +112,6 @@ - Fix haystack-ai pin (#649) - - ## [integrations/qdrant-v3.2.0] - 2024-03-27 ### ๐Ÿš€ Features @@ -117,15 +122,11 @@ ### ๐Ÿ› Bug Fixes - Fix linter errors (#282) - - - Fix order of API docs (#447) This PR will also push the docs to Readme - Fixes (#518) - - ### ๐Ÿšœ Refactor - [**breaking**] Qdrant - update secret management (#405) @@ -156,8 +157,6 @@ This PR will also push the docs to Readme - Fix import paths for beta5 (#237) - - ### ๐Ÿšœ Refactor - Use `hatch_vcs` to manage integrations versioning (#103) diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index 898fd2dcf..f0e7e7342 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -44,6 +44,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/qdrant-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -58,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index 88afd8f65..c8cb9a393 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -362,7 +362,6 @@ def write_documents( document_objects = self._handle_duplicate_documents( documents=documents, - index=self.index, policy=policy, ) @@ -468,7 +467,6 @@ def get_documents_generator( def get_documents_by_id( self, ids: List[str], - index: Optional[str] = None, ) -> List[Document]: """ Retrieves documents from Qdrant by their IDs. @@ -480,13 +478,11 @@ def get_documents_by_id( :returns: A list of documents. """ - index = index or self.index - documents: List[Document] = [] ids = [convert_id(_id) for _id in ids] records = self.client.retrieve( - collection_name=index, + collection_name=self.index, ids=ids, with_payload=True, with_vectors=True, @@ -987,7 +983,6 @@ def recreate_collection( def _handle_duplicate_documents( self, documents: List[Document], - index: Optional[str] = None, policy: DuplicatePolicy = None, ): """ @@ -995,31 +990,28 @@ def _handle_duplicate_documents( documents that are not in the index yet. :param documents: A list of Haystack Document objects. - :param index: name of the index :param policy: The duplicate policy to use when writing documents. :returns: A list of Haystack Document objects. """ - index = index or self.index if policy in (DuplicatePolicy.SKIP, DuplicatePolicy.FAIL): - documents = self._drop_duplicate_documents(documents, index) - documents_found = self.get_documents_by_id(ids=[doc.id for doc in documents], index=index) + documents = self._drop_duplicate_documents(documents) + documents_found = self.get_documents_by_id(ids=[doc.id for doc in documents]) ids_exist_in_db: List[str] = [doc.id for doc in documents_found] if len(ids_exist_in_db) > 0 and policy == DuplicatePolicy.FAIL: - msg = f"Document with ids '{', '.join(ids_exist_in_db)} already exists in index = '{index}'." + msg = f"Document with ids '{', '.join(ids_exist_in_db)} already exists in index = '{self.index}'." raise DuplicateDocumentError(msg) documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents)) return documents - def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]: + def _drop_duplicate_documents(self, documents: List[Document]) -> List[Document]: """ Drop duplicate documents based on same hash ID. :param documents: A list of Haystack Document objects. - :param index: Name of the index. :returns: A list of Haystack Document objects. """ _hash_ids: Set = set() @@ -1030,7 +1022,7 @@ def _drop_duplicate_documents(self, documents: List[Document], index: Optional[s logger.info( "Duplicate Documents: Document with id '%s' already exists in index '%s'", document.id, - index or self.index, + self.index, ) continue _documents.append(document) diff --git a/integrations/ragas/pyproject.toml b/integrations/ragas/pyproject.toml index dd56e35f6..179bcce16 100644 --- a/integrations/ragas/pyproject.toml +++ b/integrations/ragas/pyproject.toml @@ -41,6 +41,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/ragas-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -61,8 +62,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive {args:src/}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] diff --git a/integrations/snowflake/pyproject.toml b/integrations/snowflake/pyproject.toml index 355e9d090..0e089fe79 100644 --- a/integrations/snowflake/pyproject.toml +++ b/integrations/snowflake/pyproject.toml @@ -43,6 +43,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/snowflake-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -58,8 +59,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/unstructured/pyproject.toml b/integrations/unstructured/pyproject.toml index 88bd463b2..14f58594c 100644 --- a/integrations/unstructured/pyproject.toml +++ b/integrations/unstructured/pyproject.toml @@ -40,6 +40,7 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/unstructured-v[0-9]*"' [tool.hatch.envs.default] +installer = "uv" dependencies = [ "coverage[toml]>=6.5", "pytest", @@ -60,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.9", "3.10", "3.11"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = [ diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md index dacf3fef8..7f620c3a0 100644 --- a/integrations/weaviate/CHANGELOG.md +++ b/integrations/weaviate/CHANGELOG.md @@ -1,10 +1,25 @@ # Changelog -## [integrations/weaviate-v3.0.0] - 2024-09-12 +## [integrations/weaviate-v4.0.2] - 2024-11-13 + +### ๐Ÿ› Bug Fixes + +- Dependency for weaviate document store (#1186) + +## [integrations/weaviate-v4.0.1] - 2024-11-11 + +## [integrations/weaviate-v4.0.0] - 2024-10-18 + +### ๐Ÿ› Bug Fixes + +- Compatibility with Weaviate 4.9.0 (#1143) ### โš™๏ธ Miscellaneous Tasks - Weaviate - remove legacy filter support (#1070) +- Update changelog after removing legacy filters (#1083) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) ## [integrations/weaviate-v2.2.1] - 2024-09-07 @@ -58,8 +73,6 @@ This PR will also push the docs to Readme - Fix weaviate auth tests (#488) - - ### ๐Ÿ“š Documentation - Update category slug (#442) diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 22d3a160d..e88397df9 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -7,7 +7,7 @@ name = "weaviate-haystack" dynamic = ["version"] description = "An integration of Weaviate vector database with Haystack" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = "Apache-2.0" keywords = [] authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] @@ -25,8 +25,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "weaviate-client>=4.0", - "haystack-pydoc-tools", + "weaviate-client>=4.9", "python-dateutil", ] @@ -47,7 +46,8 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/weaviate-v[0-9]*"' [tool.hatch.envs.default] -dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython"] +installer = "uv" +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" @@ -61,8 +61,9 @@ docs = ["pydoc-markdown pydoc/config.yml"] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] +installer = "uv" detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["pip", "black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:}", "black --check --diff {args:.}"] diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index e312b1473..6acf0156e 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -286,6 +286,14 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]: # The embedding vector is stored separately from the rest of the data del data["embedding"] + # _split_overlap meta field is unsupported because of a bug + # https://github.com/deepset-ai/haystack-core-integrations/issues/1172 + if "_split_overlap" in data: + data.pop("_split_overlap") + logger.warning( + "Document %s has the unsupported `_split_overlap` meta field. It will be ignored.", data["_original_id"] + ) + if "sparse_embedding" in data: sparse_embedding = data.pop("sparse_embedding", None) if sparse_embedding: diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 190c23408..00af322e4 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -265,6 +265,7 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "session_pool_connections": 20, "session_pool_maxsize": 100, "session_pool_max_retries": 3, + "session_pool_timeout": 5, }, "proxies": {"http": "http://proxy:1234", "https": None, "grpc": None}, "timeout": [30, 90], @@ -302,6 +303,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "connection": { "session_pool_connections": 20, "session_pool_maxsize": 20, + "session_pool_timeout": 5, }, "proxies": {"http": "http://proxy:1234"}, "timeout": [10, 60], @@ -338,6 +340,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT assert document_store._additional_config.connection.session_pool_connections == 20 assert document_store._additional_config.connection.session_pool_maxsize == 20 + assert document_store._additional_config.connection.session_pool_timeout == 5 def test_to_data_object(self, document_store, test_files_path): doc = Document(content="test doc") @@ -505,6 +508,30 @@ def test_comparison_less_than_equal_with_iso_date(self, document_store, filterab def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs) + def test_meta_split_overlap_is_skipped(self, document_store): + doc = Document( + content="The moonlight shimmered ", + meta={ + "source_id": "62049ba1d1e1d5ebb1f6230b0b00c5356b8706c56e0b9c36b1dfc86084cd75f0", + "page_number": 1, + "split_id": 0, + "split_idx_start": 0, + "_split_overlap": [ + {"doc_id": "68ed48ba830048c5d7815874ed2de794722e6d10866b6c55349a914fd9a0df65", "range": (0, 20)} + ], + }, + ) + document_store.write_documents([doc]) + + written_doc = document_store.filter_documents()[0] + + assert written_doc.content == "The moonlight shimmered " + assert written_doc.meta["source_id"] == "62049ba1d1e1d5ebb1f6230b0b00c5356b8706c56e0b9c36b1dfc86084cd75f0" + assert written_doc.meta["page_number"] == 1.0 + assert written_doc.meta["split_id"] == 0.0 + assert written_doc.meta["split_idx_start"] == 0.0 + assert "_split_overlap" not in written_doc.meta + def test_bm25_retrieval(self, document_store): document_store.write_documents( [