From 3ebaa496e7de10a1b5c4fa096997497ab0136e03 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Wed, 2 Oct 2024 21:54:28 +0200 Subject: [PATCH] CRUD operations and tests for document store --- integrations/azure_ai_search/.gitignore | 163 ++++++++ integrations/azure_ai_search/CHANGELOG.md | 99 +++++ integrations/azure_ai_search/LICENSE | 201 ++++++++++ integrations/azure_ai_search/README.md | 32 ++ integrations/azure_ai_search/pydoc/config.yml | 32 ++ integrations/azure_ai_search/pyproject.toml | 158 ++++++++ .../retrievers/embedding_retriever.py | 105 +++++ .../azure_ai_search/__init__.py | 6 + .../azure_ai_search/document_store.py | 378 ++++++++++++++++++ .../document_stores/azure_ai_search/errors.py | 13 + .../azure_ai_search/tests/__init__.py | 3 + .../azure_ai_search/tests/conftest.py | 67 ++++ .../tests/test_document_store.py | 118 ++++++ 13 files changed, 1375 insertions(+) create mode 100644 integrations/azure_ai_search/.gitignore create mode 100644 integrations/azure_ai_search/CHANGELOG.md create mode 100644 integrations/azure_ai_search/LICENSE create mode 100644 integrations/azure_ai_search/README.md create mode 100644 integrations/azure_ai_search/pydoc/config.yml create mode 100644 integrations/azure_ai_search/pyproject.toml create mode 100644 integrations/azure_ai_search/src/haystack_integrations/components/azure_ai_search/retrievers/embedding_retriever.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py create mode 100644 integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py create mode 100644 integrations/azure_ai_search/tests/__init__.py create mode 100644 integrations/azure_ai_search/tests/conftest.py create mode 100644 integrations/azure_ai_search/tests/test_document_store.py diff --git a/integrations/azure_ai_search/.gitignore b/integrations/azure_ai_search/.gitignore new file mode 100644 index 000000000..d1c340c1f --- /dev/null +++ b/integrations/azure_ai_search/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# VS Code +.vscode diff --git a/integrations/azure_ai_search/CHANGELOG.md b/integrations/azure_ai_search/CHANGELOG.md new file mode 100644 index 000000000..dd1ddb86e --- /dev/null +++ b/integrations/azure_ai_search/CHANGELOG.md @@ -0,0 +1,99 @@ +# Changelog + +## [integrations/opensearch-v0.8.1] - 2024-07-15 + +### 🚀 Features + +- Add raise_on_failure param to OpenSearch retrievers (#852) +- Add filter_policy to opensearch integration (#822) + +### 🐛 Bug Fixes + +- `OpenSearch` - Fallback to default filter policy when deserializing retrievers without the init parameter (#895) + +### ⚙️ Miscellaneous Tasks + +- Update ruff invocation to include check parameter (#853) + +## [integrations/opensearch-v0.7.1] - 2024-06-27 + +### 🐛 Bug Fixes + +- Serialization for custom_query in OpenSearch retrievers (#851) +- Support legacy filters with OpenSearchDocumentStore (#850) + +## [integrations/opensearch-v0.7.0] - 2024-06-25 + +### 🚀 Features + +- Defer the database connection to when it's needed (#753) +- Improve `OpenSearchDocumentStore.__init__` arguments (#739) +- Return_embeddings flag for opensearch (#784) +- Add create_index option to OpenSearchDocumentStore (#840) +- Add custom_query param to OpenSearch retrievers (#841) + +### 🐛 Bug Fixes + +- Fix order of API docs (#447) + +This PR will also push the docs to Readme + +### 📚 Documentation + +- Update category slug (#442) +- Fixing opensearch docstrings (#521) +- Small consistency improvements (#536) +- Disable-class-def (#556) + +### ⚙️ Miscellaneous Tasks + +- Retry tests to reduce flakyness (#836) + +### Opensearch + +- Generate API docs (#324) + +## [integrations/opensearch-v0.2.0] - 2024-01-17 + +### 🐛 Bug Fixes + +- Fix links in docstrings (#188) + + + +### 🚜 Refactor + +- Use `hatch_vcs` to manage integrations versioning (#103) + +## [integrations/opensearch-v0.1.1] - 2023-12-05 + +### 🐛 Bug Fixes + +- Fix import and increase version (#77) + + + +## [integrations/opensearch-v0.1.0] - 2023-12-04 + +### 🐛 Bug Fixes + +- Fix license headers + + +## [integrations/opensearch-v0.0.2] - 2023-11-30 + +### 🚀 Features + +- Extend OpenSearch params support (#70) + +### Build + +- Bump OpenSearch integration version to 0.0.2 (#71) + +## [integrations/opensearch-v0.0.1] - 2023-11-30 + +### 🚀 Features + +- [OpenSearch] add document store, BM25Retriever and EmbeddingRetriever (#68) + + diff --git a/integrations/azure_ai_search/LICENSE b/integrations/azure_ai_search/LICENSE new file mode 100644 index 000000000..de4c7f39f --- /dev/null +++ b/integrations/azure_ai_search/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/azure_ai_search/README.md b/integrations/azure_ai_search/README.md new file mode 100644 index 000000000..40a2f8eaa --- /dev/null +++ b/integrations/azure_ai_search/README.md @@ -0,0 +1,32 @@ +[![test](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) + +[![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) + +# OpenSearch Document Store + +Document Store for Haystack 2.x, supports OpenSearch. + +## Installation + +```console +pip install opensearch-haystack +``` + +## Testing + +To run tests first start a Docker container running OpenSearch. We provide a utility `docker-compose.yml` for that: + +```console +docker-compose up +``` + +Then run tests: + +```console +hatch run test +``` + +## License + +`opensearch-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/azure_ai_search/pydoc/config.yml b/integrations/azure_ai_search/pydoc/config.yml new file mode 100644 index 000000000..7b2e20d83 --- /dev/null +++ b/integrations/azure_ai_search/pydoc/config.yml @@ -0,0 +1,32 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.opensearch.bm25_retriever", + "haystack_integrations.components.retrievers.opensearch.embedding_retriever", + "haystack_integrations.document_stores.opensearch.document_store", + "haystack_integrations.document_stores.opensearch.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmeIntegrationRenderer + excerpt: OpenSearch integration for Haystack + category_slug: integrations-api + title: OpenSearch + slug: integrations-opensearch + order: 180 + markdown: + descriptive_class_title: false + classdef_code_block: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_opensearch.md diff --git a/integrations/azure_ai_search/pyproject.toml b/integrations/azure_ai_search/pyproject.toml new file mode 100644 index 000000000..c7061cae4 --- /dev/null +++ b/integrations/azure_ai_search/pyproject.toml @@ -0,0 +1,158 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "azure-ai-search-haystack" +dynamic = ["version"] +description = 'Haystack 2.x Document Store for Azure AI Search' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [{ name = "deepset", email = "info@deepset.ai" }] +classifiers = [ + "License :: OSI Approved :: Apache Software License", + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["haystack-ai", "azure-search-documents>=11.5"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/opensearch#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/opensearch" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/azure-ai-search-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/azure-ai-search-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-rerunfailures", + "pytest-xdist", + "haystack-pydoc-tools", +] +[tool.hatch.envs.default.scripts] +test = "pytest --reruns 0 --reruns-delay 30 -x {args:tests}" +test-cov = "coverage run -m pytest --reruns 3 --reruns-delay 30 -x {args:tests}" +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] + +docs = ["pydoc-markdown pydoc/config.yml"] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = ["ruff check {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff check --fix {args:.}", "style"] +all = ["style", "typing"] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.black] +target-version = ["py38"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py38" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["src"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source = ["haystack_integrations"] +branch = true +parallel = false + + +[tool.coverage.report] +omit = ["*/tests/*", "*/__init__.py"] +show_missing = true +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] + + +[tool.pytest.ini_options] +minversion = "6.0" +markers = ["unit: unit tests", "integration: integration tests"] + +[[tool.mypy.overrides]] +module = ["haystack.*", "haystack_integrations.*", "pytest.*", "azure-ai-search.*"] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/azure_ai_search/src/haystack_integrations/components/azure_ai_search/retrievers/embedding_retriever.py b/integrations/azure_ai_search/src/haystack_integrations/components/azure_ai_search/retrievers/embedding_retriever.py new file mode 100644 index 000000000..bb16e8027 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/components/azure_ai_search/retrievers/embedding_retriever.py @@ -0,0 +1,105 @@ +import logging +import os +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Union + +from azure.search.documents.models import VectorizedQuery +from haystack import Document, component +from haystack.document_stores.types import FilterPolicy +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +# from haystack.components.embedders import AzureOpenAIDocumentEmbedder, AzureOpenAITextEmbedder +# from .vectorizer import create_vectorizer, get_document_emebeddings, get_text_embeddings + +logger = logging.getLogger(__name__) + + +@component +class AzureAISearchEmbeddingRetriever: + """ + Retrieves documents from the AzureAISearchDocumentStore using a vector similarity metric. + + Must be connected to the AzureAISearchDocumentStore to run. + """ + + def __init__( + self, + *, + document_store: AzureAISearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + filter_policy: Union[str, FilterPolicy] = FilterPolicy.REPLACE, + ): + """ + Create the AzureAISearchEmbeddingRetriever component. + + :param document_store: An instance of AzureAISearchDocumentStore to use with the Retriever. + :param filters: Filters applied when fetching documents from the Document Store. + Filters are applied during the approximate kNN search to ensure the Retriever returns + `top_k` matching documents. + :param top_k: Maximum number of documents to return. + + """ + self.filters = filters or {} + self.top_k = top_k + self.document_store = document_store + self.filter_policy = ( + filter_policy if isinstance(filter_policy, FilterPolicy) else FilterPolicy.from_str(filter_policy) + ) + + if not isinstance(document_store, AzureAISearchDocumentStore): + message = "document_store must be an instance of AstraDocumentStore" + raise Exception(message) + + @component.output_types(documents=List[Document]) + def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None): + """Retrieve documents from the AzureAISearchDocumentStore. + + :param query_embedding: floats representing the query embedding + :param filters: Filters applied to the retrieved Documents. The way runtime filters are applied depends on + the `filter_policy` chosen at retriever initialization. See init method docstring for more + details. + :param top_k: the maximum number of documents to retrieve. + :returns: a dictionary with the following keys: + - `documents`: A list of documents retrieved from the AzureAISearchDocumentStore. + """ + # filters = apply_filter_policy(self.filter_policy, self.filters, filters) + top_k = top_k or self.top_k + + return {"documents": self._vector_search(query_embedding, top_k, filters=filters)} + + def _vector_search( + self, + query_embedding: List[float], + *, + top_k: int = 10, + fields: Optional[List[str]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + It uses the vector configuration of the document store. By default it uses the HNSW algorithm with cosine similarity. + + This method is not meant to be part of the public interface of + `AzureAISearchDocumentStore` nor called directly. + `AzureAISearchEmbeddingRetriever` uses this method directly and is the public interface for it. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + + :raises ValueError: If `query_embedding` is an empty list + :returns: List of Document that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + # embedding = get_embeddings(input=query, model=embedding_model_name, dimensions=self._embedding_dimension) + + vector_query = VectorizedQuery(vector=query_embedding, k_nearest_neighbors=3, fields="embeddings") + + results = self.client.search(search_text=None, vector_queries=[vector_query], select=fields) + + return results diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py new file mode 100644 index 000000000..51fb2b911 --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .document_store import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + +__all__ = ["AzureAISearchDocumentStore", "DEFAULT_VECTOR_SEARCH"] diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py new file mode 100644 index 000000000..80f0d785d --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -0,0 +1,378 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import logging +import os +from dataclasses import asdict +from typing import Any, Dict, List, Optional + +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.search.documents import SearchClient +from azure.search.documents.indexes import SearchIndexClient +from azure.search.documents.indexes.models import ( + HnswAlgorithmConfiguration, + HnswParameters, + SearchableField, + SearchField, + SearchFieldDataType, + SearchIndex, + SimpleField, + VectorSearch, + VectorSearchAlgorithmMetric, + VectorSearchProfile, +) +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores.errors import DuplicateDocumentError +from haystack.document_stores.types import DuplicatePolicy +from haystack.utils import Secret, deserialize_secrets_inplace + +from .errors import AzureAISearchDocumentStoreConfigError + +type_mapping = {str: "Edm.String", bool: "Edm.Boolean", int: "Edm.Int32", float: "Edm.Double"} + +MAX_UPLOAD_BATCH_SIZE = 1000 + +DEFAULT_VECTOR_SEARCH = VectorSearch( + profiles=[ + VectorSearchProfile(name="default-vector-config", algorithm_configuration_name="cosine-algorithm-config") + ], + algorithms=[ + HnswAlgorithmConfiguration( + name="cosine-algorithm-config", + parameters=HnswParameters( + metric=VectorSearchAlgorithmMetric.COSINE, + ), + ) + ], +) + +logger = logging.getLogger(__name__) +logging.getLogger("azure").setLevel(logging.ERROR) +logging.getLogger("azure.identity").setLevel(logging.DEBUG) + + +class AzureAISearchDocumentStore: + def __init__( + self, + *, + api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), + azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=False), + index_name: str = "default", + embedding_dimension: int = 768, # whats a better default value + metadata_fields: Optional[Dict[str, type]] = None, + vector_search_configuration: VectorSearch = None, + create_index: bool = True, + **kwargs, + ): + """ + A document store using [Azure AI Search](https://azure.microsoft.com/products/ai-services/ai-search/) + as the backend. + + :param azure_endpoint: The URL endpoint of an Azure AI Search service. + :param api_key: The API key to use for authentication. + :param index_name: Name of index in Azure AI Search, if it doesn't exist it will be created. + :param embedding_dimension: Dimension of the embeddings. + :param metadata_fields: A dictionary of metatada keys and their types to create + additional fields in index schema. As fields in Azure SearchIndex cannot be dynamic, it is necessary to specify the metadata fields in advance. + :param vector_search_configuration: Configuration option related to vector search. + Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches. + + :param kwargs: Optional keyword parameters for Azure AI Search. + Some of the supported parameters: + - `api_version`: The Search API version to use for requests. + - `audience`: sets the Audience to use for authentication with Azure Active Directory (AAD). + The audience is not considered when using a shared key. If audience is not provided, the public cloud audience will be assumed. + + For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) + """ + + azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") + if not azure_endpoint: + msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." + raise ValueError(msg) + api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") + + self._client = None + self._index_client = None + self._index_fields = None # stores all fields in the final schema of index + self._api_key = api_key + self._azure_endpoint = azure_endpoint + self._index_name = index_name + self._embedding_dimension = embedding_dimension + self._metadata_fields = metadata_fields + self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH + self._create_index = create_index + self._kwargs = kwargs + + @property + def client(self) -> SearchClient: + + if isinstance(self._azure_endpoint, Secret): + self._azure_endpoint = self._azure_endpoint.resolve_value() + + if isinstance(self._api_key, Secret): + self._api_key = self._api_key.resolve_value() + credential = AzureKeyCredential(self._api_key) if self._api_key else DefaultAzureCredential() + try: + if not self._index_client: + self._index_client = SearchIndexClient(self._azure_endpoint, credential, **self._kwargs) + if not self.index_exists(self._index_name): + # Create a new index if it does not exist + logger.debug( + "The index '%s' does not exist. A new index will be created.", + self._index_name, + ) + self.create_index(self._index_name) + except (HttpResponseError, ClientAuthenticationError) as error: + msg = f"Failed to authenticate with Azure Search: {error}" + raise AzureAISearchDocumentStoreConfigError(msg) from error + + self._client = self._index_client.get_search_client(self._index_name) + return self._client + + def create_index(self, index_name: str, **kwargs) -> None: + """ + Creates a new search index. + :param index_name: Name of the index to create. If None, the index name from the constructor is used. + :param kwargs: Optional keyword parameters. + + """ + + # default fields to create index based on Haystack Document + default_fields = [ + SimpleField(name="id", type=SearchFieldDataType.String, key=True, filterable=True), + SearchableField(name="content", type=SearchFieldDataType.String), + SearchField( + name="embedding", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + searchable=True, + vector_search_dimensions=self._embedding_dimension, + vector_search_profile_name="default-vector-config", + ), + ] + + if not index_name: + index_name = self._index_name + fields = default_fields + if self._metadata_fields: + fields.extend(self._create_metadata_index_fields(self._metadata_fields)) + + self._index_fields = fields + index = SearchIndex(name=index_name, fields=fields, vector_search=self._vector_search_configuration, **kwargs) + self._index_client.create_index(index) + + def to_dict(self) -> Dict[str, Any]: + # This is not the best solution to serialise this class but is the fastest to implement. + # Not all kwargs types can be serialised to text so this can fail. We must serialise each + # type explicitly to handle this properly. + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint is not None else None, + api_key=self._api_key.to_dict() if self._api_key is not None else None, + index_name=self._index_name, + create_index=self._create_index, + embedding_dimension=self._embedding_dimension, + metadata_fields=self._metadata_fields, + vector_search_configuration=self._vector_search_configuration, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + + :returns: + Deserialized component. + """ + + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"]) + return default_from_dict(cls, data) + + def count_documents(self, **kwargs: Any) -> int: + """ + Returns how many documents are present in the search index. + + :param kwargs: additional keyword parameters. + :returns: list of retrieved documents. + """ + return self.client.get_document_count(**kwargs) + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> int: + """ + Writes the provided documents to search index. + + :param documents: documents to write to the index. + :return: the number of documents added to index. + """ + + if len(documents) > 0: + if not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + def _convert_input_document(documents: Document): + document_dict = asdict(documents) + if not isinstance(document_dict["id"], str): + msg = f"Document id {document_dict['id']} is not a string, " + raise Exception(msg) + index_document = self._default_index_mapping(document_dict) + return index_document + + documents_to_write = [] + for doc in documents: + try: + self.client.get_document(doc.id) + if policy == DuplicatePolicy.SKIP: + logger.info(f"Document with ID {doc.id} already exists. Skipping.") + continue + elif policy == DuplicatePolicy.FAIL: + msg = f"Document with ID {doc.id} already exists." + raise DuplicateDocumentError(msg) + elif policy == DuplicatePolicy.OVERWRITE: + logger.info(f"Document with ID {doc.id} already exists. Overwriting.") + documents_to_write.append(_convert_input_document(doc)) + except ResourceNotFoundError: + # Document does not exist, safe to add + documents_to_write.append(_convert_input_document(doc)) + + if documents_to_write != []: + self.client.upload_documents(documents_to_write) + return len(documents_to_write) + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the search index. + + :param document_ids: ids of the documents to be deleted. + """ + if self.count_documents == 0: + return + documents = self._get_raw_documents_by_id(document_ids) + if documents: + self.client.delete_documents(documents) + + def _get_raw_documents_by_id(self, document_ids: List[str]): + """ + Retrieves all Azure documents with a matching document_ids from the document store. + + :param document_ids: ids of the documents to be retrieved. + :returns: list of retrieved Azure documents. + """ + azure_documents = [] + for doc_id in document_ids: + try: + document = self.client.get_document(doc_id) + azure_documents.append(document) + except ResourceNotFoundError: + logger.warning(f"Document with ID {doc_id} not found.") + return azure_documents + + def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: + return self._convert_search_result_to_documents(self._get_raw_documents_by_id(document_ids)) + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + + # TODO: Implement this method to filter documents based on metadata fields + # For now the implementation is similar to search_documents + """ + Calls the Azure AI Search client's search method and handles pagination. + :param search_text: The text to search for. If not supplied, all documents will be retrieved. + :returns: A list of Documents that match the given filters. + """ + + search_text = "*" # default to search all documents + azure_docs = [] + if filters: + # Handle filtering by 'id' first + document_ids = filters.get("id") + if document_ids: + azure_docs = self._get_raw_documents_by_id(document_ids) + return self._convert_search_result_to_documents(azure_docs) + + # Handle filtering by 'content' + search_text = filters.get("content", "*") + + # Perform search with pagination + result = self.client.search(search_text=search_text, top=self.count_documents()) + azure_docs = list(result) + return self._convert_search_result_to_documents(azure_docs) + + def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]]) -> List[Document]: + + documents = [] + for azure_doc in azure_docs: + + embedding = azure_doc.get("embedding") + if embedding == self._dummy_vector: + embedding = None + + # Filter out meta fields + meta = { + key: value + for key, value in azure_doc.items() + if key not in ["id", "content", "embedding"] and not key.startswith("@") + } + + # Create the document with meta only if it's non-empty + doc = Document( + id=azure_doc["id"], content=azure_doc["content"], embedding=embedding, meta=meta if meta else {} + ) + + documents.append(doc) + return documents + + def index_exists(self, index_name: Optional[str]) -> bool: + if self._index_client and index_name: + return index_name in self._index_client.list_index_names() + + def _default_index_mapping(self, document: Dict[str, Any]) -> Dict[str, Any]: + """Map the document keys to fields of search index""" + + keys_to_remove = ["dataframe", "blob", "sparse_embedding", "score"] + index_document = {k: v for k, v in document.items() if k not in keys_to_remove} + + metadata = index_document.pop("meta", None) + for key, value in metadata.items(): + index_document[key] = value + if index_document["embedding"] is None: + self._dummy_vector = [-10.0] * self._embedding_dimension + index_document["embedding"] = self._dummy_vector + + return index_document + + def _create_metadata_index_fields(self, metadata: Dict[str, Any]) -> List[SimpleField]: + """Create a list of index fields for storing metadata values.""" + + index_fields = [] + metadata_field_mapping = self._map_metadata_field_types(metadata) + + for key, field_type in metadata_field_mapping.items(): + index_fields.append(SimpleField(name=key, type=field_type, filterable=True)) + + return index_fields + + def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]: + """Map metadata field types to Azure Search field types.""" + metadata_field_mapping = {} + + for key, value_type in metadata.items(): + field_type = type_mapping.get(value_type) + if not field_type: + error_message = f"Unsupported field type for key '{key}': {value_type}" + raise ValueError(error_message) + metadata_field_mapping[key] = field_type + + return metadata_field_mapping diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py new file mode 100644 index 000000000..b2756050a --- /dev/null +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/errors.py @@ -0,0 +1,13 @@ +from haystack.document_stores.errors import DocumentStoreError + + +class AzureAISearchDocumentStoreError(DocumentStoreError): + """Parent class for all AzureAISearchDocumentStore exceptions.""" + + pass + + +class AzureAISearchDocumentStoreConfigError(AzureAISearchDocumentStoreError): + """Raised when a configuration is not valid for a AzureAISearchDocumentStore.""" + + pass diff --git a/integrations/azure_ai_search/tests/__init__.py b/integrations/azure_ai_search/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/azure_ai_search/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/azure_ai_search/tests/conftest.py b/integrations/azure_ai_search/tests/conftest.py new file mode 100644 index 000000000..189d04e57 --- /dev/null +++ b/integrations/azure_ai_search/tests/conftest.py @@ -0,0 +1,67 @@ +import os +import time + +import pytest +from azure.core.credentials import AzureKeyCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.search.documents.indexes import SearchIndexClient +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.azure_ai_search import AzureAISearchDocumentStore + +# This is the approximate time in seconds it takes for the documents to be available +SLEEP_TIME_IN_SECONDS = 5 + + +@pytest.fixture() +def sleep_time(): + return SLEEP_TIME_IN_SECONDS + + +@pytest.fixture +def document_store(request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + index_name = "haystack_test_integration" + metadata_fields = getattr(request, "param", {}).get("metadata_fields", None) + + azure_endpoint = os.environ["AZURE_SEARCH_SERVICE_ENDPOINT"] + api_key = os.environ["AZURE_SEARCH_API_KEY"] + + client = SearchIndexClient(azure_endpoint, AzureKeyCredential(api_key)) + if index_name in client.list_index_names(): + client.delete_index(index_name) + + store = AzureAISearchDocumentStore( + api_key=api_key, + azure_endpoint=azure_endpoint, + index_name=index_name, + create_index=True, + embedding_dimension=15, + metadata_fields=metadata_fields, + ) + + # Override some methods to wait for the documents to be available + original_write_documents = store.write_documents + + def write_documents_and_wait(documents, policy=DuplicatePolicy.NONE): + written_docs = original_write_documents(documents, policy) + time.sleep(SLEEP_TIME_IN_SECONDS) + return written_docs + + original_delete_documents = store.delete_documents + + def delete_documents_and_wait(filters): + original_delete_documents(filters) + time.sleep(SLEEP_TIME_IN_SECONDS) + + store.write_documents = write_documents_and_wait + store.delete_documents = delete_documents_and_wait + + yield store + try: + client.delete_index(index_name) + except ResourceNotFoundError: + pass diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py new file mode 100644 index 000000000..c7fae3f18 --- /dev/null +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import patch +import pytest +from haystack.dataclasses.document import Document +from haystack.testing.document_store import ( + CountDocumentsTest, + DeleteDocumentsTest, + WriteDocumentsTest, +) +from haystack.utils.auth import EnvVarSecret, Secret + +from haystack_integrations.document_stores.azure_ai_search import DEFAULT_VECTOR_SEARCH, AzureAISearchDocumentStore + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_to_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + document_store = AzureAISearchDocumentStore() + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": False, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "index_name": "default", + "embedding_dimension": 768, + "metadata_fields": None, + "create_index": True, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + }, + } + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_from_dict(monkeypatch): + monkeypatch.setenv("AZURE_SEARCH_API_KEY", "test-api-key") + monkeypatch.setenv("AZURE_SEARCH_SERVICE_ENDPOINT", "test-endpoint") + + data = { + "type": "haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore", + "init_parameters": { + "azure_endpoint": {"env_vars": ["AZURE_SEARCH_SERVICE_ENDPOINT"], "strict": False, "type": "env_var"}, + "api_key": {"env_vars": ["AZURE_SEARCH_API_KEY"], "strict": False, "type": "env_var"}, + "embedding_dimension": 768, + "index_name": "default", + "metadata_fields": None, + "create_index": False, + "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + }, + } + document_store = AzureAISearchDocumentStore.from_dict(data) + assert isinstance(document_store._api_key, EnvVarSecret) + assert isinstance(document_store._azure_endpoint, EnvVarSecret) + assert document_store._index_name == "default" + assert document_store._embedding_dimension == 768 + assert document_store._metadata_fields is None + assert document_store._create_index is False + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init_is_lazy(_mock_azure_search_client): + AzureAISearchDocumentStore(azure_endppoint=Secret.from_token("test_endpoint")) + _mock_azure_search_client.assert_not_called() + + +@patch("haystack_integrations.document_stores.azure_ai_search.document_store.AzureAISearchDocumentStore") +def test_init(_mock_azure_search_client): + + document_store = AzureAISearchDocumentStore( + api_key=Secret.from_token("fake-api-key"), + azure_endpoint=Secret.from_token("fake_endpoint"), + index_name="my_index", + create_index=False, + embedding_dimension=15, + metadata_fields={"Title": str, "Pages": int}, + ) + + assert document_store._index_name == "my_index" + assert document_store._create_index is False + assert document_store._embedding_dimension == 15 + assert document_store._metadata_fields == {"Title": str, "Pages": int} + assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH + + +@pytest.mark.skipif( + not os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT", None) and not os.environ.get("AZURE_SEARCH_API_KEY", None), + reason="Missing AZURE_SEARCH_SERVICE_ENDPOINT or AZURE_SEARCH_API_KEY.", +) +class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + + def test_write_documents(self, document_store: AzureAISearchDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + + # Parametrize the test with metadata fields + @pytest.mark.parametrize( + "document_store", + [ + {"metadata_fields": {"author": str, "publication_year": int, "rating": float}}, + ], + indirect=True, + ) + def test_write_documents_with_meta(self, document_store: AzureAISearchDocumentStore): + docs = [ + Document( + id="1", + meta={"author": "Tom", "publication_year": 2021, "rating": 4.5}, + content="This is a test document.", + ) + ] + document_store.write_documents(docs) + doc = document_store.get_documents_by_id(["1"]) + assert doc[0] == docs[0]