diff --git a/.github/workflows/azure_ai_search.yml b/.github/workflows/azure_ai_search.yml new file mode 100644 index 000000000..1c10edc91 --- /dev/null +++ b/.github/workflows/azure_ai_search.yml @@ -0,0 +1,72 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / azure_ai_search + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/azure_ai_search/**" + - ".github/workflows/azure_ai_search.yml" + +concurrency: + group: azure_ai_search-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }} + AZURE_SEARCH_SERVICE_ENDPOINT: ${{ secrets.AZURE_SEARCH_SERVICE_ENDPOINT }} + +defaults: + run: + working-directory: integrations/azure_ai_search + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + max-parallel: 3 + matrix: + os: [ubuntu-latest, windows-latest] + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov-retry + + - name: Nightly - run unit tests with Haystack main branch + if: github.event_name == 'schedule' + run: | + hatch run pip install git+https://github.com/deepset-ai/haystack.git + hatch run cov-retry -m "not integration" + + - name: Send event to Datadog for nightly failures + if: failure() && github.event_name == 'schedule' + uses: ./.github/actions/send_failure + with: + title: | + Core integrations nightly tests failure: ${{ github.workflow }} + api-key: ${{ secrets.CORE_DATADOG_API_KEY }} diff --git a/README.md b/README.md index 2b4a83253..af83d045d 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [cohere-haystack](integrations/cohere/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | | [deepeval-haystack](integrations/deepeval/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/deepeval-haystack.svg)](https://pypi.org/project/deepeval-haystack) | [![Test / deepeval](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/deepeval.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | -| [fastembed-haystack](integrations/fastembed/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | +| [fastembed-haystack](integrations/fastembed/) | Embedder, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/fastembed-haystack.svg)](https://pypi.org/project/fastembed-haystack/) | [![Test / fastembed](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/fastembed.yml) | | [google-ai-haystack](integrations/google_ai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-ai-haystack.svg)](https://pypi.org/project/google-ai-haystack) | [![Test / google-ai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_ai.yml) | | [google-vertex-haystack](integrations/google_vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | | [instructor-embedders-haystack](integrations/instructor_embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | diff --git a/integrations/azure_ai_search/LICENSE b/integrations/azure_ai_search/LICENSE new file mode 100644 index 000000000..de4c7f39f --- /dev/null +++ b/integrations/azure_ai_search/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/azure_ai_search/README.md b/integrations/azure_ai_search/README.md new file mode 100644 index 000000000..915a23b63 --- /dev/null +++ b/integrations/azure_ai_search/README.md @@ -0,0 +1,26 @@ +# Azure AI Search Document Store for Haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/azure-ai-search-haystack.svg)](https://pypi.org/project/azure-ai-search-haystack) + +----- + +**Table of Contents** + +- [Azure AI Search Document Store for Haystack](#azure-ai-search-document-store-for-haystack) + - [Installation](#installation) + - [Examples](#examples) + - [License](#license) + +## Installation + +```console +pip install azure-ai-search-haystack +``` + +## Examples +You can find a code example showing how to use the Document Store and the Retriever in the documentation or in [this Colab](https://colab.research.google.com/drive/1YpDetI8BRbObPDEVdfqUcwhEX9UUXP-m?usp=sharing). + +## License + +`azure-ai-search-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py new file mode 100644 index 000000000..779f28935 --- /dev/null +++ b/integrations/azure_ai_search/example/document_store.py @@ -0,0 +1,44 @@ +from haystack import Document +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchDocumentStore to write and filter documents. +To run this example, you'll need an Azure Search service endpoint and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" +document_store = AzureAISearchDocumentStore( + metadata_fields={"version": float, "label": str}, + index_name="document-store-example", +) + +documents = [ + Document( + content="This is an introduction to using Python for data analysis.", + meta={"version": 1.0, "label": "chapter_one"}, + ), + Document( + content="Learn how to use Python libraries for machine learning.", + meta={"version": 1.5, "label": "chapter_two"}, + ), + Document( + content="Advanced Python techniques for data visualization.", + meta={"version": 2.0, "label": "chapter_three"}, + ), +] +document_store.write_documents(documents, policy=DuplicatePolicy.SKIP) + +filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.version", "operator": ">", "value": 1.2}, + {"field": "meta.label", "operator": "in", "value": ["chapter_one", "chapter_three"]}, + ], +} + +results = document_store.filter_documents(filters) +print(results) diff --git a/integrations/azure_ai_search/example/embedding_retrieval.py b/integrations/azure_ai_search/example/embedding_retrieval.py new file mode 100644 index 000000000..088b08653 --- /dev/null +++ b/integrations/azure_ai_search/example/embedding_retrieval.py @@ -0,0 +1,58 @@ +from haystack import Document, Pipeline +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.writers import DocumentWriter +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +""" +This example demonstrates how to use the AzureAISearchEmbeddingRetriever to retrieve documents +using embeddings based on a query. To run this example, you'll need an Azure Search service endpoint +and API key, which can either be +set as environment variables (AZURE_SEARCH_SERVICE_ENDPOINT and AZURE_SEARCH_API_KEY) or +provided directly to AzureAISearchDocumentStore(as params "api_key", "azure_endpoint"). +Otherwise you can use DefaultAzureCredential to authenticate with Azure services. +See more details at https://learn.microsoft.com/en-us/azure/search/keyless-connections?tabs=python%2Cazure-cli +""" + +document_store = AzureAISearchDocumentStore(index_name="retrieval-example") + +model = "sentence-transformers/all-mpnet-base-v2" + +documents = [ + Document(content="There are over 7,000 languages spoken around the world today."), + Document( + content="""Elephants have been observed to behave in a way that indicates a + high level of self-awareness, such as recognizing themselves in mirrors.""" + ), + Document( + content="""In certain parts of the world, like the Maldives, Puerto Rico, and + San Diego, you can witness the phenomenon of bioluminescent waves.""" + ), +] + +document_embedder = SentenceTransformersDocumentEmbedder(model=model) +document_embedder.warm_up() + +# Indexing Pipeline +indexing_pipeline = Pipeline() +indexing_pipeline.add_component(instance=document_embedder, name="doc_embedder") +indexing_pipeline.add_component( + instance=DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP), name="doc_writer" +) +indexing_pipeline.connect("doc_embedder", "doc_writer") + +indexing_pipeline.run({"doc_embedder": {"documents": documents}}) + +# Query Pipeline +query_pipeline = Pipeline() +query_pipeline.add_component("text_embedder", SentenceTransformersTextEmbedder(model=model)) +query_pipeline.add_component("retriever", AzureAISearchEmbeddingRetriever(document_store=document_store)) +query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding") + +query = "How many languages are there?" + +result = query_pipeline.run({"text_embedder": {"text": query}}) + +print(result["retriever"]["documents"][0]) diff --git a/integrations/azure_ai_search/pydoc/config.yml b/integrations/azure_ai_search/pydoc/config.yml new file mode 100644 index 000000000..ec411af60 --- /dev/null +++ b/integrations/azure_ai_search/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever", + "haystack_integrations.document_stores.azure_ai_search.document_store", + "haystack_integrations.document_stores.azure_ai_search.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: Azure AI Search integration for Haystack + category_slug: integrations-api + title: Azure AI Search + slug: integrations-azure_ai_search + order: 180 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_azure_ai_search.md diff --git a/integrations/azure_ai_search/pyproject.toml b/integrations/azure_ai_search/pyproject.toml new file mode 100644 index 000000000..49ca623e7 --- /dev/null +++ b/integrations/azure_ai_search/pyproject.toml @@ -0,0 +1,163 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "azure-ai-search-haystack" +dynamic = ["version"] +description = 'Haystack 2.x Document Store for Azure AI Search' +readme = "README.md" +requires-python = ">=3.8,<3.13" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset", email = "info@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "azure-search-documents>=11.5", "azure-identity"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/azure_ai_search" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/azure-ai-search-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/azure-ai-search-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-rerunfailures", + "pytest-xdist", + "haystack-pydoc-tools", +] + +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +test-cov-retry = "test-cov --reruns 3 --reruns-delay 30 -x" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +cov-retry = ["test-cov-retry", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:src/}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] +all = ["style", "typing"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 + +[tool.ruff.lint] +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] +exclude = ["example"] + +[tool.ruff.lint.isort] +known-first-party = ["src"] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252", "S311"] +"example/**/*" = ["T201"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + + +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure.identity.*", "mypy.*", "azure.core.*", "azure.search.documents.*"] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py new file mode 100644 index 000000000..eb75ffa6c --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/__init__.py @@ -0,0 +1,3 @@ +from .embedding_retriever import AzureAISearchEmbeddingRetriever + +__all__ = ["AzureAISearchEmbeddingRetriever"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py new file mode 100644 index 000000000..ab649f874 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/retrievers/azure_ai_search/embedding_retriever.py @@ -0,0 +1,116 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.document_stores.types import FilterPolicy +from haystack.document_stores.types.filter_policy import apply_filter_policy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore, normalize_filters + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchEmbeddingRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a vector similarity metric. + Must be connected to the AzureAISearchDocumentStore to run. + + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchEmbeddingRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + :filter_policy: Policy to determine how filters are applied. Possible options: + + """ + self._filters = filters or {} + self._top_k = top_k + self._document_store = document_store + self._filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AzureAISearchDocumentStore" + raise Exception(message) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + document_store=self._document_store.to_dict(), + filter_policy=self._filter_policy.value, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchEmbeddingRetriever": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + data["init_parameters"]["document_store"] = AzureAISearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + + # Pipelines serialized with old versions of the component might not + # have the filter_policy field. + if "filter_policy" in data["init_parameters"]: + data["init_parameters"]["filter_policy"] = FilterPolicy.from_str(data["init_parameters"]["filter_policy"]) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query_embedding: floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + + top_k = top_k or self._top_k + if filters is not None: + applied_filters = apply_filter_policy(self._filter_policy, self._filters, filters) + normalized_filters = normalize_filters(applied_filters) + else: + normalized_filters = "" + + try: + docs = self._document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=normalized_filters, + top_k=top_k, + ) + except Exception as e: + raise e + + return {"documents": docs} diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py new file mode 100644 index 000000000..635878a38 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore +from .filters import normalize_filters + +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH", "normalize_filters"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py new file mode 100644 index 000000000..0b59b6e37 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -0,0 +1,440 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +from dataclasses import asdict +from datetime import datetime +from typing import Any, Dict, List, Optional + +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.search.documents import SearchClient +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchAlgorithmMetric, + VectorSearchProfile, +) +from azure.search.documents.models import VectorizedQuery +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace + +from .errors import AzureAISearchDocumentStoreConfigError +from .filters import normalize_filters + +type_mapping = { + str: "Edm.String", + bool: "Edm.Boolean", + int: "Edm.Int32", + float: "Edm.Double", + datetime: "Edm.DateTimeOffset", +} + +DEFAULT_VECTOR_SEARCH = VectorSearch( + profiles=[ + VectorSearchProfile(name="default-vector-config", algorithm_configuration_name="cosine-algorithm-config") + ], + algorithms=[ + HnswAlgorithmConfiguration( + name="cosine-algorithm-config", + parameters=HnswParameters( + metric=VectorSearchAlgorithmMetric.COSINE, + ), + ) + ], +) + +logger = logging.getLogger(__name__) +logging.getLogger("azure").setLevel(logging.ERROR) +logging.getLogger("azure.identity").setLevel(logging.DEBUG) + + +class AzureAISearchDocumentStore: + def __init__( + self, + *, + api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008 + azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=True), # noqa: B008 + index_name: str = "default", + embedding_dimension: int = 768, + metadata_fields: Optional[Dict[str, type]] = None, + vector_search_configuration: VectorSearch = None, + **kwargs, + ): + """ + A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) + as the backend. + + :param azure_endpoint: The URL endpoint of an Azure AI Search service. + :param api_key: The API key to use for authentication. + :param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created. + :param embedding_dimension: Dimension of the embeddings. + :param metadata_fields: A dictionary of metadata keys and their types to create + additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic, + it is necessary to specify the metadata fields in advance. + (e.g. metadata_fields = {"author": str, "date": datetime}) + :param vector_search_configuration: Configuration option related to vector search. + Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. + + :param kwargs: Optional keyword parameters for Azure AI Search. + Some of the supported parameters: + - `api_version`: The Search API version to use for requests. + - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). + The audience is not considered when using a shared key. If audience is not provided, + the public cloud audience will be assumed. + + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + """ + + azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None + if not azure_endpoint: + msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + raise ValueError(msg) + + api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None + + self._client = None + self._index_client = None + self._index_fields = [] # type: List[Any] # stores all fields in the final schema of index + self._api_key = api_key + self._azure_endpoint = azure_endpoint + self._index_name = index_name + self._embedding_dimension = embedding_dimension + self._dummy_vector = [-10.0] * self._embedding_dimension + self._metadata_fields = metadata_fields + self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH + self._kwargs = kwargs + + @property + def client(self) -> SearchClient: + + # resolve secrets for authentication + resolved_endpoint = ( + self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint + ) + resolved_key = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key + + credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() + try: + if not self._index_client: + self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) + if not self._index_exists(self._index_name): + # Create a new index if it does not exist + logger.debug( + "The index '%s' does not exist. A new index will be created.", + self._index_name, + ) + self._create_index(self._index_name) + except (HttpResponseError, ClientAuthenticationError) as error: + msg = f"Failed to authenticate with Azure Search: {error}" + raise AzureAISearchDocumentStoreConfigError(msg) from error + + if self._index_client: + # Get the search client, if index client is initialized + index_fields = self._index_client.get_index(self._index_name).fields + self._index_fields = [field.name for field in index_fields] + self._client = self._index_client.get_search_client(self._index_name) + else: + msg = "Search Index Client is not initialized." + raise AzureAISearchDocumentStoreConfigError(msg) + + return self._client + + def _create_index(self, index_name: str, **kwargs) -> None: + """ + Creates a new search index. + :param index_name: Name of the index to create. If None, the index name from the constructor is used. + :param kwargs: Optional keyword parameters. + """ + + # default fields to create index based on Haystack Document (id, content, embedding) + default_fields = [ + SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True), + SearchableField(name="content", type=SearchFieldDataType.String), + SearchField( + name="embedding", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + hidden=False, + vector_search_dimensions=self._embedding_dimension, + vector_search_profile_name="default-vector-config", + ), + ] + + if not index_name: + index_name = self._index_name + if self._metadata_fields: + default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) + index = SearchIndex( + name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs + ) + if self._index_client: + self._index_client.create_index(index) + + def to_dict(self) -> Dict[str, Any]: + # This is not the best solution to serialise this class but is the fastest to implement. + # Not all kwargs types can be serialised to text so this can fail. We must serialise each + # type explicitly to handle this properly. + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, + api_key=self._api_key.to_dict() if self._api_key is not None else None, + index_name=self._index_name, + embedding_dimension=self._embedding_dimension, + metadata_fields=self._metadata_fields, + vector_search_configuration=self._vector_search_configuration.as_dict(), + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"]) + if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None: + data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration) + return default_from_dict(cls, data) + + def count_documents(self) -> int: + """ + Returns how many documents are present in the search index. + + :returns: list of retrieved documents. + """ + return self.client.get_document_count() + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + """ + Writes the provided documents to search index. + + :param documents: documents to write to the index. + :return: the number of documents added to index. + """ + + def _convert_input_document(documents: Document): + document_dict = asdict(documents) + if not isinstance(document_dict["id"], str): + msg = f"Document id {document_dict['id']} is not a string, " + raise Exception(msg) + index_document = self._convert_haystack_documents_to_azure(document_dict) + + return index_document + + if len(documents) > 0: + if not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + if policy not in [DuplicatePolicy.NONE, DuplicatePolicy.OVERWRITE]: + logger.warning( + f"AzureAISearchDocumentStore only supports `DuplicatePolicy.OVERWRITE`" + f"but got {policy}. Overwriting duplicates is enabled by default." + ) + client = self.client + documents_to_write = [(_convert_input_document(doc)) for doc in documents] + + if documents_to_write != []: + client.upload_documents(documents_to_write) + return len(documents_to_write) + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the search index. + + :param document_ids: ids of the documents to be deleted. + """ + if self.count_documents() == 0: + return + documents = self._get_raw_documents_by_id(document_ids) + if documents: + self.client.delete_documents(documents) + + def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: + return self._convert_search_result_to_documents(self._get_raw_documents_by_id(document_ids)) + + def search_documents(self, search_text: str = "*", top_k: int = 10) -> List[Document]: + """ + Returns all documents that match the provided search_text. + If search_text is None, returns all documents. + :param search_text: the text to search for in the Document list. + :param top_k: Maximum number of documents to return. + :returns: A list of Documents that match the given search_text. + """ + result = self.client.search(search_text=search_text, top=top_k) + return self._convert_search_result_to_documents(list(result)) + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the provided filters. + Filters should be given as a dictionary supporting filtering by metadata. For details on + filters, see the [metadata filtering documentation](https://docs.haystack.deepset.ai/docs/metadata-filtering). + + :param filters: the filters to apply to the document list. + :returns: A list of Documents that match the given filters. + """ + if filters: + normalized_filters = normalize_filters(filters) + result = self.client.search(filter=normalized_filters) + return self._convert_search_result_to_documents(result) + else: + return self.search_documents() + + def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]]) -> List[Document]: + """ + Converts Azure search results to Haystack Documents. + """ + documents = [] + + for azure_doc in azure_docs: + embedding = azure_doc.get("embedding") + if embedding == self._dummy_vector: + embedding = None + + # Anything besides default fields (id, content, and embedding) is considered metadata + meta = { + key: value + for key, value in azure_doc.items() + if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None + } + + # Create the document with meta only if it's non-empty + doc = Document( + id=azure_doc["id"], content=azure_doc["content"], embedding=embedding, meta=meta if meta else {} + ) + + documents.append(doc) + return documents + + def _index_exists(self, index_name: Optional[str]) -> bool: + """ + Check if the index exists in the Azure AI Search service. + + :param index_name: The name of the index to check. + :returns bool: whether the index exists. + """ + + if self._index_client and index_name: + return index_name in self._index_client.list_index_names() + else: + msg = "Index name is required to check if the index exists." + raise ValueError(msg) + + def _get_raw_documents_by_id(self, document_ids: List[str]): + """ + Retrieves all Azure documents with a matching document_ids from the document store. + + :param document_ids: ids of the documents to be retrieved. + :returns: list of retrieved Azure documents. + """ + azure_documents = [] + for doc_id in document_ids: + try: + document = self.client.get_document(doc_id) + azure_documents.append(document) + except ResourceNotFoundError: + logger.warning(f"Document with ID {doc_id} not found.") + return azure_documents + + def _convert_haystack_documents_to_azure(self, document: Dict[str, Any]) -> Dict[str, Any]: + """Map the document keys to fields of search index""" + + # Because Azure Search does not allow dynamic fields, we only include fields that are part of the schema + index_document = {k: v for k, v in {**document, **document.get("meta", {})}.items() if k in self._index_fields} + if index_document["embedding"] is None: + index_document["embedding"] = self._dummy_vector + + return index_document + + def _create_metadata_index_fields(self, metadata: Dict[str, Any]) -> List[SimpleField]: + """Create a list of index fields for storing metadata values.""" + + index_fields = [] + metadata_field_mapping = self._map_metadata_field_types(metadata) + + for key, field_type in metadata_field_mapping.items(): + index_fields.append(SimpleField(name=key, type=field_type, filterable=True)) + + return index_fields + + def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]: + """Map metadata field types to Azure Search field types.""" + + metadata_field_mapping = {} + + for key, value_type in metadata.items(): + + if not key[0].isalpha(): + msg = ( + f"Azure Search index only allows field names starting with letters. " + f"Invalid key: {key} will be dropped." + ) + logger.warning(msg) + continue + + field_type = type_mapping.get(value_type) + if not field_type: + error_message = f"Unsupported field type for key '{key}': {value_type}" + raise ValueError(error_message) + metadata_field_mapping[key] = field_type + + return metadata_field_mapping + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + top_k: int = 10, + fields: Optional[List[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + It uses the vector configuration of the document store. By default it uses the HNSW algorithm + with cosine similarity. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + + :raises ValueError: If `query_embedding` is an empty list + :returns: List of Document that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=top_k, fields="embedding") + result = self.client.search(search_text=None, vector_queries=[vector_query], select=fields, filter=filters) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py new file mode 100644 index 000000000..0fbc80696 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py @@ -0,0 +1,20 @@ +from haystack.document_stores.errors import DocumentStoreError +from haystack.errors import FilterError + + +class AzureAISearchDocumentStoreError(DocumentStoreError): + """Parent class for all AzureAISearchDocumentStore exceptions.""" + + pass + + +class AzureAISearchDocumentStoreConfigError(AzureAISearchDocumentStoreError): + """Raised when a configuration is not valid for a AzureAISearchDocumentStore.""" + + pass + + +class AzureAISearchDocumentStoreFilterError(FilterError): + """Raised when filter is not valid for AzureAISearchDocumentStore.""" + + pass diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py new file mode 100644 index 000000000..650e3f8be --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/filters.py @@ -0,0 +1,112 @@ +from typing import Any, Dict + +from dateutil import parser + +from .errors import AzureAISearchDocumentStoreFilterError + +LOGICAL_OPERATORS = {"AND": "and", "OR": "or", "NOT": "not"} + + +def normalize_filters(filters: Dict[str, Any]) -> str: + """ + Converts Haystack filters in Azure AI Search compatible filters. + """ + if not isinstance(filters, dict): + msg = """Filters must be a dictionary. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + if "field" in filters: + return _parse_comparison_condition(filters) + return _parse_logical_condition(filters) + + +def _parse_logical_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("operator", "conditions") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + operator = condition["operator"] + if operator not in LOGICAL_OPERATORS: + msg = f"Unknown operator {operator}" + raise AzureAISearchDocumentStoreFilterError(msg) + conditions = [] + for c in condition["conditions"]: + # Recursively parse if the condition itself is a logical condition + if isinstance(c, dict) and "operator" in c and c["operator"] in LOGICAL_OPERATORS: + conditions.append(_parse_logical_condition(c)) + else: + # Otherwise, parse it as a comparison condition + conditions.append(_parse_comparison_condition(c)) + + # Format the result based on the operator + if operator == "NOT": + return f"not ({' and '.join([f'({c})' for c in conditions])})" + else: + return f" {LOGICAL_OPERATORS[operator]} ".join([f"({c})" for c in conditions]) + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> str: + missing_keys = [key for key in ("field", "operator", "value") if key not in condition] + if missing_keys: + msg = f"""Missing key(s) {missing_keys} in {condition}. + See https://docs.haystack.deepset.ai/docs/metadata-filtering for details on filters syntax.""" + raise AzureAISearchDocumentStoreFilterError(msg) + + # Remove the "meta." prefix from the field name if present + field = condition["field"][5:] if condition["field"].startswith("meta.") else condition["field"] + operator = condition["operator"] + value = "null" if condition["value"] is None else condition["value"] + + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown operator {operator}. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise AzureAISearchDocumentStoreFilterError(msg) + + return COMPARISON_OPERATORS[operator](field, value) + + +def _eq(field: str, value: Any) -> str: + return f"{field} eq '{value}'" if isinstance(value, str) and value != "null" else f"{field} eq {value}" + + +def _ne(field: str, value: Any) -> str: + return f"not ({field} eq '{value}')" if isinstance(value, str) and value != "null" else f"not ({field} eq {value})" + + +def _in(field: str, value: Any) -> str: + if not isinstance(value, list) or any(not isinstance(v, str) for v in value): + msg = "Azure AI Search only supports a list of strings for 'in' comparators" + raise AzureAISearchDocumentStoreFilterError(msg) + values = ", ".join(map(str, value)) + return f"search.in({field},'{values}')" + + +def _comparison_operator(field: str, value: Any, operator: str) -> str: + _validate_type(value, operator) + return f"{field} {operator} {value}" + + +def _validate_type(value: Any, operator: str) -> None: + """Validates that the value is either an integer, float, or ISO 8601 string.""" + msg = f"Invalid value type for '{operator}' comparator. Supported types are: int, float, or ISO 8601 string." + + if isinstance(value, str): + try: + parser.isoparse(value) + except ValueError as e: + raise AzureAISearchDocumentStoreFilterError(msg) from e + elif not isinstance(value, (int, float)): + raise AzureAISearchDocumentStoreFilterError(msg) + + +COMPARISON_OPERATORS = { + "==": _eq, + "!=": _ne, + "in": _in, + ">": lambda f, v: _comparison_operator(f, v, "gt"), + ">=": lambda f, v: _comparison_operator(f, v, "ge"), + "<": lambda f, v: _comparison_operator(f, v, "lt"), + "<=": lambda f, v: _comparison_operator(f, v, "le"), +} diff --git a/integrations/azure_ai_search/tests/__init__.py b/integrations/azure_ai_search/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/azure_ai_search/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py new file mode 100644 index 000000000..3017c79c2 --- /dev/null +++ b/integrations/azure_ai_search/tests/conftest.py @@ -0,0 +1,68 @@ +import os +import time +import uuid + +import pytest +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.search.documents.indexes import SearchIndexClient +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +# This is the approximate time in seconds it takes for the documents to be available in Azure Search index +SLEEP_TIME_IN_SECONDS = 5 + + +@pytest.fixture() +def sleep_time(): + return SLEEP_TIME_IN_SECONDS + + +@pytest.fixture +def document_store(request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + index_name = f"haystack_test_{uuid.uuid4().hex}" + metadata_fields = getattr(request, "param", {}).get("metadata_fields", None) + + azure_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] + api_key = os.environ["AZURE_SEARCH_API_KEY"] + + client = SearchIndexClient(azure_endpoint, AzureKeyCredential(api_key)) + if index_name in client.list_index_names(): + client.delete_index(index_name) + + store = AzureAISearchDocumentStore( + api_key=api_key, + azure_endpoint=azure_endpoint, + index_name=index_name, + create_index=True, + embedding_dimension=768, + metadata_fields=metadata_fields, + ) + + # Override some methods to wait for the documents to be available + original_write_documents = store.write_documents + + def write_documents_and_wait(documents, policy=DuplicatePolicy.OVERWRITE): + written_docs = original_write_documents(documents, policy) + time.sleep(SLEEP_TIME_IN_SECONDS) + return written_docs + + original_delete_documents = store.delete_documents + + def delete_documents_and_wait(filters): + original_delete_documents(filters) + time.sleep(SLEEP_TIME_IN_SECONDS) + + store.write_documents = write_documents_and_wait + store.delete_documents = delete_documents_and_wait + + yield store + try: + client.delete_index(index_name) + except ResourceNotFoundError: + pass diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py new file mode 100644 index 000000000..1bcd967c6 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +import random +from datetime import datetime, timezone +from typing import List +from unittest.mock import patch + +import pytest +from haystack.dataclasses.document import Document +from haystack.errors import FilterError +from haystack.testing.document_store import ( + CountDocumentsTest, + DeleteDocumentsTest, + FilterDocumentsTest, + WriteDocumentsTest, +) +from haystack.utils.auth import EnvVarSecret, Secret + +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_to_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + document_store = AzureAISearchDocumentStore() + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + }, + } + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_from_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + + data = { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": True, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "embedding_dimension": 768, + "index_name": "default", + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + }, + } + document_store = AzureAISearchDocumentStore.from_dict(data) + assert isinstance(document_store._api_key, EnvVarSecret) + assert isinstance(document_store._azure_endpoint, EnvVarSecret) + assert document_store._index_name == "default" + assert document_store._embedding_dimension == 768 + assert document_store._metadata_fields is None + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init_is_lazy(_mock_azure_search_client): + AzureAISearchDocumentStore(azure_endpoint=Secret.from_token("test_endpoint")) + _mock_azure_search_client.assert_not_called() + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init(_mock_azure_search_client): + + document_store = AzureAISearchDocumentStore( + api_key=Secret.from_token("fake-api-key"), + azure_endpoint=Secret.from_token("fake_endpoint"), + index_name="my_index", + embedding_dimension=15, + metadata_fields={"Title": str, "Pages": int}, + ) + + assert document_store._index_name == "my_index" + assert document_store._embedding_dimension == 15 + assert document_store._metadata_fields == {"Title": str, "Pages": int} + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + + def test_write_documents(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + + # Parametrize the test with metadata fields + @pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"author": str, "publication_year": int, "rating": float}}, + ], + indirect=True, + ) + def test_write_documents_with_meta(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document( + id="1", + meta={"author": "Tom", "publication_year": 2021, "rating": 4.5}, + content="This is a test document.", + ) + ] + document_store.write_documents(docs) + doc = document_store.get_documents_by_id(["1"]) + assert doc[0] == docs[0] + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_fail(self, document_store: AzureAISearchDocumentStore): ... + + @pytest.mark.skip(reason="Azure AI search index overwrites duplicate documents by default") + def test_write_documents_duplicate_skip(self, document_store: AzureAISearchDocumentStore): ... + + +def _random_embeddings(n): + return [round(random.random(), 7) for _ in range(n)] # nosec: S311 + + +TEST_EMBEDDING_1 = _random_embeddings(768) +TEST_EMBEDDING_2 = _random_embeddings(768) + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"name": str, "page": str, "chapter": str, "number": int, "date": datetime}}, + ], + indirect=True, +) +class TestFilters(FilterDocumentsTest): + + # Overriding to change "date" to compatible ISO 8601 format + # and remove incompatible fields (dataframes) for Azure search index + @pytest.fixture + def filterable_docs(self) -> List[Document]: + """Fixture that returns a list of Documents that can be used to test filtering.""" + documents = [] + for i in range(3): + documents.append( + Document( + content=f"A Foo Document {i}", + meta={ + "name": f"name_{i}", + "page": "100", + "chapter": "intro", + "number": 2, + "date": "1969-07-21T20:17:40Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Bar Document {i}", + meta={ + "name": f"name_{i}", + "page": "123", + "chapter": "abstract", + "number": -2, + "date": "1972-12-11T19:54:58Z", + }, + embedding=_random_embeddings(768), + ) + ) + documents.append( + Document( + content=f"A Foobar Document {i}", + meta={ + "name": f"name_{i}", + "page": "90", + "chapter": "conclusion", + "number": -10, + "date": "1989-11-09T17:53:00Z", + }, + embedding=_random_embeddings(768), + ) + ) + + documents.append( + Document(content=f"Doc {i} with zeros emb", meta={"name": "zeros_doc"}, embedding=TEST_EMBEDDING_1) + ) + documents.append( + Document(content=f"Doc {i} with ones emb", meta={"name": "ones_doc"}, embedding=TEST_EMBEDDING_2) + ) + return documents + + # Overriding to compare the documents with the same order + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + + This is used in every test, if a Document Store implementation has a different behaviour + it should override this method. This can happen for example when the Document Store sets + a score to returned Documents. Since we can't know what the score will be, we can't compare + the Documents reliably. + """ + sorted_recieved = sorted(received, key=lambda doc: doc.id) + sorted_expected = sorted(expected, key=lambda doc: doc.id) + assert sorted_recieved == sorted_expected + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_greater_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support dataframes") + def test_comparison_less_than_equal_with_dataframe(self, document_store, filterable_docs): ... + + # Azure search index supports UTC datetime in ISO 8601 format + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">", "value": "1972-12-11T19:54:58Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + > datetime.strptime("1972-12-11T19:54:58Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": ">=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + >= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + < datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and datetime""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + {"field": "meta.date", "operator": "<=", "value": "1969-07-21T20:17:40Z"} + ) + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if d.meta.get("date") is not None + and datetime.strptime(d.meta["date"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + <= datetime.strptime("1969-07-21T20:17:40Z", "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) + ], + ) + + # Override as comparison operators with None/null raise errors + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with > comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": None}) + + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with >= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": None}) + + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + """Test filter_documents() with < comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": None}) + + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + """Test filter_documents() with <= comparator and None""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": None}) + + # Override as Azure AI Search supports 'in' operator only for strings + def test_comparison_in(self, document_store, filterable_docs): + """Test filter_documents() with 'in' comparator""" + document_store.write_documents(filterable_docs) + result = document_store.filter_documents({"field": "meta.page", "operator": "in", "value": ["100", "123"]}) + assert len(result) + expected = [d for d in filterable_docs if d.meta.get("page") is not None and d.meta["page"] in ["100", "123"]] + self.assert_documents_are_equal(result, expected) + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list(self, document_store, filterable_docs): ... + + @pytest.mark.skip(reason="Azure AI search index does not support not in operator") + def test_comparison_not_in_with_with_non_list_iterable(self, document_store, filterable_docs): ... + + def test_missing_condition_operator_key(self, document_store, filterable_docs): + """Test filter_documents() with missing operator key""" + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents( + filters={"conditions": [{"field": "meta.name", "operator": "eq", "value": "test"}]} + ) + + def test_nested_logical_filters(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + {"field": "meta.name", "operator": "==", "value": "name_0"}, + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "!=", "value": 0}, + {"field": "meta.page", "operator": "==", "value": "123"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + {"field": "meta.page", "operator": "==", "value": "90"}, + ], + }, + ], + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + # Ensure all required fields are present in doc.meta + ("name" in doc.meta and doc.meta.get("name") == "name_0") + or ( + all(key in doc.meta for key in ["number", "page"]) + and doc.meta.get("number") != 0 + and doc.meta.get("page") == "123" + ) + or ( + all(key in doc.meta for key in ["page", "chapter"]) + and doc.meta.get("chapter") == "conclusion" + and doc.meta.get("page") == "90" + ) + ) + ], + ) diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py new file mode 100644 index 000000000..d4615ec44 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import List +from unittest.mock import Mock + +import pytest +from azure.core.exceptions import HttpResponseError +from haystack.dataclasses import Document +from haystack.document_stores.types import FilterPolicy +from numpy.random import rand # type: ignore + +from haystack_integrations.components.retrievers.azure_ai_search import AzureAISearchEmbeddingRetriever +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=AzureAISearchDocumentStore) + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + retriever = AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="replace") + assert retriever._filter_policy == FilterPolicy.REPLACE + + with pytest.raises(ValueError): + AzureAISearchEmbeddingRetriever(document_store=mock_store, filter_policy="unknown") + + +def test_to_dict(): + document_store = AzureAISearchDocumentStore(hosts="some fake host") + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + + +def test_from_dict(): + data = { + "type": "haystack_integrations.components.retrievers.azure_ai_search.embedding_retriever.AzureAISearchEmbeddingRetriever", # noqa: E501 + "init_parameters": { + "filters": {}, + "top_k": 10, + "document_store": { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", # noqa: E501 + "init_parameters": { + "azure_endpoint": { + "type": "env_var", + "env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], + "strict": True, + }, + "api_key": {"type": "env_var", "env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "hosts": "some fake host", + }, + }, + "filter_policy": "replace", + }, + } + retriever = AzureAISearchEmbeddingRetriever.from_dict(data) + assert isinstance(retriever._document_store, AzureAISearchDocumentStore) + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._filter_policy == FilterPolicy.REPLACE + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +@pytest.mark.integration +class TestRetriever: + + def test_run(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + res = retriever.run(query_embedding=[0.1] * 768) + assert res["documents"] == docs + + def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 200 + [0.1] * 300 + [0.2] * 268 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="This is first document", embedding=most_similar_embedding), + Document(content="This is second document", embedding=second_best_embedding), + Document(content="This is thrid document", embedding=another_embedding), + ] + + document_store.write_documents(docs) + retriever = AzureAISearchEmbeddingRetriever(document_store=document_store) + results = retriever.run(query_embedding=query_embedding) + assert results["documents"][0].content == "This is first document" + + def test_empty_query_embedding(self, document_store: AzureAISearchDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: AzureAISearchDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(HttpResponseError): + document_store._embedding_retrieval(query_embedding=query_embedding) diff --git a/integrations/deepeval/CHANGELOG.md b/integrations/deepeval/CHANGELOG.md new file mode 100644 index 000000000..a296c7cfa --- /dev/null +++ b/integrations/deepeval/CHANGELOG.md @@ -0,0 +1,35 @@ +# Changelog + +## [integrations/deepeval-v0.1.2] - 2024-11-14 + +### ๐Ÿš€ Features + +- Implement `DeepEvalEvaluator` (#346) + +### ๐Ÿ› Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme +- Deepeval - pin indirect dependencies based on python version (#1187) + +### ๐Ÿ“š Documentation + +- Update paths and titles (#397) +- Update category slug (#442) +- Update `deepeval-haystack` docstrings (#527) +- Disable-class-def (#556) + +### ๐Ÿงช Testing + +- Do not retry tests in `hatch run test` command (#954) + +### โš™๏ธ Miscellaneous Tasks + +- Exculde evaluator private classes in API docs (#392) +- Retry tests to reduce flakyness (#836) +- Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) + + diff --git a/integrations/deepeval/pyproject.toml b/integrations/deepeval/pyproject.toml index 6ef64387b..78cc2542a 100644 --- a/integrations/deepeval/pyproject.toml +++ b/integrations/deepeval/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai", "deepeval==0.20.57"] +dependencies = ["haystack-ai", "deepeval==0.20.57", "langchain<0.3; python_version < '3.10'"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/deepeval" diff --git a/integrations/fastembed/CHANGELOG.md b/integrations/fastembed/CHANGELOG.md index b5c194d8b..5dd62d130 100644 --- a/integrations/fastembed/CHANGELOG.md +++ b/integrations/fastembed/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/fastembed-v1.4.0] - 2024-11-13 + +### โš™๏ธ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/fastembed-v1.3.0] - 2024-10-07 ### ๐Ÿš€ Features diff --git a/integrations/fastembed/examples/ranker_example.py b/integrations/fastembed/examples/ranker_example.py new file mode 100644 index 000000000..7a31e4646 --- /dev/null +++ b/integrations/fastembed/examples/ranker_example.py @@ -0,0 +1,22 @@ +from haystack import Document + +from haystack_integrations.components.rankers.fastembed import FastembedRanker + +query = "Who is maintaining Qdrant?" +documents = [ + Document( + content="This is built to be faster and lighter than other embedding libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="fastembed is supported by and maintained by Qdrant."), +] + +ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2") +ranker.warm_up() +reranked_documents = ranker.run(query=query, documents=documents)["documents"] + + +print(reranked_documents["documents"][0]) + +# Document(id=..., +# content: 'fastembed is supported by and maintained by Qdrant.', +# score: 5.472434997558594..) diff --git a/integrations/fastembed/pydoc/config.yml b/integrations/fastembed/pydoc/config.yml index aad50e52c..8ab538cf8 100644 --- a/integrations/fastembed/pydoc/config.yml +++ b/integrations/fastembed/pydoc/config.yml @@ -6,7 +6,8 @@ loaders: "haystack_integrations.components.embedders.fastembed.fastembed_document_embedder", "haystack_integrations.components.embedders.fastembed.fastembed_text_embedder", "haystack_integrations.components.embedders.fastembed.fastembed_sparse_document_embedder", - "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder" + "haystack_integrations.components.embedders.fastembed.fastembed_sparse_text_embedder", + "haystack_integrations.components.rankers.fastembed.ranker" ] ignore_when_discovered: ["__init__"] processors: diff --git a/integrations/fastembed/pyproject.toml b/integrations/fastembed/pyproject.toml index b9f1f6cfd..abae78d8a 100644 --- a/integrations/fastembed/pyproject.toml +++ b/integrations/fastembed/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.2.5", "onnxruntime<1.20.0"] +dependencies = ["haystack-ai>=2.0.1", "fastembed>=0.4.2"] [project.urls] Source = "https://github.com/deepset-ai/haystack-core-integrations" @@ -154,6 +154,10 @@ omit = ["*/tests/*", "*/__init__.py"] show_missing = true exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + [[tool.mypy.overrides]] module = [ "haystack.*", diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py new file mode 100644 index 000000000..ece5e858b --- /dev/null +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/__init__.py @@ -0,0 +1,3 @@ +from .ranker import FastembedRanker + +__all__ = ["FastembedRanker"] diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py new file mode 100644 index 000000000..8f077a30c --- /dev/null +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py @@ -0,0 +1,202 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict, logging + +from fastembed.rerank.cross_encoder import TextCrossEncoder + +logger = logging.getLogger(__name__) + + +@component +class FastembedRanker: + """ + Ranks Documents based on their similarity to the query using + [Fastembed models](https://qdrant.github.io/fastembed/examples/Supported_Models/). + + Documents are indexed from most to least semantically relevant to the query. + + Usage example: + ```python + from haystack import Document + from haystack_integrations.components.rankers.fastembed import FastembedRanker + + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2) + + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + print(output["documents"][0].content) + + # Berlin + ``` + """ + + def __init__( + self, + model_name: str = "Xenova/ms-marco-MiniLM-L-6-v2", + top_k: int = 10, + cache_dir: Optional[str] = None, + threads: Optional[int] = None, + batch_size: int = 64, + parallel: Optional[int] = None, + local_files_only: bool = False, + meta_fields_to_embed: Optional[List[str]] = None, + meta_data_separator: str = "\n", + ): + """ + Creates an instance of the 'FastembedRanker'. + + :param model_name: Fastembed model name. Check the list of supported models in the [Fastembed documentation](https://qdrant.github.io/fastembed/examples/Supported_Models/). + :param top_k: The maximum number of documents to return. + :param cache_dir: The path to the cache directory. + Can be set using the `FASTEMBED_CACHE_PATH` env variable. + Defaults to `fastembed_cache` in the system's temp directory. + :param threads: The number of threads single onnxruntime session can use. Defaults to None. + :param batch_size: Number of strings to encode at once. + :param parallel: + If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets. + If 0, use all available cores. + If None, don't use data-parallel processing, use default onnxruntime threading instead. + :param local_files_only: If `True`, only use the model files in the `cache_dir`. + :param meta_fields_to_embed: List of meta fields that should be concatenated + with the document content for reranking. + :param meta_data_separator: Separator used to concatenate the meta fields + to the Document content. + """ + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + self.model_name = model_name + self.top_k = top_k + self.cache_dir = cache_dir + self.threads = threads + self.batch_size = batch_size + self.parallel = parallel + self.local_files_only = local_files_only + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.meta_data_separator = meta_data_separator + self._model = None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model_name=self.model_name, + top_k=self.top_k, + cache_dir=self.cache_dir, + threads=self.threads, + batch_size=self.batch_size, + parallel=self.parallel, + local_files_only=self.local_files_only, + meta_fields_to_embed=self.meta_fields_to_embed, + meta_data_separator=self.meta_data_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "FastembedRanker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initializes the component. + """ + if self._model is None: + self._model = TextCrossEncoder( + model_name=self.model_name, + cache_dir=self.cache_dir, + threads=self.threads, + local_files_only=self.local_files_only, + ) + + def _prepare_fastembed_input_docs(self, documents: List[Document]) -> List[str]: + """ + Prepare the input by concatenating the document text with the metadata fields specified. + :param documents: The list of Document objects. + + :return: A list of strings to be given as input to Fastembed model. + """ + concatenated_input_list = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta.get(key) + ] + concatenated_input = self.meta_data_separator.join([*meta_values_to_embed, doc.content or ""]) + concatenated_input_list.append(concatenated_input) + + return concatenated_input_list + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Returns a list of documents ranked by their similarity to the given query, using FastEmbed. + + :param query: + The input query to compare the documents to. + :param documents: + A list of documents to be ranked. + :param top_k: + The maximum number of documents to return. + + :returns: + A dictionary with the following keys: + - `documents`: A list of documents closest to the query, sorted from most similar to least similar. + + :raises ValueError: If `top_k` is not > 0. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = "FastembedRanker expects a list of Documents as input. " + raise TypeError(msg) + if query == "": + msg = "No query provided" + raise ValueError(msg) + + if not documents: + return {"documents": []} + + top_k = top_k or self.top_k + if top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + if self._model is None: + msg = "The ranker model has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) + + fastembed_input_docs = self._prepare_fastembed_input_docs(documents) + + scores = list( + self._model.rerank( + query=query, + documents=fastembed_input_docs, + batch_size=self.batch_size, + parallel=self.parallel, + ) + ) + + # Combine the two lists into a single list of tuples + doc_scores = list(zip(documents, scores)) + + # Sort the list of tuples by the score in descending order + sorted_doc_scores = sorted(doc_scores, key=lambda x: x[1], reverse=True) + + # Get the top_k documents + top_k_documents = [] + for doc, score in sorted_doc_scores[:top_k]: + doc.score = score + top_k_documents.append(doc) + + return {"documents": top_k_documents} diff --git a/integrations/fastembed/tests/test_fastembed_ranker.py b/integrations/fastembed/tests/test_fastembed_ranker.py new file mode 100644 index 000000000..e38229c87 --- /dev/null +++ b/integrations/fastembed/tests/test_fastembed_ranker.py @@ -0,0 +1,292 @@ +from unittest.mock import MagicMock + +import pytest +from haystack import Document, default_from_dict + +from haystack_integrations.components.rankers.fastembed.ranker import ( + FastembedRanker, +) + + +class TestFastembedRanker: + def test_init_default(self): + """ + Test default initialization parameters for FastembedRanker. + """ + ranker = FastembedRanker(model_name="BAAI/bge-reranker-base") + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.top_k == 10 + assert ranker.cache_dir is None + assert ranker.threads is None + assert ranker.batch_size == 64 + assert ranker.parallel is None + assert not ranker.local_files_only + assert ranker.meta_fields_to_embed == [] + assert ranker.meta_data_separator == "\n" + + def test_init_with_parameters(self): + """ + Test custom initialization parameters for FastembedRanker. + """ + ranker = FastembedRanker( + model_name="BAAI/bge-reranker-base", + top_k=64, + cache_dir="fake_dir", + threads=2, + batch_size=50, + parallel=1, + local_files_only=True, + meta_fields_to_embed=["test_field"], + meta_data_separator=" | ", + ) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.top_k == 64 + assert ranker.cache_dir == "fake_dir" + assert ranker.threads == 2 + assert ranker.batch_size == 50 + assert ranker.parallel == 1 + assert ranker.local_files_only + assert ranker.meta_fields_to_embed == ["test_field"] + assert ranker.meta_data_separator == " | " + + def test_init_with_incorrect_input(self): + """ + Test for checking incorrect input format on init + """ + with pytest.raises( + ValueError, + match="top_k must be > 0, but got 0", + ): + FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2", top_k=0) + + with pytest.raises( + ValueError, + match="top_k must be > 0, but got -3", + ): + FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2", top_k=-3) + + def test_to_dict(self): + """ + Test serialization of FastembedRanker to a dictionary, using default initialization parameters. + """ + ranker = FastembedRanker(model_name="BAAI/bge-reranker-base") + ranker_dict = ranker.to_dict() + assert ranker_dict == { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "top_k": 10, + "cache_dir": None, + "threads": None, + "batch_size": 64, + "parallel": None, + "local_files_only": False, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + + def test_to_dict_with_custom_init_parameters(self): + """ + Test serialization of FastembedRanker to a dictionary, using custom initialization parameters. + """ + ranker = FastembedRanker( + model_name="BAAI/bge-reranker-base", + cache_dir="fake_dir", + threads=2, + top_k=5, + batch_size=50, + parallel=1, + local_files_only=True, + meta_fields_to_embed=["test_field"], + meta_data_separator=" | ", + ) + ranker_dict = ranker.to_dict() + assert ranker_dict == { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": "fake_dir", + "threads": 2, + "top_k": 5, + "batch_size": 50, + "parallel": 1, + "local_files_only": True, + "meta_fields_to_embed": ["test_field"], + "meta_data_separator": " | ", + }, + } + + def test_from_dict(self): + """ + Test deserialization of FastembedRanker from a dictionary, using default initialization parameters. + """ + ranker_dict = { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": None, + "threads": None, + "top_k": 5, + "batch_size": 50, + "parallel": None, + "local_files_only": False, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + ranker = default_from_dict(FastembedRanker, ranker_dict) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.cache_dir is None + assert ranker.threads is None + assert ranker.top_k == 5 + assert ranker.batch_size == 50 + assert ranker.parallel is None + assert not ranker.local_files_only + assert ranker.meta_fields_to_embed == [] + assert ranker.meta_data_separator == "\n" + + def test_from_dict_with_custom_init_parameters(self): + """ + Test deserialization of FastembedRanker from a dictionary, using custom initialization parameters. + """ + ranker_dict = { + "type": "haystack_integrations.components.rankers.fastembed.ranker.FastembedRanker", + "init_parameters": { + "model_name": "BAAI/bge-reranker-base", + "cache_dir": "fake_dir", + "threads": 2, + "top_k": 5, + "batch_size": 50, + "parallel": 1, + "local_files_only": True, + "meta_fields_to_embed": ["test_field"], + "meta_data_separator": " | ", + }, + } + ranker = default_from_dict(FastembedRanker, ranker_dict) + assert ranker.model_name == "BAAI/bge-reranker-base" + assert ranker.cache_dir == "fake_dir" + assert ranker.threads == 2 + assert ranker.top_k == 5 + assert ranker.batch_size == 50 + assert ranker.parallel == 1 + assert ranker.local_files_only + assert ranker.meta_fields_to_embed == ["test_field"] + assert ranker.meta_data_separator == " | " + + def test_run_incorrect_input_format(self): + """ + Test for checking incorrect input format. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + ranker._model = "mock_model" + + query = "query" + string_input = "text" + list_integers_input = [1, 2, 3] + list_document = [Document("Document 1")] + + with pytest.raises( + TypeError, + match="FastembedRanker expects a list of Documents as input.", + ): + ranker.run(query=query, documents=string_input) + + with pytest.raises( + TypeError, + match="FastembedRanker expects a list of Documents as input.", + ): + ranker.run(query=query, documents=list_integers_input) + + with pytest.raises( + ValueError, + match="No query provided", + ): + ranker.run(query="", documents=list_document) + + with pytest.raises( + ValueError, + match="top_k must be > 0, but got -3", + ): + ranker.run(query=query, documents=list_document, top_k=-3) + + def test_run_no_warmup(self): + """ + Test for checking error when calling without a warmup. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + + query = "query" + list_document = [Document("Document 1")] + + with pytest.raises( + RuntimeError, + ): + ranker.run(query=query, documents=list_document) + + def test_run_empty_document_list(self): + """ + Test for no error when sending no documents. + """ + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2") + ranker._model = "mock_model" + + query = "query" + list_document = [] + + result = ranker.run(query=query, documents=list_document) + assert len(result["documents"]) == 0 + + def test_embed_metadata(self): + """ + Tests the embedding of metadata fields in document content for ranking. + """ + ranker = FastembedRanker( + model_name="model_name", + meta_fields_to_embed=["meta_field"], + ) + ranker._model = MagicMock() + + documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)] + query = "test" + ranker.run(query=query, documents=documents) + + ranker._model.rerank.assert_called_once_with( + query=query, + documents=[ + "meta_value 0\ndocument-number 0", + "meta_value 1\ndocument-number 1", + "meta_value 2\ndocument-number 2", + "meta_value 3\ndocument-number 3", + "meta_value 4\ndocument-number 4", + ], + batch_size=64, + parallel=None, + ) + + @pytest.mark.integration + def test_run(self): + ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2) + ranker.warm_up() + + query = "Who is maintaining Qdrant?" + documents = [ + Document( + content="This is built to be faster and lighter than other embedding \ +libraries e.g. Transformers, Sentence-Transformers, etc." + ), + Document(content="This is some random input"), + Document(content="fastembed is supported by and maintained by Qdrant."), + ] + + result = ranker.run(query=query, documents=documents) + + assert len(result["documents"]) == 2 + first_document = result["documents"][0] + second_document = result["documents"][1] + + assert isinstance(first_document, Document) + assert isinstance(second_document, Document) + assert first_document.content == "fastembed is supported by and maintained by Qdrant." + assert first_document.score > second_document.score diff --git a/integrations/google_ai/CHANGELOG.md b/integrations/google_ai/CHANGELOG.md index 3f3ecaf79..8f09db79a 100644 --- a/integrations/google_ai/CHANGELOG.md +++ b/integrations/google_ai/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [integrations/google_ai-v3.0.0] - 2024-11-12 + +### ๐Ÿ› Bug Fixes + +- `GoogleAIGeminiGenerator` - remove support for tools and change output type (#1177) + +### โš™๏ธ Miscellaneous Tasks + +- Adopt uv as installer (#1142) + ## [integrations/google_ai-v2.0.1] - 2024-10-15 ### ๐Ÿš€ Features diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py index 218e16c4c..b032169df 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/gemini.py @@ -2,7 +2,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import google.generativeai as genai -from google.ai.generativelanguage import Content, Part, Tool +from google.ai.generativelanguage import Content, Part from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import GenerateContentResponse, HarmBlockThreshold, HarmCategory from haystack.core.component import component @@ -62,6 +62,16 @@ class GoogleAIGeminiGenerator: ``` """ + def __new__(cls, *_, **kwargs): + if "tools" in kwargs: + msg = ( + "GoogleAIGeminiGenerator does not support the `tools` parameter. " + " Use GoogleAIGeminiChatGenerator instead." + ) + raise TypeError(msg) + return super(GoogleAIGeminiGenerator, cls).__new__(cls) # noqa: UP008 + # super(__class__, cls) is needed because of the component decorator + def __init__( self, *, @@ -69,7 +79,6 @@ def __init__( model: str = "gemini-1.5-flash", generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, - tools: Optional[List[Tool]] = None, streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, ): """ @@ -86,7 +95,6 @@ def __init__( :param safety_settings: The safety settings to use. A dictionary with `HarmCategory` as keys and `HarmBlockThreshold` as values. For more information, see [the API reference](https://ai.google.dev/api) - :param tools: A list of Tool objects that can be used for [Function calling](https://ai.google.dev/docs/function_calling). :param streaming_callback: A callback function that is called when a new token is received from the stream. The callback function accepts StreamingChunk as an argument. """ @@ -96,8 +104,7 @@ def __init__( self._model_name = model self._generation_config = generation_config self._safety_settings = safety_settings - self._tools = tools - self._model = GenerativeModel(self._model_name, tools=self._tools) + self._model = GenerativeModel(self._model_name) self._streaming_callback = streaming_callback def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: @@ -126,11 +133,8 @@ def to_dict(self) -> Dict[str, Any]: model=self._model_name, generation_config=self._generation_config, safety_settings=self._safety_settings, - tools=self._tools, streaming_callback=callback_name, ) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.serialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -149,8 +153,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "GoogleAIGeminiGenerator": """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) - if (tools := data["init_parameters"].get("tools")) is not None: - data["init_parameters"]["tools"] = [Tool.deserialize(t) for t in tools] if (generation_config := data["init_parameters"].get("generation_config")) is not None: data["init_parameters"]["generation_config"] = GenerationConfig(**generation_config) if (safety_settings := data["init_parameters"].get("safety_settings")) is not None: @@ -178,7 +180,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: msg = f"Unsupported type {type(part)} for part {part}" raise ValueError(msg) - @component.output_types(replies=List[Union[str, Dict[str, str]]]) + @component.output_types(replies=List[str]) def run( self, parts: Variadic[Union[str, ByteStream, Part]], @@ -192,7 +194,7 @@ def run( :param streaming_callback: A callback function that is called when a new token is received from the stream. :returns: A dictionary containing the following key: - - `replies`: A list of strings or dictionaries with function calls. + - `replies`: A list of strings containing the generated responses. """ # check if streaming_callback is passed @@ -221,12 +223,6 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[str]: for part in candidate.content.parts: if part.text != "": replies.append(part.text) - elif part.function_call is not None: - function_call = { - "name": part.function_call.name, - "args": dict(part.function_call.args.items()), - } - replies.append(function_call) return replies def _get_stream_response( diff --git a/integrations/google_ai/tests/generators/test_gemini.py b/integrations/google_ai/tests/generators/test_gemini.py index 7206b7a43..07d194a59 100644 --- a/integrations/google_ai/tests/generators/test_gemini.py +++ b/integrations/google_ai/tests/generators/test_gemini.py @@ -2,32 +2,12 @@ from unittest.mock import patch import pytest -from google.ai.generativelanguage import FunctionDeclaration, Tool from google.generativeai import GenerationConfig, GenerativeModel from google.generativeai.types import HarmBlockThreshold, HarmCategory from haystack.dataclasses import StreamingChunk from haystack_integrations.components.generators.google_ai import GoogleAIGeminiGenerator -GET_CURRENT_WEATHER_FUNC = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, -) - def test_init(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -41,40 +21,24 @@ def test_init(monkeypatch): top_k=0.5, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure") as mock_genai_configure: gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) mock_genai_configure.assert_called_once_with(api_key="test") assert gemini._model_name == "gemini-1.5-flash" assert gemini._generation_config == generation_config assert gemini._safety_settings == safety_settings - assert gemini._tools == [tool] assert isinstance(gemini._model, GenerativeModel) +def test_init_fails_with_tools(): + with pytest.raises(TypeError, match="GoogleAIGeminiGenerator does not support the `tools` parameter."): + GoogleAIGeminiGenerator(tools=["tool1", "tool2"]) + + def test_to_dict(monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "test") @@ -88,7 +52,6 @@ def test_to_dict(monkeypatch): "generation_config": None, "safety_settings": None, "streaming_callback": None, - "tools": None, }, } @@ -105,32 +68,11 @@ def test_to_dict_with_param(monkeypatch): top_k=2, ) safety_settings = {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - get_current_weather_func = FunctionDeclaration( - name="get_current_weather", - description="Get the current weather in a given location", - parameters={ - "type_": "OBJECT", - "properties": { - "location": {"type_": "STRING", "description": "The city and state, e.g. San Francisco, CA"}, - "unit": { - "type_": "STRING", - "enum": [ - "celsius", - "fahrenheit", - ], - }, - }, - "required": ["location"], - }, - ) - - tool = Tool(function_declarations=[get_current_weather_func]) with patch("haystack_integrations.components.generators.google_ai.gemini.genai.configure"): gemini = GoogleAIGeminiGenerator( generation_config=generation_config, safety_settings=safety_settings, - tools=[tool], ) assert gemini.to_dict() == { "type": "haystack_integrations.components.generators.google_ai.gemini.GoogleAIGeminiGenerator", @@ -147,11 +89,6 @@ def test_to_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } @@ -175,11 +112,6 @@ def test_from_dict_with_param(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -194,7 +126,6 @@ def test_from_dict_with_param(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) @@ -217,11 +148,6 @@ def test_from_dict(monkeypatch): }, "safety_settings": {10: 3}, "streaming_callback": None, - "tools": [ - b"\n\xad\x01\n\x13get_current_weather\x12+Get the current weather in a given location\x1ai" - b"\x08\x06:\x1f\n\x04unit\x12\x17\x08\x01*\x07celsius*\nfahrenheit::\n\x08location\x12.\x08" - b"\x01\x1a*The city and state, e.g. San Francisco, CAB\x08location" - ], }, } ) @@ -236,7 +162,6 @@ def test_from_dict(monkeypatch): top_k=0.5, ) assert gemini._safety_settings == {HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH} - assert gemini._tools == [Tool(function_declarations=[GET_CURRENT_WEATHER_FUNC])] assert isinstance(gemini._model, GenerativeModel) diff --git a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py index 94064a0d1..c9c8a354e 100644 --- a/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py +++ b/integrations/langfuse/src/haystack_integrations/tracing/langfuse/tracer.py @@ -1,5 +1,7 @@ import contextlib +import logging import os +from datetime import datetime from typing import Any, Dict, Iterator, Optional, Union from haystack.components.generators.openai_utils import _convert_message_to_openai_format @@ -9,6 +11,8 @@ import langfuse +logger = logging.getLogger(__name__) + HAYSTACK_LANGFUSE_ENFORCE_FLUSH_ENV_VAR = "HAYSTACK_LANGFUSE_ENFORCE_FLUSH" _SUPPORTED_GENERATORS = [ "AzureOpenAIGenerator", @@ -148,7 +152,18 @@ def trace(self, operation_name: str, tags: Optional[Dict[str, Any]] = None) -> I replies = span._data.get("haystack.component.output", {}).get("replies") if replies: meta = replies[0].meta - span._span.update(usage=meta.get("usage") or None, model=meta.get("model")) + completion_start_time = meta.get("completion_start_time") + if completion_start_time: + try: + completion_start_time = datetime.fromisoformat(completion_start_time) + except ValueError: + logger.error(f"Failed to parse completion_start_time: {completion_start_time}") + completion_start_time = None + span._span.update( + usage=meta.get("usage") or None, + model=meta.get("model"), + completion_start_time=completion_start_time, + ) pipeline_input = tags.get("haystack.pipeline.input_data", None) if pipeline_input: diff --git a/integrations/langfuse/tests/test_tracer.py b/integrations/langfuse/tests/test_tracer.py index c6bf4acdf..9ee8e5dc4 100644 --- a/integrations/langfuse/tests/test_tracer.py +++ b/integrations/langfuse/tests/test_tracer.py @@ -1,9 +1,43 @@ -import os +import datetime from unittest.mock import MagicMock, Mock, patch +from haystack.dataclasses import ChatMessage from haystack_integrations.tracing.langfuse.tracer import LangfuseTracer +class MockSpan: + def __init__(self): + self._data = {} + self._span = self + self.operation_name = "operation_name" + + def raw_span(self): + return self + + def span(self, name=None): + # assert correct operation name passed to the span + assert name == "operation_name" + return self + + def update(self, **kwargs): + self._data.update(kwargs) + + def generation(self, name=None): + return self + + def end(self): + pass + + +class MockTracer: + + def trace(self, name, **kwargs): + return MockSpan() + + def flush(self): + pass + + class TestLangfuseTracer: # LangfuseTracer can be initialized with a Langfuse instance, a name and a boolean value for public. @@ -45,37 +79,6 @@ def test_create_new_span(self): # check that update method is called on the span instance with the provided key value pairs def test_update_span_with_pipeline_input_output_data(self): - class MockTracer: - - def trace(self, name, **kwargs): - return MockSpan() - - def flush(self): - pass - - class MockSpan: - def __init__(self): - self._data = {} - self._span = self - self.operation_name = "operation_name" - - def raw_span(self): - return self - - def span(self, name=None): - # assert correct operation name passed to the span - assert name == "operation_name" - return self - - def update(self, **kwargs): - self._data.update(kwargs) - - def generation(self, name=None): - return self - - def end(self): - pass - tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.input_data": "hello"}) as span: assert span.raw_span()._data["metadata"] == {"haystack.pipeline.input_data": "hello"} @@ -83,6 +86,40 @@ def end(self): with tracer.trace(operation_name="operation_name", tags={"haystack.pipeline.output_data": "bye"}) as span: assert span.raw_span()._data["metadata"] == {"haystack.pipeline.output_data": "bye"} + def test_trace_generation(self): + tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + tags = { + "haystack.component.type": "OpenAIChatGenerator", + "haystack.component.output": { + "replies": [ + ChatMessage.from_assistant( + "", meta={"completion_start_time": "2021-07-27T16:02:08.012345", "model": "test_model"} + ) + ] + }, + } + with tracer.trace(operation_name="operation_name", tags=tags) as span: + ... + assert span.raw_span()._data["usage"] is None + assert span.raw_span()._data["model"] == "test_model" + assert span.raw_span()._data["completion_start_time"] == datetime.datetime(2021, 7, 27, 16, 2, 8, 12345) + + def test_trace_generation_invalid_start_time(self): + tracer = LangfuseTracer(tracer=MockTracer(), name="Haystack", public=False) + tags = { + "haystack.component.type": "OpenAIChatGenerator", + "haystack.component.output": { + "replies": [ + ChatMessage.from_assistant("", meta={"completion_start_time": "foobar", "model": "test_model"}), + ] + }, + } + with tracer.trace(operation_name="operation_name", tags=tags) as span: + ... + assert span.raw_span()._data["usage"] is None + assert span.raw_span()._data["model"] == "test_model" + assert span.raw_span()._data["completion_start_time"] is None + def test_update_span_gets_flushed_by_default(self): tracer_mock = Mock() diff --git a/integrations/nvidia/CHANGELOG.md b/integrations/nvidia/CHANGELOG.md index f66536fe5..75b31d033 100644 --- a/integrations/nvidia/CHANGELOG.md +++ b/integrations/nvidia/CHANGELOG.md @@ -1,15 +1,17 @@ # Changelog -## [unreleased] +## [integrations/nvidia-v0.1.0] - 2024-11-13 ### ๐Ÿš€ Features - Update default embedding model to nvidia/nv-embedqa-e5-v5 (#1015) - Add NVIDIA NIM ranker support (#1023) +- Raise error when attempting to embed empty documents/strings with Nvidia embedders (#1118) ### ๐Ÿ› Bug Fixes - Lints in `nvidia-haystack` (#993) +- Missing Nvidia embedding truncate mode (#1043) ### ๐Ÿšœ Refactor @@ -27,6 +29,8 @@ - Retry tests to reduce flakyness (#836) - Update ruff invocation to include check parameter (#853) +- Update ruff linting scripts and settings (#1105) +- Adopt uv as installer (#1142) ### Docs diff --git a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py index 46c736883..1553d1ac3 100644 --- a/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py +++ b/integrations/nvidia/src/haystack_integrations/components/rankers/nvidia/ranker.py @@ -12,6 +12,7 @@ _MODEL_ENDPOINT_MAP = { "nvidia/nv-rerankqa-mistral-4b-v3": "https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking", + "nvidia/llama-3.2-nv-rerankqa-1b-v1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v1/reranking", } diff --git a/integrations/weaviate/CHANGELOG.md b/integrations/weaviate/CHANGELOG.md index ec15cbeef..7f620c3a0 100644 --- a/integrations/weaviate/CHANGELOG.md +++ b/integrations/weaviate/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [integrations/weaviate-v4.0.2] - 2024-11-13 + +### ๐Ÿ› Bug Fixes + +- Dependency for weaviate document store (#1186) + +## [integrations/weaviate-v4.0.1] - 2024-11-11 + ## [integrations/weaviate-v4.0.0] - 2024-10-18 ### ๐Ÿ› Bug Fixes diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 70b045bc4..e88397df9 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -26,7 +26,6 @@ classifiers = [ dependencies = [ "haystack-ai", "weaviate-client>=4.9", - "haystack-pydoc-tools", "python-dateutil", ] @@ -48,7 +47,7 @@ git_describe_command = 'git describe --tags --match="integrations/weaviate-v[0-9 [tool.hatch.envs.default] installer = "uv" -dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython"] +dependencies = ["coverage[toml]>=6.5", "pytest", "pytest-rerunfailures", "ipython", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index e312b1473..6acf0156e 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -286,6 +286,14 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]: # The embedding vector is stored separately from the rest of the data del data["embedding"] + # _split_overlap meta field is unsupported because of a bug + # https://github.com/deepset-ai/haystack-core-integrations/issues/1172 + if "_split_overlap" in data: + data.pop("_split_overlap") + logger.warning( + "Document %s has the unsupported `_split_overlap` meta field. It will be ignored.", data["_original_id"] + ) + if "sparse_embedding" in data: sparse_embedding = data.pop("sparse_embedding", None) if sparse_embedding: diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 70f1e1eb2..00af322e4 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -508,6 +508,30 @@ def test_comparison_less_than_equal_with_iso_date(self, document_store, filterab def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): return super().test_comparison_not_equal_with_dataframe(document_store, filterable_docs) + def test_meta_split_overlap_is_skipped(self, document_store): + doc = Document( + content="The moonlight shimmered ", + meta={ + "source_id": "62049ba1d1e1d5ebb1f6230b0b00c5356b8706c56e0b9c36b1dfc86084cd75f0", + "page_number": 1, + "split_id": 0, + "split_idx_start": 0, + "_split_overlap": [ + {"doc_id": "68ed48ba830048c5d7815874ed2de794722e6d10866b6c55349a914fd9a0df65", "range": (0, 20)} + ], + }, + ) + document_store.write_documents([doc]) + + written_doc = document_store.filter_documents()[0] + + assert written_doc.content == "The moonlight shimmered " + assert written_doc.meta["source_id"] == "62049ba1d1e1d5ebb1f6230b0b00c5356b8706c56e0b9c36b1dfc86084cd75f0" + assert written_doc.meta["page_number"] == 1.0 + assert written_doc.meta["split_id"] == 0.0 + assert written_doc.meta["split_idx_start"] == 0.0 + assert "_split_overlap" not in written_doc.meta + def test_bm25_retrieval(self, document_store): document_store.write_documents( [