From 01f08b951e5a119542646ac4535d272e2ebbae22 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 30 Jan 2024 09:56:01 +0100 Subject: [PATCH 01/26] feat: Sagemaker integration: `SagemakerGenerator` (#276) * basic generator and tests * readme * fix import paths * improve tests * to/from dict test * review feedback * readme * quotes * typo * readme * labeler --- .github/labeler.yml | 5 + .github/workflows/amazon_sagemaker.yml | 56 ++++ integrations/amazon_sagemaker/LICENSE.txt | 73 ++++++ integrations/amazon_sagemaker/README.md | 52 ++++ integrations/amazon_sagemaker/pyproject.toml | 177 +++++++++++++ .../generators/amazon_sagemaker/__init__.py | 6 + .../generators/amazon_sagemaker/errors.py | 46 ++++ .../generators/amazon_sagemaker/sagemaker.py | 224 ++++++++++++++++ .../amazon_sagemaker/tests/__init__.py | 3 + .../amazon_sagemaker/tests/test_sagemaker.py | 243 ++++++++++++++++++ 10 files changed, 885 insertions(+) create mode 100644 .github/workflows/amazon_sagemaker.yml create mode 100644 integrations/amazon_sagemaker/LICENSE.txt create mode 100644 integrations/amazon_sagemaker/README.md create mode 100644 integrations/amazon_sagemaker/pyproject.toml create mode 100644 integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py create mode 100644 integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py create mode 100644 integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py create mode 100644 integrations/amazon_sagemaker/tests/__init__.py create mode 100644 integrations/amazon_sagemaker/tests/test_sagemaker.py diff --git a/.github/labeler.yml b/.github/labeler.yml index ba74c43a2..f5eaa3374 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -4,6 +4,11 @@ integration:amazon-bedrock: - any-glob-to-any-file: "integrations/amazon_bedrock/**/*" - any-glob-to-any-file: ".github/workflows/amazon_bedrock.yml" +integration:amazon-sagemaker: + - changed-files: + - any-glob-to-any-file: "integrations/amazon_sagemaker/**/*" + - any-glob-to-any-file: ".github/workflows/amazon_sagemaker.yml" + integration:astra: - changed-files: - any-glob-to-any-file: "integrations/astra/**/*" diff --git a/.github/workflows/amazon_sagemaker.yml b/.github/workflows/amazon_sagemaker.yml new file mode 100644 index 000000000..88f397c85 --- /dev/null +++ b/.github/workflows/amazon_sagemaker.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / amazon-sagemaker + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/amazon_sagemaker/**" + - ".github/workflows/amazon_sagemaker.yml" + +defaults: + run: + working-directory: integrations/amazon_sagemaker + +concurrency: + group: amazon-sagemaker-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov diff --git a/integrations/amazon_sagemaker/LICENSE.txt b/integrations/amazon_sagemaker/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/amazon_sagemaker/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/amazon_sagemaker/README.md b/integrations/amazon_sagemaker/README.md new file mode 100644 index 000000000..1ea01871d --- /dev/null +++ b/integrations/amazon_sagemaker/README.md @@ -0,0 +1,52 @@ +# amazon-sagemaker-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [Contributing](#contributing) +- [License](#license) + +## Installation + +```console +pip install amazon-sagemaker-haystack +``` + +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` + +> Note: You need to export your AWS credentials for Sagemaker integration tests to run (`AWS_ACCESS_KEY_ID` and +`AWS_SECRET_SECRET_KEY`). If those are missing, the integration tests will be skipped. + +To only run unit tests: +``` +hatch run test -m "not integration" +``` + +To only run integration tests: +``` +hatch run test -m "integration" +``` + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + +## License + +`amazon-sagemaker-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/amazon_sagemaker/pyproject.toml b/integrations/amazon_sagemaker/pyproject.toml new file mode 100644 index 000000000..916307156 --- /dev/null +++ b/integrations/amazon_sagemaker/pyproject.toml @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "amazon-sagemaker-haystack" +dynamic = ["version"] +description = 'An integration of Amazon Sagemaker as an SagemakerGenerator component.' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "boto3>=1.28.57", +] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_sagemaker_haystack" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/amazon_sagemaker-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/amazon_sagemaker-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Import sorting doesn't seem to work + "I001", + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["haystack_integrations"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +branch = true +parallel = true + +[tool.coverage.paths] +amazon_sagemaker_haystack = ["src"] +tests = ["tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "haystack_integrations.*", + "pytest.*", + "numpy.*", +] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py new file mode 100644 index 000000000..0fe45a8a1 --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from haystack_integrations.components.generators.amazon_sagemaker.sagemaker import SagemakerGenerator + +__all__ = ["SagemakerGenerator"] diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py new file mode 100644 index 000000000..6c13d0fcb --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/errors.py @@ -0,0 +1,46 @@ +from typing import Optional + + +class SagemakerError(Exception): + """ + Error generated by the Amazon Sagemaker integration. + """ + + def __init__( + self, + message: Optional[str] = None, + ): + super().__init__() + if message: + self.message = message + + def __getattr__(self, attr): + # If self.__cause__ is None, it will raise the expected AttributeError + getattr(self.__cause__, attr) + + def __str__(self): + return self.message + + def __repr__(self): + return str(self) + + +class AWSConfigurationError(SagemakerError): + """Exception raised when AWS is not configured correctly""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class SagemakerNotReadyError(SagemakerError): + """Exception for issues that occur during Sagemaker inference""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) + + +class SagemakerInferenceError(SagemakerError): + """Exception for issues that occur during Sagemaker inference""" + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message) diff --git a/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py new file mode 100644 index 000000000..35e54a055 --- /dev/null +++ b/integrations/amazon_sagemaker/src/haystack_integrations/components/generators/amazon_sagemaker/sagemaker.py @@ -0,0 +1,224 @@ +import json +import logging +import os +from typing import Any, ClassVar, Dict, List, Optional + +import requests +from haystack import component, default_from_dict, default_to_dict +from haystack.lazy_imports import LazyImport +from haystack_integrations.components.generators.amazon_sagemaker.errors import ( + AWSConfigurationError, + SagemakerInferenceError, + SagemakerNotReadyError, +) + +with LazyImport(message="Run 'pip install boto3'") as boto3_import: + import boto3 # type: ignore + from botocore.client import BaseClient # type: ignore + + +logger = logging.getLogger(__name__) + + +MODEL_NOT_READY_STATUS_CODE = 429 + + +@component +class SagemakerGenerator: + """ + Enables text generation using Sagemaker. It supports Large Language Models (LLMs) hosted and deployed on a SageMaker + Inference Endpoint. For guidance on how to deploy a model to SageMaker, refer to the + [SageMaker JumpStart foundation models documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-use.html). + + **Example:** + + First export your AWS credentials as environment variables: + ```bash + export AWS_ACCESS_KEY_ID= + export AWS_SECRET_ACCESS_KEY= + ``` + (Note: you may also need to set the session token and region name, depending on your AWS configuration) + + Then you can use the generator as follows: + ```python + from haystack.components.generators.sagemaker import SagemakerGenerator + generator = SagemakerGenerator(model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16") + generator.warm_up() + response = generator.run("What's Natural Language Processing? Be brief.") + print(response) + ``` + ``` + >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on + >> the interaction between computers and human language. It involves enabling computers to understand, interpret, + >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{}]} + ``` + """ + + model_generation_keys: ClassVar = ["generated_text", "generation"] + + def __init__( + self, + model: str, + aws_access_key_id_var: str = "AWS_ACCESS_KEY_ID", + aws_secret_access_key_var: str = "AWS_SECRET_ACCESS_KEY", + aws_session_token_var: str = "AWS_SESSION_TOKEN", + aws_region_name_var: str = "AWS_REGION", + aws_profile_name_var: str = "AWS_PROFILE", + aws_custom_attributes: Optional[Dict[str, Any]] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Instantiates the session with SageMaker. + + :param model: The name for SageMaker Model Endpoint. + :param aws_access_key_id_var: The name of the env var where the AWS access key ID is stored. + :param aws_secret_access_key_var: The name of the env var where the AWS secret access key is stored. + :param aws_session_token_var: The name of the env var where the AWS session token is stored. + :param aws_region_name_var: The name of the env var where the AWS region name is stored. + :param aws_profile_name_var: The name of the env var where the AWS profile name is stored. + :param aws_custom_attributes: Custom attributes to be passed to SageMaker, for example `{"accept_eula": True}` + in case of Llama-2 models. + :param generation_kwargs: Additional keyword arguments for text generation. For a list of supported parameters + see your model's documentation page, for example here for HuggingFace models: + https://huggingface.co/blog/sagemaker-huggingface-llm#4-run-inference-and-chat-with-our-model + + Specifically, Llama-2 models support the following inference payload parameters: + + - `max_new_tokens`: Model generates text until the output length (excluding the input context length) + reaches `max_new_tokens`. If specified, it must be a positive integer. + - `temperature`: Controls the randomness in the output. Higher temperature results in output sequence with + low-probability words and lower temperature results in output sequence with high-probability words. + If `temperature=0`, it results in greedy decoding. If specified, it must be a positive float. + - `top_p`: In each step of text generation, sample from the smallest possible set of words with cumulative + probability `top_p`. If specified, it must be a float between 0 and 1. + - `return_full_text`: If `True`, input text will be part of the output generated text. If specified, it must + be boolean. The default value for it is `False`. + """ + self.model = model + self.aws_access_key_id_var = aws_access_key_id_var + self.aws_secret_access_key_var = aws_secret_access_key_var + self.aws_session_token_var = aws_session_token_var + self.aws_region_name_var = aws_region_name_var + self.aws_profile_name_var = aws_profile_name_var + self.aws_custom_attributes = aws_custom_attributes or {} + self.generation_kwargs = generation_kwargs or {"max_new_tokens": 1024} + self.client: Optional[BaseClient] = None + + if not os.getenv(self.aws_access_key_id_var) or not os.getenv(self.aws_secret_access_key_var): + msg = ( + f"Please provide AWS credentials via environment variables '{self.aws_access_key_id_var}' and " + f"'{self.aws_secret_access_key_var}'." + ) + raise AWSConfigurationError(msg) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize the object to a dictionary. + """ + return default_to_dict( + self, + model=self.model, + aws_access_key_id_var=self.aws_access_key_id_var, + aws_secret_access_key_var=self.aws_secret_access_key_var, + aws_session_token_var=self.aws_session_token_var, + aws_region_name_var=self.aws_region_name_var, + aws_profile_name_var=self.aws_profile_name_var, + aws_custom_attributes=self.aws_custom_attributes, + generation_kwargs=self.generation_kwargs, + ) + + @classmethod + def from_dict(cls, data) -> "SagemakerGenerator": + """ + Deserialize the dictionary into an instance of SagemakerGenerator. + """ + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initializes the SageMaker Inference client. + """ + boto3_import.check() + try: + session = boto3.Session( + aws_access_key_id=os.getenv(self.aws_access_key_id_var), + aws_secret_access_key=os.getenv(self.aws_secret_access_key_var), + aws_session_token=os.getenv(self.aws_session_token_var), + region_name=os.getenv(self.aws_region_name_var), + profile_name=os.getenv(self.aws_profile_name_var), + ) + self.client = session.client("sagemaker-runtime") + except Exception as e: + msg = ( + f"Could not connect to SageMaker Inference Endpoint '{self.model}'." + f"Make sure the Endpoint exists and AWS environment is configured." + ) + raise AWSConfigurationError(msg) from e + + @component.output_types(replies=List[str], meta=List[Dict[str, Any]]) + def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param prompt: The string prompt to use for text generation. + :param generation_kwargs: Additional keyword arguments for text generation. These parameters will + potentially override the parameters passed in the `__init__` method. + + :return: A list of strings containing the generated responses and a list of dictionaries containing the metadata + for each response. + """ + if self.client is None: + msg = "SageMaker Inference client is not initialized. Please call warm_up() first." + raise ValueError(msg) + + generation_kwargs = generation_kwargs or self.generation_kwargs + custom_attributes = ";".join( + f"{k}={str(v).lower() if isinstance(v, bool) else str(v)}" for k, v in self.aws_custom_attributes.items() + ) + try: + body = json.dumps({"inputs": prompt, "parameters": generation_kwargs}) + response = self.client.invoke_endpoint( + EndpointName=self.model, + Body=body, + ContentType="application/json", + Accept="application/json", + CustomAttributes=custom_attributes, + ) + response_json = response.get("Body").read().decode("utf-8") + output: Dict[str, Dict[str, Any]] = json.loads(response_json) + + # The output might be either a list of dictionaries or a single dictionary + list_output: List[Dict[str, Any]] + if output and isinstance(output, dict): + list_output = [output] + elif isinstance(output, list) and all(isinstance(o, dict) for o in output): + list_output = output + else: + msg = f"Unexpected model response type: {type(output)}" + raise ValueError(msg) + + # The key where the replies are stored changes from model to model, so we need to look for it. + # All other keys in the response are added to the metadata. + # Unfortunately every model returns different metadata, most of them return none at all, + # so we can't replicate the metadata structure of other generators. + for key in self.model_generation_keys: + if key in list_output[0]: + break + replies = [o.pop(key, None) for o in list_output] + + return {"replies": replies, "meta": list_output * len(replies)} + + except requests.HTTPError as err: + res = err.response + if res.status_code == MODEL_NOT_READY_STATUS_CODE: + msg = f"Sagemaker model not ready: {res.text}" + raise SagemakerNotReadyError(msg) from err + + msg = f"SageMaker Inference returned an error. Status code: {res.status_code} Response body: {res.text}" + raise SagemakerInferenceError(msg, status_code=res.status_code) from err diff --git a/integrations/amazon_sagemaker/tests/__init__.py b/integrations/amazon_sagemaker/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/amazon_sagemaker/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/amazon_sagemaker/tests/test_sagemaker.py b/integrations/amazon_sagemaker/tests/test_sagemaker.py new file mode 100644 index 000000000..a22634be1 --- /dev/null +++ b/integrations/amazon_sagemaker/tests/test_sagemaker.py @@ -0,0 +1,243 @@ +import os +from unittest.mock import Mock + +import pytest +from haystack_integrations.components.generators.amazon_sagemaker import SagemakerGenerator +from haystack_integrations.components.generators.amazon_sagemaker.errors import AWSConfigurationError + + +class TestSagemakerGenerator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator(model="test-model") + assert component.model == "test-model" + assert component.aws_access_key_id_var == "AWS_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "AWS_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "AWS_SESSION_TOKEN" + assert component.aws_region_name_var == "AWS_REGION" + assert component.aws_profile_name_var == "AWS_PROFILE" + assert component.aws_custom_attributes == {} + assert component.generation_kwargs == {"max_new_tokens": 1024} + assert component.client is None + + def test_init_fail_wo_access_key_or_secret_key(self, monkeypatch): + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + with pytest.raises(AWSConfigurationError): + SagemakerGenerator(model="test-model") + + def test_init_with_parameters(self, monkeypatch): + monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator( + model="test-model", + aws_access_key_id_var="MY_ACCESS_KEY_ID", + aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", + aws_session_token_var="MY_SESSION_TOKEN", + aws_region_name_var="MY_REGION", + aws_profile_name_var="MY_PROFILE", + aws_custom_attributes={"custom": "attr"}, + generation_kwargs={"generation": "kwargs"}, + ) + assert component.model == "test-model" + assert component.aws_access_key_id_var == "MY_ACCESS_KEY_ID" + assert component.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" + assert component.aws_session_token_var == "MY_SESSION_TOKEN" + assert component.aws_region_name_var == "MY_REGION" + assert component.aws_profile_name_var == "MY_PROFILE" + assert component.aws_custom_attributes == {"custom": "attr"} + assert component.generation_kwargs == {"generation": "kwargs"} + assert component.client is None + + def test_to_from_dict(self, monkeypatch): + monkeypatch.setenv("MY_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("MY_SECRET_ACCESS_KEY", "test-secret-key") + + component = SagemakerGenerator( + model="test-model", + aws_access_key_id_var="MY_ACCESS_KEY_ID", + aws_secret_access_key_var="MY_SECRET_ACCESS_KEY", + aws_session_token_var="MY_SESSION_TOKEN", + aws_region_name_var="MY_REGION", + aws_profile_name_var="MY_PROFILE", + aws_custom_attributes={"custom": "attr"}, + generation_kwargs={"generation": "kwargs"}, + ) + serialized = component.to_dict() + assert serialized == { + "type": "haystack_integrations.components.generators.amazon_sagemaker.sagemaker.SagemakerGenerator", + "init_parameters": { + "model": "test-model", + "aws_access_key_id_var": "MY_ACCESS_KEY_ID", + "aws_secret_access_key_var": "MY_SECRET_ACCESS_KEY", + "aws_session_token_var": "MY_SESSION_TOKEN", + "aws_region_name_var": "MY_REGION", + "aws_profile_name_var": "MY_PROFILE", + "aws_custom_attributes": {"custom": "attr"}, + "generation_kwargs": {"generation": "kwargs"}, + }, + } + deserialized = SagemakerGenerator.from_dict(serialized) + assert deserialized.model == "test-model" + assert deserialized.aws_access_key_id_var == "MY_ACCESS_KEY_ID" + assert deserialized.aws_secret_access_key_var == "MY_SECRET_ACCESS_KEY" + assert deserialized.aws_session_token_var == "MY_SESSION_TOKEN" + assert deserialized.aws_region_name_var == "MY_REGION" + assert deserialized.aws_profile_name_var == "MY_PROFILE" + assert deserialized.aws_custom_attributes == {"custom": "attr"} + assert deserialized.generation_kwargs == {"generation": "kwargs"} + assert deserialized.client is None + + def test_run_with_list_of_dictionaries(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'[{"generated_text": "test-reply", "other": "metadata"}]') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock # Simulate warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + def test_run_with_single_dictionary(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "test-access-key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test-secret-key") + client_mock = Mock() + client_mock.invoke_endpoint.return_value = { + "Body": Mock(read=lambda: b'{"generation": "test-reply", "other": "metadata"}') + } + + component = SagemakerGenerator(model="test-model") + component.client = client_mock # Simulate warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + assert "test-reply" in response["replies"][0] + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + assert response["meta"][0]["other"] == "metadata" + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_falcon(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-llm-falcon-7b-instruct-bf16", generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_llama2(self): + component = SagemakerGenerator( + model="jumpstart-dft-meta-textgenerationneuron-llama-2-7b", + generation_kwargs={"max_new_tokens": 10}, + aws_custom_attributes={"accept_eula": True}, + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] + + @pytest.mark.skipif( + (not os.environ.get("AWS_ACCESS_KEY_ID", None) or not os.environ.get("AWS_SECRET_ACCESS_KEY", None)), + reason="Export two env vars called AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to run this test.", + ) + @pytest.mark.integration + def test_run_bloomz(self): + component = SagemakerGenerator( + model="jumpstart-dft-hf-textgeneration-bloomz-1b1", generation_kwargs={"max_new_tokens": 10} + ) + component.warm_up() + response = component.run("What's Natural Language Processing?") + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, str) for reply in response["replies"]] + + # Coarse check: assuming no more than 4 chars per token. In any case it + # will fail if the `max_new_tokens` parameter is not respected, as the + # default is either 256 or 1024 + assert all(len(reply) <= 40 for reply in response["replies"]) + + assert "meta" in response + assert isinstance(response["meta"], list) + assert len(response["meta"]) == 1 + assert [isinstance(reply, dict) for reply in response["meta"]] From 762045d7137d271b0eba737fa93df488a0b1e056 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 30 Jan 2024 09:58:59 +0100 Subject: [PATCH 02/26] pin sentence transformers (#289) --- integrations/instructor_embedders/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/instructor_embedders/pyproject.toml b/integrations/instructor_embedders/pyproject.toml index 67cbcb7af..c8a591b69 100644 --- a/integrations/instructor_embedders/pyproject.toml +++ b/integrations/instructor_embedders/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "requests>=2.26.0", "scikit_learn>=1.0.2", "scipy", - "sentence_transformers>=2.2.0", + "sentence_transformers>=2.2.0,<2.3.0", "torch", "tqdm", "rich", From 29c869e819cf1b0fe7b7c0702c90943c0aa2964e Mon Sep 17 00:00:00 2001 From: ZanSara Date: Tue, 30 Jan 2024 11:32:37 +0100 Subject: [PATCH 03/26] Add Sagemaker to README (#291) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 20b17b377..39d669322 100644 --- a/README.md +++ b/README.md @@ -80,3 +80,4 @@ deepset-haystack | [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | | [uptrain-haystack](integrations/uptrain/) | Evaluator | [![PyPI - Version](https://img.shields.io/pypi/v/uptrain-haystack.svg)](https://pypi.org/project/uptrain-haystack) | [![Test / uptrain](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/uptrain.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/uptrain.yml) | +| [amazon-sagemaker-haystack](integrations/amazon_sagemaker/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/amazon-sagemaker-haystack.svg)](https://pypi.org/project/amazon-sagemaker-haystack) | [![Test / amazon_sagemaker](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/amazon_sagemaker.yml) | From cd737575df28f291f9786e2917a984f70a4ca567 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Tue, 30 Jan 2024 12:40:45 +0100 Subject: [PATCH 04/26] fix: Broken version pattern in `pyproject.toml` (#294) --- integrations/uptrain/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/uptrain/pyproject.toml b/integrations/uptrain/pyproject.toml index 631b7dab8..498772313 100644 --- a/integrations/uptrain/pyproject.toml +++ b/integrations/uptrain/pyproject.toml @@ -34,11 +34,11 @@ packages = ["src/haystack_integrations"] [tool.hatch.version] source = "vcs" -tag-pattern = 'integrations\/uptrain(?P.*)' +tag-pattern = 'integrations\/uptrain-v(?P.*)' [tool.hatch.version.raw-options] root = "../.." -git_describe_command = 'git describe --tags --match="integrations/uptrain[0-9]*"' +git_describe_command = 'git describe --tags --match="integrations/uptrain-v[0-9]*"' [tool.hatch.envs.default] dependencies = ["coverage[toml]>=6.5", "pytest"] From 799c50349e3e546a30c67f171944548f026f95bb Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 30 Jan 2024 16:51:06 +0100 Subject: [PATCH 05/26] increase pinecone sleep time (#288) --- integrations/pinecone/tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pinecone/tests/conftest.py b/integrations/pinecone/tests/conftest.py index c7a1342d5..79d2608f2 100644 --- a/integrations/pinecone/tests/conftest.py +++ b/integrations/pinecone/tests/conftest.py @@ -6,7 +6,7 @@ from haystack_integrations.document_stores.pinecone import PineconeDocumentStore # This is the approximate time it takes for the documents to be available -SLEEP_TIME = 20 +SLEEP_TIME = 25 @pytest.fixture() From 9014494d40a17f0678ccc45f8d51ffac1be514cf Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 30 Jan 2024 17:33:41 +0100 Subject: [PATCH 06/26] chore: Amazon Bedrock subproject refactoring (#293) * Initial migration * Update README.md * Update test instructions * Update integrations/amazon_bedrock/pyproject.toml Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Update integrations/amazon_bedrock/pyproject.toml Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Linting --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- integrations/amazon_bedrock/README.md | 19 ++++++++++++++++ integrations/amazon_bedrock/pyproject.toml | 21 ++++++++++-------- .../generators/__init__.py | 3 --- .../generators/amazon_bedrock}/__init__.py | 2 +- .../amazon_bedrock_adapters.py | 2 +- .../amazon_bedrock_handlers.py | 0 .../generators/amazon_bedrock}/errors.py | 0 .../generators/amazon_bedrock/generator.py} | 14 ++++++------ .../tests/test_amazon_bedrock.py | 22 +++++++++---------- 9 files changed, 51 insertions(+), 32 deletions(-) delete mode 100644 integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py rename integrations/amazon_bedrock/src/{amazon_bedrock_haystack => haystack_integrations/components/generators/amazon_bedrock}/__init__.py (63%) rename integrations/amazon_bedrock/src/{amazon_bedrock_haystack/generators => haystack_integrations/components/generators/amazon_bedrock}/amazon_bedrock_adapters.py (98%) rename integrations/amazon_bedrock/src/{amazon_bedrock_haystack/generators => haystack_integrations/components/generators/amazon_bedrock}/amazon_bedrock_handlers.py (100%) rename integrations/amazon_bedrock/src/{amazon_bedrock_haystack => haystack_integrations/components/generators/amazon_bedrock}/errors.py (100%) rename integrations/amazon_bedrock/src/{amazon_bedrock_haystack/generators/amazon_bedrock.py => haystack_integrations/components/generators/amazon_bedrock/generator.py} (98%) diff --git a/integrations/amazon_bedrock/README.md b/integrations/amazon_bedrock/README.md index f84c8f3c4..3a689ef3b 100644 --- a/integrations/amazon_bedrock/README.md +++ b/integrations/amazon_bedrock/README.md @@ -8,6 +8,7 @@ **Table of Contents** - [Installation](#installation) +- [Contributing](#contributing) - [License](#license) ## Installation @@ -16,6 +17,24 @@ pip install amazon-bedrock-haystack ``` +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` +> Note: there are no integration tests for this project. + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + ## License `amazon-bedrock-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/amazon_bedrock/pyproject.toml b/integrations/amazon_bedrock/pyproject.toml index 7e82924a8..6a2ce3eab 100644 --- a/integrations/amazon_bedrock/pyproject.toml +++ b/integrations/amazon_bedrock/pyproject.toml @@ -35,6 +35,9 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock" +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + [tool.hatch.version] source = "vcs" tag-pattern = 'integrations\/amazon_bedrock-v(?P.*)' @@ -71,7 +74,8 @@ dependencies = [ "ruff>=0.0.243", ] [tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/amazon_bedrock_haystack tests}" +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" + style = [ "ruff {args:.}", "black --check --diff {args:.}", @@ -136,26 +140,24 @@ unfixable = [ ] [tool.ruff.isort] -known-first-party = ["amazon_bedrock_haystack"] +known-first-party = ["haystack_integrations"] [tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" +ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] [tool.coverage.run] -source_pkgs = ["amazon_bedrock_haystack", "tests"] +source_pkgs = ["src", "tests"] branch = true parallel = true -omit = [ - "src/amazon_bedrock_haystack/__about__.py", -] + [tool.coverage.paths] -amazon_bedrock_haystack = ["src/amazon_bedrock_haystack", "*/amazon_bedrock/src/amazon_bedrock_haystack"] -tests = ["tests", "*/amazon_bedrock_haystack/tests"] +amazon_bedrock_haystack = ["src/*"] +tests = ["tests"] [tool.coverage.report] exclude_lines = [ @@ -170,6 +172,7 @@ module = [ "transformers.*", "boto3.*", "haystack.*", + "haystack_integrations.*", "pytest.*", "numpy.*", ] diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py b/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py deleted file mode 100644 index e873bc332..000000000 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py similarity index 63% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py index 3e05179c0..236347b61 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/__init__.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator +from .generator import AmazonBedrockGenerator __all__ = ["AmazonBedrockGenerator"] diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_adapters.py similarity index 98% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_adapters.py index bec172867..b7e775cb8 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_adapters.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from amazon_bedrock_haystack.generators.amazon_bedrock_handlers import TokenStreamingHandler +from .amazon_bedrock_handlers import TokenStreamingHandler class BedrockModelAdapter(ABC): diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_handlers.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock_handlers.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_handlers.py diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py similarity index 100% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/errors.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/errors.py diff --git a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py similarity index 98% rename from integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index dda84fe14..c79ef9de4 100644 --- a/integrations/amazon_bedrock/src/amazon_bedrock_haystack/generators/amazon_bedrock.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -7,12 +7,7 @@ from botocore.exceptions import BotoCoreError, ClientError from haystack import component, default_from_dict, default_to_dict -from amazon_bedrock_haystack.errors import ( - AmazonBedrockConfigurationError, - AmazonBedrockInferenceError, - AWSConfigurationError, -) -from amazon_bedrock_haystack.generators.amazon_bedrock_adapters import ( +from .amazon_bedrock_adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -20,11 +15,16 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from amazon_bedrock_haystack.generators.amazon_bedrock_handlers import ( +from .amazon_bedrock_handlers import ( DefaultPromptHandler, DefaultTokenStreamingHandler, TokenStreamingHandler, ) +from .errors import ( + AmazonBedrockConfigurationError, + AmazonBedrockInferenceError, + AWSConfigurationError, +) logger = logging.getLogger(__name__) diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index a05c95ba3..c6bb0add4 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -4,9 +4,8 @@ import pytest from botocore.exceptions import BotoCoreError -from amazon_bedrock_haystack.errors import AmazonBedrockConfigurationError -from amazon_bedrock_haystack.generators.amazon_bedrock import AmazonBedrockGenerator -from amazon_bedrock_haystack.generators.amazon_bedrock_adapters import ( +from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator +from haystack_integrations.components.generators.amazon_bedrock.amazon_bedrock_adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -14,6 +13,7 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) +from haystack_integrations.components.generators.amazon_bedrock.errors import AmazonBedrockConfigurationError @pytest.fixture @@ -34,7 +34,7 @@ def mock_boto3_session(): @pytest.fixture def mock_prompt_handler(): with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock_handlers.DefaultPromptHandler" + "haystack_integrations.components.generators.amazon_bedrock.amazon_bedrock_handlers.DefaultPromptHandler" ) as mock_prompt_handler: yield mock_prompt_handler @@ -55,7 +55,7 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session): ) expected_dict = { - "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { "model": "anthropic.claude-v2", "max_length": 99, @@ -72,7 +72,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session): """ generator = AmazonBedrockGenerator.from_dict( { - "type": "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator", + "type": "haystack_integrations.components.generators.amazon_bedrock.generator.AmazonBedrockGenerator", "init_parameters": { "model": "anthropic.claude-v2", "max_length": 99, @@ -235,7 +235,7 @@ def test_supports_for_valid_aws_configuration(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( @@ -266,7 +266,7 @@ def test_supports_for_invalid_bedrock_config(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( @@ -282,7 +282,7 @@ def test_supports_for_invalid_bedrock_config_error_on_list_models(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises(AmazonBedrockConfigurationError, match="Could not connect to Amazon Bedrock."): AmazonBedrockGenerator.supports( @@ -314,7 +314,7 @@ def test_supports_with_stream_true_for_model_that_supports_streaming(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ): supported = AmazonBedrockGenerator.supports( @@ -335,7 +335,7 @@ def test_supports_with_stream_true_for_model_that_does_not_support_streaming(): # Patch the class method to return the mock session with patch( - "amazon_bedrock_haystack.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", + "haystack_integrations.components.generators.amazon_bedrock.AmazonBedrockGenerator.get_aws_session", return_value=mock_session, ), pytest.raises( AmazonBedrockConfigurationError, From b6115c21282257984ebc0046e4013174a29c5f6e Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Tue, 30 Jan 2024 18:04:18 +0100 Subject: [PATCH 07/26] chore: Adjust amazon bedrock helper classes names (#297) * Adjust amazon bedrock helper classes names * Linting * Update tests * More linting * Small update --- .../{amazon_bedrock_adapters.py => adapters.py} | 2 +- .../generators/amazon_bedrock/generator.py | 12 ++++++------ .../{amazon_bedrock_handlers.py => handlers.py} | 0 .../amazon_bedrock/tests/test_amazon_bedrock.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) rename integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/{amazon_bedrock_adapters.py => adapters.py} (99%) rename integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/{amazon_bedrock_handlers.py => handlers.py} (100%) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_adapters.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py similarity index 99% rename from integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_adapters.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py index b7e775cb8..40ba0bc67 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_adapters.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/adapters.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from .amazon_bedrock_handlers import TokenStreamingHandler +from .handlers import TokenStreamingHandler class BedrockModelAdapter(ABC): diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py index c79ef9de4..4c43c9a09 100644 --- a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py +++ b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/generator.py @@ -7,7 +7,7 @@ from botocore.exceptions import BotoCoreError, ClientError from haystack import component, default_from_dict, default_to_dict -from .amazon_bedrock_adapters import ( +from .adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -15,16 +15,16 @@ CohereCommandAdapter, MetaLlama2ChatAdapter, ) -from .amazon_bedrock_handlers import ( - DefaultPromptHandler, - DefaultTokenStreamingHandler, - TokenStreamingHandler, -) from .errors import ( AmazonBedrockConfigurationError, AmazonBedrockInferenceError, AWSConfigurationError, ) +from .handlers import ( + DefaultPromptHandler, + DefaultTokenStreamingHandler, + TokenStreamingHandler, +) logger = logging.getLogger(__name__) diff --git a/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_handlers.py b/integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py similarity index 100% rename from integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/amazon_bedrock_handlers.py rename to integrations/amazon_bedrock/src/haystack_integrations/components/generators/amazon_bedrock/handlers.py diff --git a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py index c6bb0add4..b08e9dfd5 100644 --- a/integrations/amazon_bedrock/tests/test_amazon_bedrock.py +++ b/integrations/amazon_bedrock/tests/test_amazon_bedrock.py @@ -5,7 +5,7 @@ from botocore.exceptions import BotoCoreError from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator -from haystack_integrations.components.generators.amazon_bedrock.amazon_bedrock_adapters import ( +from haystack_integrations.components.generators.amazon_bedrock.adapters import ( AI21LabsJurassic2Adapter, AmazonTitanAdapter, AnthropicClaudeAdapter, @@ -34,7 +34,7 @@ def mock_boto3_session(): @pytest.fixture def mock_prompt_handler(): with patch( - "haystack_integrations.components.generators.amazon_bedrock.amazon_bedrock_handlers.DefaultPromptHandler" + "haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler" ) as mock_prompt_handler: yield mock_prompt_handler From 60038c064b33d8cb522ef71886b1bd85eafc7244 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 31 Jan 2024 11:34:20 +0100 Subject: [PATCH 08/26] try changing dummy vector (#301) --- .../document_stores/pinecone/document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py index a755b7e47..92ea987b4 100644 --- a/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py +++ b/integrations/pinecone/src/haystack_integrations/document_stores/pinecone/document_store.py @@ -85,7 +85,7 @@ def __init__( ) self.dimension = actual_dimension or dimension - self._dummy_vector = [0.0] * self.dimension + self._dummy_vector = [-10.0] * self.dimension self.environment = environment self.index = index self.namespace = namespace From 69803e923a8a9446f241c7fdadfa9eadbeb33a2e Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 31 Jan 2024 11:35:57 +0100 Subject: [PATCH 09/26] Add typing_extensions pin (#295) --- integrations/chroma/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index ce4641611..2653c491f 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ dependencies = [ "haystack-ai", "chromadb<0.4.20", # FIXME: investigate why filtering tests broke on 0.4.20 + "typing_extensions>=4.8.0", ] [project.urls] From dabf0712fd67db98521f97509cdbc0cd8444e910 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bilge=20Y=C3=BCcel?= Date: Wed, 31 Jan 2024 14:30:27 +0300 Subject: [PATCH 10/26] Update Breaking Change Proposal issue template (#299) * Update breaking-change-proposal.md * Update .github/ISSUE_TEMPLATE/breaking-change-proposal.md Co-authored-by: ZanSara --------- Co-authored-by: ZanSara --- .github/ISSUE_TEMPLATE/breaking-change-proposal.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/breaking-change-proposal.md b/.github/ISSUE_TEMPLATE/breaking-change-proposal.md index 71aa2a5e9..6c6fb9017 100644 --- a/.github/ISSUE_TEMPLATE/breaking-change-proposal.md +++ b/.github/ISSUE_TEMPLATE/breaking-change-proposal.md @@ -15,9 +15,12 @@ Briefly explain how the change is breaking and why is needed. ```[tasklist] ### Tasks -- [ ] The change is documented with docstrings and was merged in the `main` branch -- [ ] Integration tile on https://github.com/deepset-ai/haystack-integrations was updated +- [ ] The changes are merged in the `main` branch (Code + Docstrings) +- [ ] New package version declares the breaking change +- [ ] The package has been released on PyPI - [ ] Docs at https://docs.haystack.deepset.ai/ were updated +- [ ] Integration tile on https://github.com/deepset-ai/haystack-integrations was updated - [ ] Notebooks on https://github.com/deepset-ai/haystack-cookbook were updated (if needed) -- [ ] New package version declares the breaking change and package has been released on PyPI -``` \ No newline at end of file +- [ ] Tutorials on https://github.com/deepset-ai/haystack-tutorials were updated (if needed) +- [ ] Articles on https://github.com/deepset-ai/haystack-home/tree/main/content were updated (if needed) +``` From ae80056f6d7e3eeeded4850504659e76ec288fcf Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 31 Jan 2024 14:47:37 +0100 Subject: [PATCH 11/26] Pgvector - filters (#257) * very first draft * setup integration folder and workflow * update readme * making progress! * mypy overrides * making progress on index * drop sqlalchemy in favor of psycopggit add tests/test_document_store.py ! * good improvements! * docstrings * improve definition * small improvements * more test cases * standardize * start working on filters * inner_product * explicit create statement * address feedback * tests separation * filters - draft * change embedding_similarity_function to vector_function * explicit insert and update statements * remove useless condition * unit tests for conversion functions * tests change * simplify! * progress! * better error messages and more * cover also complex cases * fmt * make things work again * progress on simplification * further simplification * filters simplification * fmt * rm print * uncomment line * fix name * mv check filters is a dict in filter_documents * f-strings * NO_VALUE constant * handle nested logical conditions in _parse_logical_condition * add examples to _treat_meta_field * fix fmt * ellipsis fmt * more tests for unhappy paths * more tests for internal methods * black * log debug query and params --- .../pgvector/document_store.py | 51 +++- .../document_stores/pgvector/filters.py | 242 ++++++++++++++++++ integrations/pgvector/tests/conftest.py | 24 ++ .../pgvector/tests/test_document_store.py | 21 -- integrations/pgvector/tests/test_filters.py | 179 +++++++++++++ 5 files changed, 489 insertions(+), 28 deletions(-) create mode 100644 integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py create mode 100644 integrations/pgvector/tests/conftest.py create mode 100644 integrations/pgvector/tests/test_filters.py diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index bb1915a6f..b49bd87c3 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -8,6 +8,7 @@ from haystack.dataclasses.document import ByteStream, Document from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy +from haystack.utils.filters import convert from psycopg import Error, IntegrityError, connect from psycopg.abc import Query from psycopg.cursor import Cursor @@ -18,6 +19,8 @@ from pgvector.psycopg import register_vector +from .filters import _convert_filters_to_where_clause_and_params + logger = logging.getLogger(__name__) CREATE_TABLE_STATEMENT = """ @@ -158,11 +161,16 @@ def _execute_sql( params = params or () cursor = cursor or self._cursor + sql_query_str = sql_query.as_string(cursor) if not isinstance(sql_query, str) else sql_query + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, params) + try: result = cursor.execute(sql_query, params) except Error as e: self._connection.rollback() - raise DocumentStoreError(error_msg) from e + detailed_error_msg = f"{error_msg}.\nYou can find the SQL query and the parameters in the debug logs." + raise DocumentStoreError(detailed_error_msg) from e + return result def _create_table_if_not_exists(self): @@ -257,15 +265,37 @@ def count_documents(self) -> int: ] return count - def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: # noqa: ARG002 - # TODO: implement filters - sql_get_docs = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + For a detailed specification of the filters, + refer to the [documentation](https://docs.haystack.deepset.ai/v2.0/docs/metadata-filtering) + + :param filters: The filters to apply to the document list. + :return: A list of Documents that match the given filters. + """ + if filters: + if not isinstance(filters, dict): + msg = "Filters must be a dictionary" + raise TypeError(msg) + if "operator" not in filters and "conditions" not in filters: + filters = convert(filters) + + sql_filter = SQL("SELECT * FROM {table_name}").format(table_name=Identifier(self.table_name)) + + params = () + if filters: + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) + sql_filter += sql_where_clause result = self._execute_sql( - sql_get_docs, error_msg="Could not filter documents from PgvectorDocumentStore", cursor=self._dict_cursor + sql_filter, + params, + error_msg="Could not filter documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, ) - # Fetch all the records records = result.fetchall() docs = self._from_pg_to_haystack_documents(records) return docs @@ -300,6 +330,9 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D sql_insert += SQL(" RETURNING id") + sql_query_str = sql_insert.as_string(self._cursor) if not isinstance(sql_insert, str) else sql_insert + logger.debug("SQL query: %s\nParameters: %s", sql_query_str, db_documents) + try: self._cursor.executemany(sql_insert, db_documents, returning=True) except IntegrityError as ie: @@ -307,7 +340,11 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D raise DuplicateDocumentError from ie except Error as e: self._connection.rollback() - raise DocumentStoreError from e + error_msg = ( + "Could not write documents to PgvectorDocumentStore. \n" + "You can find the SQL query and the parameters in the debug logs." + ) + raise DocumentStoreError(error_msg) from e # get the number of the inserted documents, inspired by psycopg3 docs # https://www.psycopg.org/psycopg3/docs/api/cursors.html#psycopg.Cursor.executemany diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py new file mode 100644 index 000000000..daa90f502 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/filters.py @@ -0,0 +1,242 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from datetime import datetime +from itertools import chain +from typing import Any, Dict, List + +from haystack.errors import FilterError +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + +# we need this mapping to cast meta values to the correct type, +# since they are stored in the JSONB field as strings. +# this dict can be extended if needed +PYTHON_TYPES_TO_PG_TYPES = { + int: "integer", + float: "real", + bool: "boolean", +} + +NO_VALUE = "no_value" + + +def _convert_filters_to_where_clause_and_params(filters: Dict[str, Any]) -> tuple[SQL, tuple]: + """ + Convert Haystack filters to a WHERE clause and a tuple of params to query PostgreSQL. + """ + if "field" in filters: + query, values = _parse_comparison_condition(filters) + else: + query, values = _parse_logical_condition(filters) + + where_clause = SQL(" WHERE ") + SQL(query) + params = tuple(value for value in values if value != NO_VALUE) + + return where_clause, params + + +def _parse_logical_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "conditions" not in condition: + msg = f"'conditions' key missing in {condition}" + raise FilterError(msg) + + operator = condition["operator"] + if operator not in ["AND", "OR"]: + msg = f"Unknown logical operator '{operator}'. Valid operators are: 'AND', 'OR'" + raise FilterError(msg) + + # logical conditions can be nested, so we need to parse them recursively + conditions = [] + for c in condition["conditions"]: + if "field" in c: + query, vals = _parse_comparison_condition(c) + else: + query, vals = _parse_logical_condition(c) + conditions.append((query, vals)) + + query_parts, values = [], [] + for c in conditions: + query_parts.append(c[0]) + values.append(c[1]) + if isinstance(values[0], list): + values = list(chain.from_iterable(values)) + + if operator == "AND": + sql_query = f"({' AND '.join(query_parts)})" + elif operator == "OR": + sql_query = f"({' OR '.join(query_parts)})" + else: + msg = f"Unknown logical operator '{operator}'" + raise FilterError(msg) + + return sql_query, values + + +def _parse_comparison_condition(condition: Dict[str, Any]) -> tuple[str, List[Any]]: + field: str = condition["field"] + if "operator" not in condition: + msg = f"'operator' key missing in {condition}" + raise FilterError(msg) + if "value" not in condition: + msg = f"'value' key missing in {condition}" + raise FilterError(msg) + operator: str = condition["operator"] + if operator not in COMPARISON_OPERATORS: + msg = f"Unknown comparison operator '{operator}'. Valid operators are: {list(COMPARISON_OPERATORS.keys())}" + raise FilterError(msg) + + value: Any = condition["value"] + if isinstance(value, DataFrame): + # DataFrames are stored as JSONB and we query them as such + value = Jsonb(value.to_json()) + field = f"({field})::jsonb" + + if field.startswith("meta."): + field = _treat_meta_field(field, value) + + field, value = COMPARISON_OPERATORS[operator](field, value) + return field, [value] + + +def _treat_meta_field(field: str, value: Any) -> str: + """ + Internal method that modifies the field str + to make the meta JSONB field queryable. + + Examples: + >>> _treat_meta_field(field="meta.number", value=9) + "(meta->>'number')::integer" + + >>> _treat_meta_field(field="meta.name", value="my_name") + "meta->>'name'" + """ + + # use the ->> operator to access keys in the meta JSONB field + field_name = field.split(".", 1)[-1] + field = f"meta->>'{field_name}'" + + # meta fields are stored as strings in the JSONB field, + # so we need to cast them to the correct type + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value)) + if isinstance(value, list) and len(value) > 0: + type_value = PYTHON_TYPES_TO_PG_TYPES.get(type(value[0])) + + if type_value: + field = f"({field})::{type_value}" + + return field + + +def _equal(field: str, value: Any) -> tuple[str, Any]: + if value is None: + # NO_VALUE is a placeholder that will be removed in _convert_filters_to_where_clause_and_params + return f"{field} IS NULL", NO_VALUE + return f"{field} = %s", value + + +def _not_equal(field: str, value: Any) -> tuple[str, Any]: + # we use IS DISTINCT FROM to correctly handle NULL values + # (not handled by !=) + return f"{field} IS DISTINCT FROM %s", value + + +def _greater_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} > %s", value + + +def _greater_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} >= %s", value + + +def _less_than(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} < %s", value + + +def _less_than_equal(field: str, value: Any) -> tuple[str, Any]: + if isinstance(value, str): + try: + datetime.fromisoformat(value) + except (ValueError, TypeError) as exc: + msg = ( + "Can't compare strings using operators '>', '>=', '<', '<='. " + "Strings are only comparable if they are ISO formatted dates." + ) + raise FilterError(msg) from exc + if type(value) in [list, Jsonb]: + msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" + raise FilterError(msg) + + return f"{field} <= %s", value + + +def _not_in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'not in' comparator in Pinecone" + raise FilterError(msg) + + return f"{field} IS NULL OR {field} != ALL(%s)", [value] + + +def _in(field: str, value: Any) -> tuple[str, List]: + if not isinstance(value, list): + msg = f"{field}'s value must be a list when using 'in' comparator in Pinecone" + raise FilterError(msg) + + # see https://www.psycopg.org/psycopg3/docs/basic/adapt.html#lists-adaptation + return f"{field} = ANY(%s)", [value] + + +COMPARISON_OPERATORS = { + "==": _equal, + "!=": _not_equal, + ">": _greater_than, + ">=": _greater_than_equal, + "<": _less_than, + "<=": _less_than_equal, + "in": _in, + "not in": _not_in, +} diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py new file mode 100644 index 000000000..34260f409 --- /dev/null +++ b/integrations/pgvector/tests/conftest.py @@ -0,0 +1,24 @@ +import pytest +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +@pytest.fixture +def document_store(request): + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + table_name = f"haystack_{request.node.name}" + embedding_dimension = 768 + vector_function = "cosine_distance" + recreate_table = True + search_strategy = "exact_nearest_neighbor" + + store = PgvectorDocumentStore( + connection_string=connection_string, + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + yield store + + store.delete_table() diff --git a/integrations/pgvector/tests/test_document_store.py b/integrations/pgvector/tests/test_document_store.py index 9f3521838..e8d9107d7 100644 --- a/integrations/pgvector/tests/test_document_store.py +++ b/integrations/pgvector/tests/test_document_store.py @@ -14,27 +14,6 @@ class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): - @pytest.fixture - def document_store(self, request): - connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" - table_name = f"haystack_{request.node.name}" - embedding_dimension = 768 - vector_function = "cosine_distance" - recreate_table = True - search_strategy = "exact_nearest_neighbor" - - store = PgvectorDocumentStore( - connection_string=connection_string, - table_name=table_name, - embedding_dimension=embedding_dimension, - vector_function=vector_function, - recreate_table=recreate_table, - search_strategy=search_strategy, - ) - yield store - - store.delete_table() - def test_write_documents(self, document_store: PgvectorDocumentStore): docs = [Document(id="1")] assert document_store.write_documents(docs) == 1 diff --git a/integrations/pgvector/tests/test_filters.py b/integrations/pgvector/tests/test_filters.py new file mode 100644 index 000000000..8b2dc8ec9 --- /dev/null +++ b/integrations/pgvector/tests/test_filters.py @@ -0,0 +1,179 @@ +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack.testing.document_store import FilterDocumentsTest +from haystack_integrations.document_stores.pgvector.filters import ( + FilterError, + _convert_filters_to_where_clause_and_params, + _parse_comparison_condition, + _parse_logical_condition, + _treat_meta_field, +) +from pandas import DataFrame +from psycopg.sql import SQL +from psycopg.types.json import Jsonb + + +class TestFilters(FilterDocumentsTest): + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + This overrides the default assert_documents_are_equal from FilterDocumentsTest. + It is needed because the embeddings are not exactly the same when they are retrieved from Postgres. + """ + + assert len(received) == len(expected) + received.sort(key=lambda x: x.id) + expected.sort(key=lambda x: x.id) + for received_doc, expected_doc in zip(received, expected): + # we first compare the embeddings approximately + if received_doc.embedding is None: + assert expected_doc.embedding is None + else: + assert received_doc.embedding == pytest.approx(expected_doc.embedding) + + received_doc.embedding, expected_doc.embedding = None, None + assert received_doc == expected_doc + + def test_complex_filter(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + filters = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"field": "meta.page", "operator": "==", "value": "90"}, + {"field": "meta.chapter", "operator": "==", "value": "conclusion"}, + ], + }, + ], + } + + result = document_store.filter_documents(filters=filters) + + self.assert_documents_are_equal( + result, + [ + d + for d in filterable_docs + if (d.meta.get("number") == 100 and d.meta.get("chapter") == "intro") + or (d.meta.get("page") == "90" and d.meta.get("chapter") == "conclusion") + ], + ) + + @pytest.mark.skip(reason="NOT operator is not supported in PgvectorDocumentStore") + def test_not_operator(self, document_store, filterable_docs): ... + + def test_treat_meta_field(self): + assert _treat_meta_field(field="meta.number", value=9) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.number", value=[1, 2, 3]) == "(meta->>'number')::integer" + assert _treat_meta_field(field="meta.name", value="my_name") == "meta->>'name'" + assert _treat_meta_field(field="meta.name", value=["my_name"]) == "meta->>'name'" + assert _treat_meta_field(field="meta.number", value=1.1) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.number", value=[1.1, 2.2, 3.3]) == "(meta->>'number')::real" + assert _treat_meta_field(field="meta.bool", value=True) == "(meta->>'bool')::boolean" + assert _treat_meta_field(field="meta.bool", value=[True, False, True]) == "(meta->>'bool')::boolean" + + # do not cast the field if its value is not one of the known types, an empty list or None + assert _treat_meta_field(field="meta.other", value={"a": 3, "b": "example"}) == "meta->>'other'" + assert _treat_meta_field(field="meta.empty_list", value=[]) == "meta->>'empty_list'" + assert _treat_meta_field(field="meta.name", value=None) == "meta->>'name'" + + def test_comparison_condition_dataframe_jsonb_conversion(self): + dataframe = DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + condition = {"field": "meta.df", "operator": "==", "value": dataframe} + field, values = _parse_comparison_condition(condition) + assert field == "(meta.df)::jsonb = %s" + + # we check each slot of the Jsonb object because it does not implement __eq__ + assert values[0].obj == Jsonb(dataframe.to_json()).obj + assert values[0].dumps == Jsonb(dataframe.to_json()).dumps + + def test_comparison_condition_missing_operator(self): + condition = {"field": "meta.type", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_missing_value(self): + condition = {"field": "meta.type", "operator": "=="} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_comparison_condition_unknown_operator(self): + condition = {"field": "meta.type", "operator": "unknown", "value": "article"} + with pytest.raises(FilterError): + _parse_comparison_condition(condition) + + def test_logical_condition_missing_operator(self): + condition = {"conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_missing_conditions(self): + condition = {"operator": "AND"} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_unknown_operator(self): + condition = {"operator": "unknown", "conditions": []} + with pytest.raises(FilterError): + _parse_logical_condition(condition) + + def test_logical_condition_nested(self): + condition = { + "operator": "AND", + "conditions": [ + { + "operator": "OR", + "conditions": [ + {"field": "meta.domain", "operator": "!=", "value": "science"}, + {"field": "meta.chapter", "operator": "in", "value": ["intro", "conclusion"]}, + ], + }, + { + "operator": "OR", + "conditions": [ + {"field": "meta.number", "operator": ">=", "value": 90}, + {"field": "meta.author", "operator": "not in", "value": ["John", "Jane"]}, + ], + }, + ], + } + query, values = _parse_logical_condition(condition) + assert query == ( + "((meta->>'domain' IS DISTINCT FROM %s OR meta->>'chapter' = ANY(%s)) " + "AND ((meta->>'number')::integer >= %s OR meta->>'author' IS NULL OR meta->>'author' != ALL(%s)))" + ) + assert values == ["science", [["intro", "conclusion"]], 90, [["John", "Jane"]]] + + def test_convert_filters_to_where_clause_and_params(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("((meta->>'number')::integer = %s AND meta->>'chapter' = %s)") + assert params == (100, "intro") + + def test_convert_filters_to_where_clause_and_params_handle_null(self): + filters = { + "operator": "AND", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": None}, + {"field": "meta.chapter", "operator": "==", "value": "intro"}, + ], + } + where_clause, params = _convert_filters_to_where_clause_and_params(filters) + assert where_clause == SQL(" WHERE ") + SQL("(meta->>'number' IS NULL AND meta->>'chapter' = %s)") + assert params == ("intro",) From 0d15e3675785a4db745b98a7c53f235ced57c7a2 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Wed, 31 Jan 2024 17:43:14 +0100 Subject: [PATCH 12/26] Pgvector - embedding retrieval (#298) * squash * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * fix fmt --------- Co-authored-by: Massimiliano Pippi --- .../pgvector/document_store.py | 102 +++++++++++++- integrations/pgvector/tests/conftest.py | 2 +- .../tests/test_embedding_retrieval.py | 130 ++++++++++++++++++ 3 files changed, 229 insertions(+), 5 deletions(-) create mode 100644 integrations/pgvector/tests/test_embedding_retrieval.py diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index b49bd87c3..0abaaecce 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -52,8 +52,10 @@ meta = EXCLUDED.meta """ +VALID_VECTOR_FUNCTIONS = ["cosine_similarity", "inner_product", "l2_distance"] + VECTOR_FUNCTION_TO_POSTGRESQL_OPS = { - "cosine_distance": "vector_cosine_ops", + "cosine_similarity": "vector_cosine_ops", "inner_product": "vector_ip_ops", "l2_distance": "vector_l2_ops", } @@ -70,7 +72,7 @@ def __init__( connection_string: str, table_name: str = "haystack_documents", embedding_dimension: int = 768, - vector_function: Literal["cosine_distance", "inner_product", "l2_distance"] = "cosine_distance", + vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] = "cosine_similarity", recreate_table: bool = False, search_strategy: Literal["exact_nearest_neighbor", "hnsw"] = "exact_nearest_neighbor", hnsw_recreate_index_if_exists: bool = False, @@ -87,12 +89,23 @@ def __init__( :param table_name: The name of the table to use to store Haystack documents. Defaults to "haystack_documents". :param embedding_dimension: The dimension of the embedding. Defaults to 768. :param vector_function: The similarity function to use when searching for similar embeddings. - Defaults to "cosine_distance". Set it to one of the following values: - :type vector_function: Literal["cosine_distance", "inner_product", "l2_distance"] + Defaults to "cosine_similarity". "cosine_similarity" and "inner_product" are similarity functions and + higher scores indicate greater similarity between the documents. + "l2_distance" returns the straight-line distance between vectors, + and the most similar documents are the ones with the smallest score. + + Important: when using the "hnsw" search strategy, an index will be created that depends on the + `vector_function` passed here. Make sure subsequent queries will keep using the same + vector similarity function in order to take advantage of the index. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] :param recreate_table: Whether to recreate the table if it already exists. Defaults to False. :param search_strategy: The search strategy to use when searching for similar embeddings. Defaults to "exact_nearest_neighbor". "hnsw" is an approximate nearest neighbor search strategy, which trades off some accuracy for speed; it is recommended for large numbers of documents. + + Important: when using the "hnsw" search strategy, an index will be created that depends on the + `vector_function` passed here. Make sure subsequent queries will keep using the same + vector similarity function in order to take advantage of the index. :type search_strategy: Literal["exact_nearest_neighbor", "hnsw"] :param hnsw_recreate_index_if_exists: Whether to recreate the HNSW index if it already exists. Defaults to False. Only used if search_strategy is set to "hnsw". @@ -107,6 +120,9 @@ def __init__( self.connection_string = connection_string self.table_name = table_name self.embedding_dimension = embedding_dimension + if vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" + raise ValueError(msg) self.vector_function = vector_function self.recreate_table = recreate_table self.search_strategy = search_strategy @@ -423,3 +439,81 @@ def delete_documents(self, document_ids: List[str]) -> None: ) self._execute_sql(delete_sql, error_msg="Could not delete documents from PgvectorDocumentStore") + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + + This method is not meant to be part of the public interface of + `PgvectorDocumentStore` and it should not be called directly. + `PgvectorEmbeddingRetriever` uses this method directly and is the public interface for it. + :raises ValueError + :return: List of Documents that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + if len(query_embedding) != self.embedding_dimension: + msg = ( + f"query_embedding dimension ({len(query_embedding)}) does not match PgvectorDocumentStore " + f"embedding dimension ({self.embedding_dimension})." + ) + raise ValueError(msg) + + vector_function = vector_function or self.vector_function + if vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}, but got {vector_function}" + raise ValueError(msg) + + # the vector must be a string with this format: "'[3,1,2]'" + query_embedding_for_postgres = f"'[{','.join(str(el) for el in query_embedding)}]'" + + # to compute the scores, we use the approach described in pgvector README: + # https://github.com/pgvector/pgvector?tab=readme-ov-file#distances + # cosine_similarity and inner_product are modified from the result of the operator + if vector_function == "cosine_similarity": + score_definition = f"1 - (embedding <=> {query_embedding_for_postgres}) AS score" + elif vector_function == "inner_product": + score_definition = f"(embedding <#> {query_embedding_for_postgres}) * -1 AS score" + elif vector_function == "l2_distance": + score_definition = f"embedding <-> {query_embedding_for_postgres} AS score" + + sql_select = SQL("SELECT *, {score} FROM {table_name}").format( + table_name=Identifier(self.table_name), + score=SQL(score_definition), + ) + + sql_where_clause = SQL("") + params = () + if filters: + sql_where_clause, params = _convert_filters_to_where_clause_and_params(filters) + + # we always want to return the most similar documents first + # so when using l2_distance, the sort order must be ASC + sort_order = "ASC" if vector_function == "l2_distance" else "DESC" + + sql_sort = SQL(" ORDER BY score {sort_order} LIMIT {top_k}").format( + top_k=SQLLiteral(top_k), + sort_order=SQL(sort_order), + ) + + sql_query = sql_select + sql_where_clause + sql_sort + + result = self._execute_sql( + sql_query, + params, + error_msg="Could not retrieve documents from PgvectorDocumentStore.", + cursor=self._dict_cursor, + ) + + records = result.fetchall() + docs = self._from_pg_to_haystack_documents(records) + return docs diff --git a/integrations/pgvector/tests/conftest.py b/integrations/pgvector/tests/conftest.py index 34260f409..743e8de14 100644 --- a/integrations/pgvector/tests/conftest.py +++ b/integrations/pgvector/tests/conftest.py @@ -7,7 +7,7 @@ def document_store(request): connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" table_name = f"haystack_{request.node.name}" embedding_dimension = 768 - vector_function = "cosine_distance" + vector_function = "cosine_similarity" recreate_table = True search_strategy = "exact_nearest_neighbor" diff --git a/integrations/pgvector/tests/test_embedding_retrieval.py b/integrations/pgvector/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..1d5e8e297 --- /dev/null +++ b/integrations/pgvector/tests/test_embedding_retrieval.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import pytest +from haystack.dataclasses.document import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from numpy.random import rand + + +class TestEmbeddingRetrieval: + @pytest.fixture + def document_store_w_hnsw_index(self, request): + connection_string = "postgresql://postgres:postgres@localhost:5432/postgres" + table_name = f"haystack_hnsw_{request.node.name}" + embedding_dimension = 768 + vector_function = "cosine_similarity" + recreate_table = True + search_strategy = "hnsw" + + store = PgvectorDocumentStore( + connection_string=connection_string, + table_name=table_name, + embedding_dimension=embedding_dimension, + vector_function=vector_function, + recreate_table=recreate_table, + search_strategy=search_strategy, + ) + yield store + + store.delete_table() + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_cosine_similarity(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (cosine sim)", embedding=most_similar_embedding), + Document(content="2nd best document (cosine sim)", embedding=second_best_embedding), + Document(content="Not very similar document (cosine sim)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="cosine_similarity" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (cosine sim)" + assert results[1].content == "2nd best document (cosine sim)" + assert results[0].score > results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_inner_product(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.8] * 768 + second_best_embedding = [0.8] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (inner product)", embedding=most_similar_embedding), + Document(content="2nd best document (inner product)", embedding=second_best_embedding), + Document(content="Not very similar document (inner product)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="inner_product" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (inner product)" + assert results[1].content == "2nd best document (inner product)" + assert results[0].score > results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_l2_distance(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 768 + most_similar_embedding = [0.1] * 765 + [0.15] * 3 + second_best_embedding = [0.1] * 700 + [0.1] * 3 + [0.2] * 65 + another_embedding = rand(768).tolist() + + docs = [ + Document(content="Most similar document (l2 dist)", embedding=most_similar_embedding), + Document(content="2nd best document (l2 dist)", embedding=second_best_embedding), + Document(content="Not very similar document (l2 dist)", embedding=another_embedding), + ] + + document_store.write_documents(docs) + + results = document_store._embedding_retrieval( + query_embedding=query_embedding, top_k=2, filters={}, vector_function="l2_distance" + ) + assert len(results) == 2 + assert results[0].content == "Most similar document (l2 dist)" + assert results[1].content == "2nd best document (l2 dist)" + assert results[0].score < results[1].score + + @pytest.mark.parametrize("document_store", ["document_store", "document_store_w_hnsw_index"], indirect=True) + def test_embedding_retrieval_with_filters(self, document_store: PgvectorDocumentStore): + docs = [Document(content=f"Document {i}", embedding=rand(768).tolist()) for i in range(10)] + + for i in range(10): + docs[i].meta["meta_field"] = "custom_value" if i % 2 == 0 else "other_value" + + document_store.write_documents(docs) + + query_embedding = [0.1] * 768 + filters = {"field": "meta.meta_field", "operator": "==", "value": "custom_value"} + + results = document_store._embedding_retrieval(query_embedding=query_embedding, top_k=3, filters=filters) + assert len(results) == 3 + for result in results: + assert result.meta["meta_field"] == "custom_value" + assert results[0].score > results[1].score > results[2].score + + def test_empty_query_embedding(self, document_store: PgvectorDocumentStore): + query_embedding: List[float] = [] + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) + + def test_query_embedding_wrong_dimension(self, document_store: PgvectorDocumentStore): + query_embedding = [0.1] * 4 + with pytest.raises(ValueError): + document_store._embedding_retrieval(query_embedding=query_embedding) From f56905a23a8e195025d8f1abd78caa4b339fe108 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 1 Feb 2024 09:58:14 +0100 Subject: [PATCH 13/26] elasticsearch: generate api docs (#322) * add api docs * working dir * typo --- .github/workflows/elasticsearch.yml | 11 ++++++-- integrations/elasticsearch/pydoc/config.yml | 31 +++++++++++++++++++++ integrations/elasticsearch/pyproject.toml | 4 +++ 3 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 integrations/elasticsearch/pydoc/config.yml diff --git a/.github/workflows/elasticsearch.yml b/.github/workflows/elasticsearch.yml index eb2c1748d..688e5c48f 100644 --- a/.github/workflows/elasticsearch.yml +++ b/.github/workflows/elasticsearch.yml @@ -10,6 +10,10 @@ on: - "integrations/elasticsearch/**" - ".github/workflows/elasticsearch.yml" +defaults: + run: + working-directory: integrations/elasticsearch + concurrency: group: elasticsearch-${{ github.head_ref }} cancel-in-progress: true @@ -40,14 +44,15 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/elasticsearch if: matrix.python-version == '3.9' run: hatch run lint:all - name: Run ElasticSearch container - working-directory: integrations/elasticsearch run: docker-compose up -d + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests - working-directory: integrations/elasticsearch run: hatch run cov diff --git a/integrations/elasticsearch/pydoc/config.yml b/integrations/elasticsearch/pydoc/config.yml new file mode 100644 index 000000000..dc5a090bc --- /dev/null +++ b/integrations/elasticsearch/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.elasticsearch.bm25_retriever", + "haystack_integrations.components.retrievers.elasticsearch.embedding_retriever", + "haystack_integrations.document_stores.elasticsearch.document_store", + "haystack_integrations.document_stores.elasticsearch.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Elasticsearch integration for Haystack + category_slug: haystack-integrations + title: Elasticsearch + slug: integrations-elasticsearch + order: 50 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_elasticsearch.md \ No newline at end of file diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index af3d89c0c..b67df7e03 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "pytest-xdist", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -61,6 +62,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] From bdee9332c964332e99b0f1d7a87870856ad5dbdd Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 1 Feb 2024 12:22:50 +0100 Subject: [PATCH 14/26] Pinecone - decrease concurrency in tests (#323) * pinecone - decrease concurrency * decrease more sleep time --- integrations/pinecone/pyproject.toml | 4 ++-- integrations/pinecone/tests/conftest.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/pinecone/pyproject.toml b/integrations/pinecone/pyproject.toml index 2d73cdf58..c95ee0aac 100644 --- a/integrations/pinecone/pyproject.toml +++ b/integrations/pinecone/pyproject.toml @@ -54,8 +54,8 @@ dependencies = [ [tool.hatch.envs.default.scripts] # Pinecone tests are slow (require HTTP requests), so we run them in parallel # with pytest-xdist (https://pytest-xdist.readthedocs.io/en/stable/distribution.html) -test = "pytest -n auto --maxprocesses=3 {args:tests}" -test-cov = "coverage run -m pytest -n auto --maxprocesses=3 {args:tests}" +test = "pytest -n auto --maxprocesses=2 {args:tests}" +test-cov = "coverage run -m pytest -n auto --maxprocesses=2 {args:tests}" cov-report = [ "- coverage combine", "coverage report", diff --git a/integrations/pinecone/tests/conftest.py b/integrations/pinecone/tests/conftest.py index 79d2608f2..c7a1342d5 100644 --- a/integrations/pinecone/tests/conftest.py +++ b/integrations/pinecone/tests/conftest.py @@ -6,7 +6,7 @@ from haystack_integrations.document_stores.pinecone import PineconeDocumentStore # This is the approximate time it takes for the documents to be available -SLEEP_TIME = 25 +SLEEP_TIME = 20 @pytest.fixture() From 61daacb3f9e63af8f9df67cada8288ad333ad074 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 1 Feb 2024 15:14:29 +0100 Subject: [PATCH 15/26] add pydocconf (#321) --- .github/workflows/cohere.yml | 4 ++++ integrations/cohere/pydoc/config.yml | 32 ++++++++++++++++++++++++++++ integrations/cohere/pyproject.toml | 4 ++++ 3 files changed, 40 insertions(+) create mode 100644 integrations/cohere/pydoc/config.yml diff --git a/.github/workflows/cohere.yml b/.github/workflows/cohere.yml index 0f0030ec1..562556e47 100644 --- a/.github/workflows/cohere.yml +++ b/.github/workflows/cohere.yml @@ -52,5 +52,9 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests run: hatch run cov \ No newline at end of file diff --git a/integrations/cohere/pydoc/config.yml b/integrations/cohere/pydoc/config.yml new file mode 100644 index 000000000..9418739b5 --- /dev/null +++ b/integrations/cohere/pydoc/config.yml @@ -0,0 +1,32 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.embedders.cohere.document_embedder", + "haystack_integrations.components.embedders.cohere.text_embedder", + "haystack_integrations.components.embedders.cohere.utils", + "haystack_integrations.components.generators.cohere.generator", + "haystack_integrations.components.generators.cohere.chat.chat_generator", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Cohere integration for Haystack + category_slug: haystack-integrations + title: Cohere + slug: integrations-cohere + order: 40 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_cohere.md diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 42349d9fb..332471674 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -49,6 +49,7 @@ git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -61,6 +62,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.7", "3.8", "3.9", "3.10", "3.11"] From 3454815095b539558cdda083c6d51f76ed2b12ea Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Thu, 1 Feb 2024 17:01:26 +0100 Subject: [PATCH 16/26] Pgvector - Embedding Retriever (#320) * squash * squash * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py Co-authored-by: Massimiliano Pippi * fix fmt * adjust docstrings * Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py Co-authored-by: Massimiliano Pippi * Update integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py Co-authored-by: Massimiliano Pippi * improve docstrings * fmt --------- Co-authored-by: Massimiliano Pippi --- .../retrievers/pgvector/__init__.py | 6 + .../pgvector/embedding_retriever.py | 104 ++++++++++++++++ integrations/pgvector/tests/test_retriever.py | 112 ++++++++++++++++++ 3 files changed, 222 insertions(+) create mode 100644 integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py create mode 100644 integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py create mode 100644 integrations/pgvector/tests/test_retriever.py diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py new file mode 100644 index 000000000..ec0cf0dc4 --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from .embedding_retriever import PgvectorEmbeddingRetriever + +__all__ = ["PgvectorEmbeddingRetriever"] diff --git a/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py new file mode 100644 index 000000000..26807e9bd --- /dev/null +++ b/integrations/pgvector/src/haystack_integrations/components/retrievers/pgvector/embedding_retriever.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Literal, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore +from haystack_integrations.document_stores.pgvector.document_store import VALID_VECTOR_FUNCTIONS + + +@component +class PgvectorEmbeddingRetriever: + """ + Retrieves documents from the PgvectorDocumentStore, based on their dense embeddings. + + Needs to be connected to the PgvectorDocumentStore. + """ + + def __init__( + self, + *, + document_store: PgvectorDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Create the PgvectorEmbeddingRetriever component. + + :param document_store: An instance of PgvectorDocumentStore. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + :param top_k: Maximum number of Documents to return, defaults to 10. + :param vector_function: The similarity function to use when searching for similar embeddings. + Defaults to the one set in the `document_store` instance. + "cosine_similarity" and "inner_product" are similarity functions and + higher scores indicate greater similarity between the documents. + "l2_distance" returns the straight-line distance between vectors, + and the most similar documents are the ones with the smallest score. + + Important: if the document store is using the "hnsw" search strategy, the vector function + should match the one utilized during index creation to take advantage of the index. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + + :raises ValueError: If `document_store` is not an instance of PgvectorDocumentStore. + """ + if not isinstance(document_store, PgvectorDocumentStore): + msg = "document_store must be an instance of PgvectorDocumentStore" + raise ValueError(msg) + + if vector_function and vector_function not in VALID_VECTOR_FUNCTIONS: + msg = f"vector_function must be one of {VALID_VECTOR_FUNCTIONS}" + raise ValueError(msg) + + self.document_store = document_store + self.filters = filters or {} + self.top_k = top_k + self.vector_function = vector_function or document_store.vector_function + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self.filters, + top_k=self.top_k, + vector_function=self.vector_function, + document_store=self.document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "PgvectorEmbeddingRetriever": + data["init_parameters"]["document_store"] = default_from_dict( + PgvectorDocumentStore, data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + vector_function: Optional[Literal["cosine_similarity", "inner_product", "l2_distance"]] = None, + ): + """ + Retrieve documents from the PgvectorDocumentStore, based on their embeddings. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. + :param top_k: Maximum number of Documents to return. + :param vector_function: The similarity function to use when searching for similar embeddings. + :type vector_function: Literal["cosine_similarity", "inner_product", "l2_distance"] + :return: List of Documents similar to `query_embedding`. + """ + filters = filters or self.filters + top_k = top_k or self.top_k + vector_function = vector_function or self.vector_function + + docs = self.document_store._embedding_retrieval( + query_embedding=query_embedding, + filters=filters, + top_k=top_k, + vector_function=vector_function, + ) + return {"documents": docs} diff --git a/integrations/pgvector/tests/test_retriever.py b/integrations/pgvector/tests/test_retriever.py new file mode 100644 index 000000000..cca6bbc9f --- /dev/null +++ b/integrations/pgvector/tests/test_retriever.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock + +from haystack.dataclasses import Document +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + + +class TestRetriever: + def test_init_default(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever(document_store=document_store) + assert retriever.document_store == document_store + assert retriever.filters == {} + assert retriever.top_k == 10 + assert retriever.vector_function == document_store.vector_function + + def test_init(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + assert retriever.document_store == document_store + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_to_dict(self, document_store: PgvectorDocumentStore): + retriever = PgvectorEmbeddingRetriever( + document_store=document_store, filters={"field": "value"}, top_k=5, vector_function="l2_distance" + ) + res = retriever.to_dict() + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + assert res == { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + def test_from_dict(self): + t = "haystack_integrations.components.retrievers.pgvector.embedding_retriever.PgvectorEmbeddingRetriever" + data = { + "type": t, + "init_parameters": { + "document_store": { + "type": "haystack_integrations.document_stores.pgvector.document_store.PgvectorDocumentStore", + "init_parameters": { + "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", + "table_name": "haystack_test_to_dict", + "embedding_dimension": 768, + "vector_function": "cosine_similarity", + "recreate_table": True, + "search_strategy": "exact_nearest_neighbor", + "hnsw_recreate_index_if_exists": False, + "hnsw_index_creation_kwargs": {}, + "hnsw_ef_search": None, + }, + }, + "filters": {"field": "value"}, + "top_k": 5, + "vector_function": "l2_distance", + }, + } + + retriever = PgvectorEmbeddingRetriever.from_dict(data) + document_store = retriever.document_store + + assert isinstance(document_store, PgvectorDocumentStore) + assert document_store.connection_string == "postgresql://postgres:postgres@localhost:5432/postgres" + assert document_store.table_name == "haystack_test_to_dict" + assert document_store.embedding_dimension == 768 + assert document_store.vector_function == "cosine_similarity" + assert document_store.recreate_table + assert document_store.search_strategy == "exact_nearest_neighbor" + assert not document_store.hnsw_recreate_index_if_exists + assert document_store.hnsw_index_creation_kwargs == {} + assert document_store.hnsw_ef_search is None + + assert retriever.filters == {"field": "value"} + assert retriever.top_k == 5 + assert retriever.vector_function == "l2_distance" + + def test_run(self): + mock_store = Mock(spec=PgvectorDocumentStore) + doc = Document(content="Test doc", embedding=[0.1, 0.2]) + mock_store._embedding_retrieval.return_value = [doc] + + retriever = PgvectorEmbeddingRetriever(document_store=mock_store, vector_function="l2_distance") + res = retriever.run(query_embedding=[0.3, 0.5]) + + mock_store._embedding_retrieval.assert_called_once_with( + query_embedding=[0.3, 0.5], filters={}, top_k=10, vector_function="l2_distance" + ) + + assert res == {"documents": [doc]} From d477a21e56b216c937bcd1363b43d8981ac7a7ac Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Fri, 2 Feb 2024 08:38:01 +0100 Subject: [PATCH 17/26] astra: generate api docs (#327) --- .github/workflows/astra.yml | 4 ++++ integrations/astra/pydoc/config.yml | 30 +++++++++++++++++++++++++++++ integrations/astra/pyproject.toml | 5 ++++- 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 integrations/astra/pydoc/config.yml diff --git a/.github/workflows/astra.yml b/.github/workflows/astra.yml index b751550de..a1aab7154 100644 --- a/.github/workflows/astra.yml +++ b/.github/workflows/astra.yml @@ -53,6 +53,10 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests env: ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} diff --git a/integrations/astra/pydoc/config.yml b/integrations/astra/pydoc/config.yml new file mode 100644 index 000000000..68cc1c809 --- /dev/null +++ b/integrations/astra/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.astra.retriever", + "haystack_integrations.document_stores.astra.document_store", + "haystack_integrations.document_stores.astra.errors", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Astra integration for Haystack + category_slug: haystack-integrations + title: Astra + slug: integrations-astra + order: 20 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_astra.md \ No newline at end of file diff --git a/integrations/astra/pyproject.toml b/integrations/astra/pyproject.toml index 6b4e2565d..7599797a8 100644 --- a/integrations/astra/pyproject.toml +++ b/integrations/astra/pyproject.toml @@ -50,6 +50,7 @@ git_describe_command = 'git describe --tags --match="integrations/astra-v[0-9]*" dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -62,7 +63,9 @@ cov = [ "test-cov", "cov-report", ] - +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.7", "3.8", "3.9", "3.10", "3.11"] From 8c96def2b6abec323c71221935187f45c381ee9a Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 2 Feb 2024 08:44:04 +0100 Subject: [PATCH 18/26] opensearch: generate API docs (#324) --- .github/workflows/opensearch.yml | 10 ++++++-- integrations/opensearch/pydoc/config.yml | 31 ++++++++++++++++++++++++ integrations/opensearch/pyproject.toml | 5 ++++ 3 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 integrations/opensearch/pydoc/config.yml diff --git a/.github/workflows/opensearch.yml b/.github/workflows/opensearch.yml index aacb4ce71..72a01d090 100644 --- a/.github/workflows/opensearch.yml +++ b/.github/workflows/opensearch.yml @@ -18,6 +18,10 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" +defaults: + run: + working-directory: integrations/opensearch + jobs: run: name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} @@ -40,14 +44,16 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/opensearch if: matrix.python-version == '3.9' run: hatch run lint:all - name: Run opensearch container - working-directory: integrations/opensearch run: docker-compose up -d + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests working-directory: integrations/opensearch run: hatch run cov diff --git a/integrations/opensearch/pydoc/config.yml b/integrations/opensearch/pydoc/config.yml new file mode 100644 index 000000000..3e07f6625 --- /dev/null +++ b/integrations/opensearch/pydoc/config.yml @@ -0,0 +1,31 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.opensearch.bm25_retriever", + "haystack_integrations.components.retrievers.opensearch.embedding_retriever", + "haystack_integrations.document_stores.opensearch.document_store", + "haystack_integrations.document_stores.opensearch.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: OpenSearch integration for Haystack + category_slug: haystack-integrations + title: OpenSearch + slug: integrations-opensearch + order: 130 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_opensearch.md diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index 3edd544a2..794fa73fa 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "pytest-xdist", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -62,6 +63,10 @@ cov = [ "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] + [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11"] From 68358e7e992b7d724e6b61dc8a0abe9ab5287d8d Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 2 Feb 2024 08:44:24 +0100 Subject: [PATCH 19/26] fix linting (#328) --- .../document_stores/pgvector/document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py index 0abaaecce..097e86c7e 100644 --- a/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py +++ b/integrations/pgvector/src/haystack_integrations/document_stores/pgvector/document_store.py @@ -409,7 +409,7 @@ def _from_pg_to_haystack_documents(self, documents: List[Dict[str, Any]]) -> Lis # postgresql returns the embedding as a string # so we need to convert it to a list of floats - if "embedding" in document and document["embedding"]: + if document.get("embedding"): haystack_dict["embedding"] = [float(el) for el in document["embedding"].strip("[]").split(",")] haystack_document = Document.from_dict(haystack_dict) From 55ddf41c1b3e49a20ccda84c3b585b39df9fc074 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 2 Feb 2024 08:44:37 +0100 Subject: [PATCH 20/26] new entry for pgvector (#329) --- .github/labeler.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index f5eaa3374..4d060772c 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -69,6 +69,11 @@ integration:opensearch: - any-glob-to-any-file: "integrations/opensearch/**/*" - any-glob-to-any-file: ".github/workflows/opensearch.yml" +integration:pgvector: + - changed-files: + - any-glob-to-any-file: "integrations/pgvector/**/*" + - any-glob-to-any-file: ".github/workflows/pgvector.yml" + integration:pinecone: - changed-files: - any-glob-to-any-file: "integrations/pinecone/**/*" @@ -81,8 +86,8 @@ integration:qdrant: integration:unstructured-fileconverter: - changed-files: - - any-glob-to-any-file: "integrations/unstructured/fileconverter/**/*" - - any-glob-to-any-file: ".github/workflows/unstructured_fileconverter.yml" + - any-glob-to-any-file: "integrations/unstructured/**/*" + - any-glob-to-any-file: ".github/workflows/unstructured.yml" integration:uptrain: - changed-files: From 26680ca12640454254dc8d93a5a84649d06c96bb Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Fri, 2 Feb 2024 14:22:36 +0100 Subject: [PATCH 21/26] api docs (#325) --- .github/workflows/pgvector.yml | 12 ++++++++--- integrations/pgvector/pydoc/config.yml | 30 ++++++++++++++++++++++++++ integrations/pgvector/pyproject.toml | 4 ++++ 3 files changed, 43 insertions(+), 3 deletions(-) create mode 100644 integrations/pgvector/pydoc/config.yml diff --git a/.github/workflows/pgvector.yml b/.github/workflows/pgvector.yml index c985b765a..badb2565b 100644 --- a/.github/workflows/pgvector.yml +++ b/.github/workflows/pgvector.yml @@ -18,6 +18,10 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" +defaults: + run: + working-directory: integrations/pgvector + jobs: run: name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} @@ -49,10 +53,12 @@ jobs: run: pip install --upgrade hatch - name: Lint - working-directory: integrations/pgvector if: matrix.python-version == '3.9' - run: hatch run lint:all + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs - name: Run tests - working-directory: integrations/pgvector run: hatch run cov diff --git a/integrations/pgvector/pydoc/config.yml b/integrations/pgvector/pydoc/config.yml new file mode 100644 index 000000000..79974b4a1 --- /dev/null +++ b/integrations/pgvector/pydoc/config.yml @@ -0,0 +1,30 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.retrievers.pgvector.embedding_retriever", + "haystack_integrations.document_stores.pgvector.document_store", + "haystack_integrations.document_stores.pgvector.filters", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Pgvector integration for Haystack + category_slug: haystack-integrations + title: Pgvector + slug: integrations-pgvector + order: 140 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_pgvector.md diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index b361af8b1..10ef5d314 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -51,6 +51,7 @@ dependencies = [ "coverage[toml]>=6.5", "pytest", "ipython", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -63,6 +64,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] From f7678e104399bcaab15fd3c71bc029efbcbb84a7 Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Fri, 2 Feb 2024 15:57:00 +0100 Subject: [PATCH 22/26] ollama: generate api docs (#332) * ollama: generate api docs * Update .github/workflows/ollama.yml Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- .github/workflows/ollama.yml | 5 +++++ integrations/ollama/pydoc/config.yml | 29 ++++++++++++++++++++++++++++ integrations/ollama/pyproject.toml | 4 ++++ 3 files changed, 38 insertions(+) create mode 100644 integrations/ollama/pydoc/config.yml diff --git a/.github/workflows/ollama.yml b/.github/workflows/ollama.yml index 7f61af14e..a375fc7db 100644 --- a/.github/workflows/ollama.yml +++ b/.github/workflows/ollama.yml @@ -54,6 +54,11 @@ jobs: - name: Pull the LLM in the Ollama service run: docker exec ollama ollama pull ${{ env.LLM_FOR_TESTS }} + - name: Generate docs + working-directory: integrations/ollama + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + - name: Run tests working-directory: integrations/ollama run: hatch run cov diff --git a/integrations/ollama/pydoc/config.yml b/integrations/ollama/pydoc/config.yml new file mode 100644 index 000000000..768694991 --- /dev/null +++ b/integrations/ollama/pydoc/config.yml @@ -0,0 +1,29 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: [ + "haystack_integrations.components.generators.ollama.generator", + "haystack_integrations.components.generators.ollama.chat.chat_generator" + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Ollama integration for Haystack + category_slug: haystack-integrations + title: Ollama + slug: integrations-ollama + order: 120 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_ollama.md \ No newline at end of file diff --git a/integrations/ollama/pyproject.toml b/integrations/ollama/pyproject.toml index 69cc2ed16..e3bb738b6 100644 --- a/integrations/ollama/pyproject.toml +++ b/integrations/ollama/pyproject.toml @@ -48,6 +48,7 @@ git_describe_command = 'git describe --tags --match="integrations/ollama-v[0-9]* dependencies = [ "coverage[toml]>=6.5", "pytest", + "haystack-pydoc-tools", ] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" @@ -60,6 +61,9 @@ cov = [ "test-cov", "cov-report", ] +docs = [ + "pydoc-markdown pydoc/config.yml" +] [[tool.hatch.envs.all.matrix]] python = ["3.8", "3.9", "3.10", "3.11", "3.12"] From ea50b696ba42050921dc3fec78e2ba8dc93f4cbb Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Mon, 5 Feb 2024 10:14:06 +0100 Subject: [PATCH 23/26] fix: Fix Cohere tests (#337) * Fix tests * Fix linting message --- .../components/generators/cohere/generator.py | 2 +- integrations/cohere/tests/test_cohere_chat_generator.py | 5 ++++- integrations/cohere/tests/test_cohere_generators.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index fee410eab..92fed51aa 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -122,7 +122,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": """ init_params = data.get("init_parameters", {}) streaming_callback = None - if "streaming_callback" in init_params and init_params["streaming_callback"]: + if "streaming_callback" in init_params and init_params["streaming_callback"] is not None: parts = init_params["streaming_callback"].split(".") module_name = ".".join(parts[:-1]) function_name = parts[-1] diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index c91ada419..edefc1a43 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -260,7 +260,10 @@ def test_live_run(self): @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) - with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): + with pytest.raises( + cohere.CohereAPIError, + match="model not found, make sure the correct model ID was used and that you have access to the model.", + ): component.run(chat_messages) @pytest.mark.skipif( diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index e2ce10405..90d4d3e28 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -164,7 +164,7 @@ def __init__(self): self.responses = "" def __call__(self, chunk): - self.responses += chunk.text + self.responses += chunk.content return chunk callback = Callback() From 40845c245ac2b43184bfcb8f4d0bd8b504123256 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Mon, 5 Feb 2024 10:18:11 +0100 Subject: [PATCH 24/26] Add COHERE_API_KEY secret to CI environment (#339) --- .github/workflows/cohere.yml | 43 +++++++++++++++++------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/.github/workflows/cohere.yml b/.github/workflows/cohere.yml index 562556e47..fb6b00680 100644 --- a/.github/workflows/cohere.yml +++ b/.github/workflows/cohere.yml @@ -7,8 +7,8 @@ on: - cron: "0 0 * * *" pull_request: paths: - - 'integrations/cohere/**' - - '.github/workflows/cohere.yml' + - "integrations/cohere/**" + - ".github/workflows/cohere.yml" defaults: run: @@ -21,6 +21,7 @@ concurrency: env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" + COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} jobs: run: @@ -30,31 +31,27 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ['3.9', '3.10'] + python-version: ["3.9", "3.10"] steps: - - name: Support longpaths - if: matrix.os == 'windows-latest' - working-directory: . - run: git config --system core.longpaths true + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true - - uses: actions/checkout@v4 + - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} - - name: Install Hatch - run: pip install --upgrade hatch + - name: Install Hatch + run: pip install --upgrade hatch - - name: Lint - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run lint:all + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all - - name: Generate docs - if: matrix.python-version == '3.9' && runner.os == 'Linux' - run: hatch run docs - - - name: Run tests - run: hatch run cov \ No newline at end of file + - name: Run tests + run: hatch run cov From 96f6ade703c7dd47b6eb9622485a7a285a98b8b0 Mon Sep 17 00:00:00 2001 From: Corentin Date: Mon, 5 Feb 2024 10:41:57 +0100 Subject: [PATCH 25/26] unstructured: fix metadata order mixed up (#336) * Optional meta field for UnstructuredFileConverter with proper tests * black lint * Adding multiple files and meta list test case * Black formatting test * Fixing metadata page number bug. Deep copy of dict * Folder of files test * Update integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py Co-authored-by: Stefano Fiorucci * Update integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py Co-authored-by: Stefano Fiorucci * Update integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py Co-authored-by: Stefano Fiorucci * Renaming "name" meta to "file_path" and deepcopy fix * Fix Ruff Complaining * Removing unique file logic using set that does not preserve file orders. Raise error if glob and metadata list because unsafe * Better test to make sure metadata order are preserved. * Make a failing test if metadata list and directory * filepaths as lists * Update integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py Co-authored-by: Stefano Fiorucci * update meta docstrings --------- Co-authored-by: Stefano Fiorucci --- .../converters/unstructured/converter.py | 26 +++++++++++-------- .../unstructured/tests/test_converter.py | 19 ++++++++++++-- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py b/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py index bee1d9a7b..188dd9e6e 100644 --- a/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py +++ b/integrations/unstructured/src/haystack_integrations/components/converters/unstructured/converter.py @@ -105,19 +105,23 @@ def run( This value can be either a list of dictionaries or a single dictionary. If it's a single dictionary, its content is added to the metadata of all produced Documents. If it's a list, the length of the list must match the number of paths, because the two lists will be zipped. - Please note that if the paths contain directories, the length of the meta list must match - the actual number of files contained. + Please note that if the paths contain directories, meta can only be a single dictionary + (same metadata for all files). Defaults to `None`. """ - - unique_paths = {Path(path) for path in paths} - filepaths = {path for path in unique_paths if path.is_file()} - filepaths_in_directories = { - filepath for path in unique_paths if path.is_dir() for filepath in path.glob("*.*") if filepath.is_file() - } - - all_filepaths = filepaths.union(filepaths_in_directories) - + paths_obj = [Path(path) for path in paths] + filepaths = [path for path in paths_obj if path.is_file()] + filepaths_in_directories = [ + filepath for path in paths_obj if path.is_dir() for filepath in path.glob("*.*") if filepath.is_file() + ] + if filepaths_in_directories and isinstance(meta, list): + error = """"If providing directories in the `paths` parameter, + `meta` can only be a dictionary (metadata applied to every file), + and not a list. To specify different metadata for each file, + provide an explicit list of direct paths instead.""" + raise ValueError(error) + + all_filepaths = filepaths + filepaths_in_directories # currently, the files are converted sequentially to gently handle API failures documents = [] meta_list = normalize_metadata(meta, sources_count=len(all_filepaths)) diff --git a/integrations/unstructured/tests/test_converter.py b/integrations/unstructured/tests/test_converter.py index d5266ac62..ca590ab2f 100644 --- a/integrations/unstructured/tests/test_converter.py +++ b/integrations/unstructured/tests/test_converter.py @@ -154,7 +154,10 @@ def test_run_one_doc_per_element_with_meta(self, samples_path): @pytest.mark.integration def test_run_one_doc_per_element_with_meta_list_two_files(self, samples_path): pdf_path = [samples_path / "sample_pdf.pdf", samples_path / "sample_pdf2.pdf"] - meta = [{"custom_meta": "foobar", "common_meta": "common"}, {"other_meta": "barfoo", "common_meta": "common"}] + meta = [ + {"custom_meta": "sample_pdf.pdf", "common_meta": "common"}, + {"custom_meta": "sample_pdf2.pdf", "common_meta": "common"}, + ] local_converter = UnstructuredFileConverter( api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) @@ -163,6 +166,7 @@ def test_run_one_doc_per_element_with_meta_list_two_files(self, samples_path): assert len(documents) > 4 for doc in documents: + assert doc.meta["custom_meta"] == doc.meta["filename"] assert "file_path" in doc.meta assert "page_number" in doc.meta # elements have a category attribute that is saved in the document meta @@ -171,9 +175,20 @@ def test_run_one_doc_per_element_with_meta_list_two_files(self, samples_path): assert doc.meta["common_meta"] == "common" @pytest.mark.integration - def test_run_one_doc_per_element_with_meta_list_folder(self, samples_path): + def test_run_one_doc_per_element_with_meta_list_folder_fail(self, samples_path): pdf_path = [samples_path] meta = [{"custom_meta": "foobar", "common_meta": "common"}, {"other_meta": "barfoo", "common_meta": "common"}] + local_converter = UnstructuredFileConverter( + api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" + ) + with pytest.raises(ValueError): + local_converter.run(paths=pdf_path, meta=meta)["documents"] + + @pytest.mark.integration + def test_run_one_doc_per_element_with_meta_list_folder(self, samples_path): + pdf_path = [samples_path] + meta = {"common_meta": "common"} + local_converter = UnstructuredFileConverter( api_url="http://localhost:8000/general/v0/general", document_creation_mode="one-doc-per-element" ) From 6d18bc400976a8b98b57e99bd148ae7c3b7368bc Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 5 Feb 2024 13:21:12 +0100 Subject: [PATCH 26/26] add example (#334) --- integrations/pgvector/examples/example.py | 58 +++++++++++++++++++++++ integrations/pgvector/pyproject.toml | 2 + 2 files changed, 60 insertions(+) create mode 100644 integrations/pgvector/examples/example.py diff --git a/integrations/pgvector/examples/example.py b/integrations/pgvector/examples/example.py new file mode 100644 index 000000000..14c2cba60 --- /dev/null +++ b/integrations/pgvector/examples/example.py @@ -0,0 +1,58 @@ +# Before running this example, ensure you have PostgreSQL installed with the pgvector extension. +# For a quick setup using Docker: +# docker run -d -p 5432:5432 -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres +# -e POSTGRES_DB=postgres ankane/pgvector + +# Install required packages for this example, including pgvector-haystack and other libraries needed +# for Markdown conversion and embeddings generation. Use the following command: +# pip install pgvector-haystack markdown-it-py mdit_plain "sentence-transformers>=2.2.0" + +# Download some Markdown files to index. +# git clone https://github.com/anakin87/neural-search-pills + +import glob + +from haystack import Pipeline +from haystack.components.converters import MarkdownToDocument +from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder +from haystack.components.preprocessors import DocumentSplitter +from haystack.components.writers import DocumentWriter +from haystack_integrations.components.retrievers.pgvector import PgvectorEmbeddingRetriever +from haystack_integrations.document_stores.pgvector import PgvectorDocumentStore + +# Initialize PgvectorDocumentStore +document_store = PgvectorDocumentStore( + connection_string="postgresql://postgres:postgres@localhost:5432/postgres", + table_name="haystack_test", + embedding_dimension=768, + vector_function="cosine_similarity", + recreate_table=True, + search_strategy="hnsw", +) + +# Create the indexing Pipeline and index some documents +file_paths = glob.glob("neural-search-pills/pills/*.md") + + +indexing = Pipeline() +indexing.add_component("converter", MarkdownToDocument()) +indexing.add_component("splitter", DocumentSplitter(split_by="sentence", split_length=2)) +indexing.add_component("embedder", SentenceTransformersDocumentEmbedder()) +indexing.add_component("writer", DocumentWriter(document_store)) +indexing.connect("converter", "splitter") +indexing.connect("splitter", "embedder") +indexing.connect("embedder", "writer") + +indexing.run({"converter": {"sources": file_paths}}) + +# Create the querying Pipeline and try a query +querying = Pipeline() +querying.add_component("embedder", SentenceTransformersTextEmbedder()) +querying.add_component("retriever", PgvectorEmbeddingRetriever(document_store=document_store, top_k=3)) +querying.connect("embedder", "retriever") + +results = querying.run({"embedder": {"text": "What is a cross-encoder?"}}) + +for doc in results["retriever"]["documents"]: + print(doc) + print("-" * 10) diff --git a/integrations/pgvector/pyproject.toml b/integrations/pgvector/pyproject.toml index 10ef5d314..65ded967f 100644 --- a/integrations/pgvector/pyproject.toml +++ b/integrations/pgvector/pyproject.toml @@ -153,6 +153,8 @@ ban-relative-imports = "parents" [tool.ruff.per-file-ignores] # Tests can use magic values, assertions, and relative imports "tests/**/*" = ["PLR2004", "S101", "TID252"] +# examples can contain "print" commands +"examples/**/*" = ["T201"] [tool.coverage.run] source_pkgs = ["src", "tests"]