From ac8bf9caa673808ebaa5658ceebd79fd5355e999 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:06:15 +0100 Subject: [PATCH 01/29] Bump actions/labeler from 4 to 5 (#78) Bumps [actions/labeler](https://github.com/actions/labeler) from 4 to 5. - [Release notes](https://github.com/actions/labeler/releases) - [Commits](https://github.com/actions/labeler/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/labeler dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/labeler.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index d3e9adbd9..2af558297 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -10,6 +10,6 @@ jobs: triage: runs-on: ubuntu-latest steps: - - uses: actions/labeler@v4 + - uses: actions/labeler@v5 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" From d58b62ba211520d60cef13581d578efb47bc4412 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 6 Dec 2023 17:26:25 +0100 Subject: [PATCH 02/29] Bump actions/setup-python from 4 to 5 (#82) Bumps [actions/setup-python](https://github.com/actions/setup-python) from 4 to 5. - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/setup-python dependency-type: direct:production update-type: version-update:semver-major ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/chroma.yml | 2 +- .github/workflows/cohere.yml | 2 +- .github/workflows/elasticsearch.yml | 2 +- .github/workflows/gradient.yml | 2 +- .github/workflows/instructor_embedders.yml | 2 +- .github/workflows/nodes_text2speech.yml | 2 +- .github/workflows/opensearch.yml | 2 +- .github/workflows/unstructured_fileconverter.yml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/chroma.yml b/.github/workflows/chroma.yml index 88020818e..89b6a5b24 100644 --- a/.github/workflows/chroma.yml +++ b/.github/workflows/chroma.yml @@ -41,7 +41,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/cohere.yml b/.github/workflows/cohere.yml index b40cd4953..0f0030ec1 100644 --- a/.github/workflows/cohere.yml +++ b/.github/workflows/cohere.yml @@ -41,7 +41,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/elasticsearch.yml b/.github/workflows/elasticsearch.yml index 08254a58b..eb2c1748d 100644 --- a/.github/workflows/elasticsearch.yml +++ b/.github/workflows/elasticsearch.yml @@ -32,7 +32,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/gradient.yml b/.github/workflows/gradient.yml index f717ba2c9..8bab11d39 100644 --- a/.github/workflows/gradient.yml +++ b/.github/workflows/gradient.yml @@ -41,7 +41,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/instructor_embedders.yml b/.github/workflows/instructor_embedders.yml index 626933fa0..0c7765f11 100644 --- a/.github/workflows/instructor_embedders.yml +++ b/.github/workflows/instructor_embedders.yml @@ -20,7 +20,7 @@ jobs: uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/.github/workflows/nodes_text2speech.yml b/.github/workflows/nodes_text2speech.yml index 555ad3c0d..4b9bccaf5 100644 --- a/.github/workflows/nodes_text2speech.yml +++ b/.github/workflows/nodes_text2speech.yml @@ -18,7 +18,7 @@ jobs: uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/.github/workflows/opensearch.yml b/.github/workflows/opensearch.yml index faf359cfc..aacb4ce71 100644 --- a/.github/workflows/opensearch.yml +++ b/.github/workflows/opensearch.yml @@ -32,7 +32,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/unstructured_fileconverter.yml b/.github/workflows/unstructured_fileconverter.yml index 8d7ece048..83e355ac6 100644 --- a/.github/workflows/unstructured_fileconverter.yml +++ b/.github/workflows/unstructured_fileconverter.yml @@ -42,7 +42,7 @@ jobs: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} From 67d325f8e8f9b7b685724c8512d8824418b3f89e Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Wed, 6 Dec 2023 19:25:29 +0100 Subject: [PATCH 03/29] Revert "Bump actions/labeler from 4 to 5 (#78)" (#85) This reverts commit ac8bf9caa673808ebaa5658ceebd79fd5355e999. --- .github/workflows/labeler.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 2af558297..d3e9adbd9 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -10,6 +10,6 @@ jobs: triage: runs-on: ubuntu-latest steps: - - uses: actions/labeler@v5 + - uses: actions/labeler@v4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" From 0634b233b439974aebeb0031d12475208abea445 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 6 Dec 2023 19:28:54 +0100 Subject: [PATCH 04/29] chore: remove LazyImport from integrations (#83) * remove LazyImport from integrations * fix linter --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> --- .../embedders/gradient_document_embedder.py | 6 +----- .../gradient_haystack/embedders/gradient_text_embedder.py | 6 +----- .../gradient/src/gradient_haystack/generator/base.py | 7 +------ .../embedding_backend/instructor_backend.py | 6 +----- 4 files changed, 4 insertions(+), 21 deletions(-) diff --git a/integrations/gradient/src/gradient_haystack/embedders/gradient_document_embedder.py b/integrations/gradient/src/gradient_haystack/embedders/gradient_document_embedder.py index 81a93ad2b..a716b8cf2 100644 --- a/integrations/gradient/src/gradient_haystack/embedders/gradient_document_embedder.py +++ b/integrations/gradient/src/gradient_haystack/embedders/gradient_document_embedder.py @@ -1,11 +1,8 @@ import logging from typing import Any, Dict, List, Optional +from gradientai import Gradient from haystack import Document, component, default_to_dict -from haystack.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install gradientai'") as gradientai_import: - from gradientai import Gradient logger = logging.getLogger(__name__) @@ -49,7 +46,6 @@ def __init__( variable GRADIENT_WORKSPACE_ID. :param host: The Gradient host. By default it uses https://api.gradient.ai/. """ - gradientai_import.check() self._batch_size = batch_size self._host = host self._model_name = model_name diff --git a/integrations/gradient/src/gradient_haystack/embedders/gradient_text_embedder.py b/integrations/gradient/src/gradient_haystack/embedders/gradient_text_embedder.py index 53996b785..013d375ff 100644 --- a/integrations/gradient/src/gradient_haystack/embedders/gradient_text_embedder.py +++ b/integrations/gradient/src/gradient_haystack/embedders/gradient_text_embedder.py @@ -1,10 +1,7 @@ from typing import Any, Dict, List, Optional +from gradientai import Gradient from haystack import component, default_to_dict -from haystack.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install gradientai'") as gradientai_import: - from gradientai import Gradient @component @@ -43,7 +40,6 @@ def __init__( variable GRADIENT_WORKSPACE_ID. :param host: The Gradient host. By default it uses https://api.gradient.ai/. """ - gradientai_import.check() self._host = host self._model_name = model_name diff --git a/integrations/gradient/src/gradient_haystack/generator/base.py b/integrations/gradient/src/gradient_haystack/generator/base.py index 536525377..3adf0be01 100644 --- a/integrations/gradient/src/gradient_haystack/generator/base.py +++ b/integrations/gradient/src/gradient_haystack/generator/base.py @@ -1,11 +1,8 @@ import logging from typing import Any, Dict, List, Optional +from gradientai import Gradient from haystack import component, default_to_dict -from haystack.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install gradientai'") as gradientai_import: - from gradientai import Gradient logger = logging.getLogger(__name__) @@ -57,8 +54,6 @@ def __init__( :param workspace_id: The Gradient workspace ID. If not provided it's read from the environment variable GRADIENT_WORKSPACE_ID. """ - gradientai_import.check() - self._access_token = access_token self._base_model_slug = base_model_slug self._host = host diff --git a/integrations/instructor-embedders/instructor_embedders/embedding_backend/instructor_backend.py b/integrations/instructor-embedders/instructor_embedders/embedding_backend/instructor_backend.py index 5be300dd3..efe35e9b7 100644 --- a/integrations/instructor-embedders/instructor_embedders/embedding_backend/instructor_backend.py +++ b/integrations/instructor-embedders/instructor_embedders/embedding_backend/instructor_backend.py @@ -3,10 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import ClassVar, Dict, List, Optional, Union -from haystack.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install InstructorEmbedding'") as instructor_embeddings_import: - from InstructorEmbedding import INSTRUCTOR +from InstructorEmbedding import INSTRUCTOR class _InstructorEmbeddingBackendFactory: @@ -40,7 +37,6 @@ class _InstructorEmbeddingBackend: def __init__( self, model_name_or_path: str, device: Optional[str] = None, use_auth_token: Union[bool, str, None] = None ): - instructor_embeddings_import.check() self.model = INSTRUCTOR(model_name_or_path=model_name_or_path, device=device, use_auth_token=use_auth_token) def embed(self, data: List[List[str]], **kwargs) -> List[List[float]]: From 2627712abbb671acda1e67811babcff35161489a Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Thu, 7 Dec 2023 09:09:27 +0100 Subject: [PATCH 05/29] [cohere] Add text and document embedders (#80) * Add text and document embedders ------ Co-authored-by: vrunm <97465624+vrunm@users.noreply.github.com> * refactoring * linting * add pytest markers * more * fix api url management * fix integration tests * fix integrations tests for good * final cleanup * review feedback --------- Co-authored-by: vrunm <97465624+vrunm@users.noreply.github.com> --- integrations/cohere/README.md | 45 +++++- integrations/cohere/pyproject.toml | 17 +- .../src/cohere_haystack/embedders/__init__.py | 3 + .../embedders/document_embedder.py | 153 ++++++++++++++++++ .../embedders/text_embedder.py | 104 ++++++++++++ .../src/cohere_haystack/embedders/utils.py | 57 +++++++ .../cohere/src/cohere_haystack/generator.py | 12 +- .../cohere/tests/test_cohere_generators.py | 20 ++- .../cohere/tests/test_document_embedder.py | 136 ++++++++++++++++ .../cohere/tests/test_text_embedder.py | 113 +++++++++++++ 10 files changed, 634 insertions(+), 26 deletions(-) create mode 100644 integrations/cohere/src/cohere_haystack/embedders/__init__.py create mode 100644 integrations/cohere/src/cohere_haystack/embedders/document_embedder.py create mode 100644 integrations/cohere/src/cohere_haystack/embedders/text_embedder.py create mode 100644 integrations/cohere/src/cohere_haystack/embedders/utils.py create mode 100644 integrations/cohere/tests/test_document_embedder.py create mode 100644 integrations/cohere/tests/test_text_embedder.py diff --git a/integrations/cohere/README.md b/integrations/cohere/README.md index 79cefed21..86a43bf83 100644 --- a/integrations/cohere/README.md +++ b/integrations/cohere/README.md @@ -7,8 +7,10 @@ **Table of Contents** -- [Installation](#installation) -- [License](#license) +- [cohere-haystack](#cohere-haystack) + - [Installation](#installation) + - [Contributing](#contributing) + - [License](#license) ## Installation @@ -16,6 +18,45 @@ pip install cohere-haystack ``` +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` +> Note: integration tests will be skipped unless the env var COHERE_API_KEY is set. The api key needs to be valid +> in order to pass the tests. + +To only run unit tests: +``` +hatch run test -m"not integration" +``` + +To only run embedders tests: +``` +hatch run test -m"embedders" +``` + +To only run generators tests: +``` +hatch run test -m"generators" +``` + +Markers can be combined, for example you can run only integration tests for embedders with: +``` +hatch run test -m"integrations and embedders" +``` + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + ## License `cohere-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index e291907fd..5d589df7b 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -97,7 +97,6 @@ select = [ "E", "EM", "F", - "FBT", "I", "ICN", "ISC", @@ -118,8 +117,6 @@ select = [ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", # Ignore checks for possible passwords "S105", "S106", "S107", # Ignore complexity @@ -163,6 +160,16 @@ exclude_lines = [ module = [ "cohere.*", "haystack.*", - "pytest.*" + "pytest.*", + "numpy.*", ] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = [ + "integration: integration tests", + "embedders: embedders tests", + "generators: generators tests", +] +log_cli = true \ No newline at end of file diff --git a/integrations/cohere/src/cohere_haystack/embedders/__init__.py b/integrations/cohere/src/cohere_haystack/embedders/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py new file mode 100644 index 000000000..681471947 --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os +from typing import Any, Dict, List, Optional + +from cohere import COHERE_API_URL, AsyncClient, Client +from haystack import Document, component, default_to_dict + +from cohere_haystack.embedders.utils import get_async_response, get_response + + +@component +class CohereDocumentEmbedder: + """ + A component for computing Document embeddings using Cohere models. + The embedding of each Document is stored in the `embedding` field of the Document. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "embed-english-v2.0", + api_base_url: str = COHERE_API_URL, + truncate: str = "END", + use_async_client: bool = False, + max_retries: int = 3, + timeout: int = 120, + batch_size: int = 32, + progress_bar: bool = True, + metadata_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create a CohereDocumentEmbedder component. + + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment + variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are + `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, + `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to + `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + cases, input is discarded until the remaining input is exactly the maximum input token length for the model. + If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + AsyncClient for applications with many concurrent calls. + :param max_retries: maximal number of retries for requests, defaults to `3`. + :param timeout: request timeout in seconds, defaults to `120`. + :param batch_size: Number of Documents to encode at once. + :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments + to keep the logs clean. + :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. + :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + """ + + if api_key is None: + try: + api_key = os.environ["COHERE_API_KEY"] + except KeyError as error_msg: + msg = ( + "CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment " + "variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) + raise ValueError(msg) from error_msg + + self.api_key = api_key + self.model_name = model_name + self.api_base_url = api_base_url + self.truncate = truncate + self.use_async_client = use_async_client + self.max_retries = max_retries + self.timeout = timeout + self.batch_size = batch_size + self.progress_bar = progress_bar + self.metadata_fields_to_embed = metadata_fields_to_embed or [] + self.embedding_separator = embedding_separator + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary omitting the api_key field. + """ + return default_to_dict( + self, + model_name=self.model_name, + api_base_url=self.api_base_url, + truncate=self.truncate, + use_async_client=self.use_async_client, + max_retries=self.max_retries, + timeout=self.timeout, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + metadata_fields_to_embed=self.metadata_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed: List[str] = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None + ] + + text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005 + texts_to_embed.append(text_to_embed) + return texts_to_embed + + @component.output_types(documents=List[Document], metadata=Dict[str, Any]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + + :param documents: A list of Documents to embed. + """ + + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = ( + "CohereDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a string, please use the CohereTextEmbedder." + ) + raise TypeError(msg) + + if not documents: + # return early if we were passed an empty list + return {"documents": [], "metadata": {}} + + texts_to_embed = self._prepare_texts_to_embed(documents) + + if self.use_async_client: + cohere_client = AsyncClient( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + all_embeddings, metadata = asyncio.run( + get_async_response(cohere_client, texts_to_embed, self.model_name, self.truncate) + ) + else: + cohere_client = Client( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + all_embeddings, metadata = get_response( + cohere_client, texts_to_embed, self.model_name, self.truncate, self.batch_size, self.progress_bar + ) + + for doc, embeddings in zip(documents, all_embeddings): + doc.embedding = embeddings + + return {"documents": documents, "metadata": metadata} diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py new file mode 100644 index 000000000..936926b99 --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os +from typing import Any, Dict, List, Optional + +from cohere import COHERE_API_URL, AsyncClient, Client +from haystack import component, default_to_dict + +from cohere_haystack.embedders.utils import get_async_response, get_response + + +@component +class CohereTextEmbedder: + """ + A component for embedding strings using Cohere models. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "embed-english-v2.0", + api_base_url: str = COHERE_API_URL, + truncate: str = "END", + use_async_client: bool = False, + max_retries: int = 3, + timeout: int = 120, + ): + """ + Create a CohereTextEmbedder component. + + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment + variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are + `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, + `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to + `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + cases, input is discarded until the remaining input is exactly the maximum input token length for the model. + If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + AsyncClient for applications with many concurrent calls. + :param max_retries: Maximum number of retries for requests, defaults to `3`. + :param timeout: Request timeout in seconds, defaults to `120`. + """ + + if api_key is None: + try: + api_key = os.environ["COHERE_API_KEY"] + except KeyError as error_msg: + msg = ( + "CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment " + "variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) + raise ValueError(msg) from error_msg + + self.api_key = api_key + self.model_name = model_name + self.api_base_url = api_base_url + self.truncate = truncate + self.use_async_client = use_async_client + self.max_retries = max_retries + self.timeout = timeout + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary omitting the api_key field. + """ + return default_to_dict( + self, + model_name=self.model_name, + api_base_url=self.api_base_url, + truncate=self.truncate, + use_async_client=self.use_async_client, + max_retries=self.max_retries, + timeout=self.timeout, + ) + + @component.output_types(embedding=List[float], metadata=Dict[str, Any]) + def run(self, text: str): + """Embed a string.""" + if not isinstance(text, str): + msg = ( + "CohereTextEmbedder expects a string as input." + "In case you want to embed a list of Documents, please use the CohereDocumentEmbedder." + ) + raise TypeError(msg) + + # Establish connection to API + + if self.use_async_client: + cohere_client = AsyncClient( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + embedding, metadata = asyncio.run(get_async_response(cohere_client, [text], self.model_name, self.truncate)) + else: + cohere_client = Client( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + embedding, metadata = get_response(cohere_client, [text], self.model_name, self.truncate) + + return {"embedding": embedding[0], "metadata": metadata} diff --git a/integrations/cohere/src/cohere_haystack/embedders/utils.py b/integrations/cohere/src/cohere_haystack/embedders/utils.py new file mode 100644 index 000000000..a3511008b --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/utils.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Tuple + +from cohere import AsyncClient, Client, CohereError +from tqdm import tqdm + +API_BASE_URL = "https://api.cohere.ai/v1/embed" + + +async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, truncate): + all_embeddings: List[List[float]] = [] + metadata: Dict[str, Any] = {} + try: + response = await cohere_async_client.embed(texts=texts, model=model_name, truncate=truncate) + if response.meta is not None: + metadata = response.meta + for emb in response.embeddings: + all_embeddings.append(emb) + + return all_embeddings, metadata + + except CohereError as error_response: + msg = error_response.message + raise ValueError(msg) from error_response + + +def get_response( + cohere_client: Client, texts: List[str], model_name, truncate, batch_size=32, progress_bar=False +) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + We support batching with the sync client. + """ + all_embeddings: List[List[float]] = [] + metadata: Dict[str, Any] = {} + + try: + for i in tqdm( + range(0, len(texts), batch_size), + disable=not progress_bar, + desc="Calculating embeddings", + ): + batch = texts[i : i + batch_size] + response = cohere_client.embed(batch, model=model_name, truncate=truncate) + for emb in response.embeddings: + all_embeddings.append(emb) + embeddings = [list(map(float, emb)) for emb in response.embeddings] + all_embeddings.extend(embeddings) + if response.meta is not None: + metadata = response.meta + + return all_embeddings, metadata + + except CohereError as error_response: + msg = error_response.message + raise ValueError(msg) from error_response diff --git a/integrations/cohere/src/cohere_haystack/generator.py b/integrations/cohere/src/cohere_haystack/generator.py index 4b18fb75d..a07225804 100644 --- a/integrations/cohere/src/cohere_haystack/generator.py +++ b/integrations/cohere/src/cohere_haystack/generator.py @@ -4,13 +4,11 @@ import logging import os import sys -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, cast +from cohere import COHERE_API_URL, Client +from cohere.responses import Generations from haystack import DeserializationError, component, default_from_dict, default_to_dict -from haystack.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install cohere'") as cohere_import: - from cohere import COHERE_API_URL, Client logger = logging.getLogger(__name__) @@ -75,8 +73,6 @@ def __init__( - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10. """ - cohere_import.check() - if not api_key: api_key = os.environ.get("COHERE_API_KEY") if not api_key: @@ -159,7 +155,7 @@ def run(self, prompt: str): self._check_truncated_answers(metadata) return {"replies": replies, "metadata": metadata} - metadata = [{"finish_reason": resp.finish_reason} for resp in response] + metadata = [{"finish_reason": resp.finish_reason} for resp in cast(Generations, response)] replies = [resp.text for resp in response] self._check_truncated_answers(metadata) return {"replies": replies, "metadata": metadata} diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index d267847a4..9462f364d 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -4,9 +4,12 @@ import os import pytest +from cohere import COHERE_API_URL from cohere_haystack.generator import CohereGenerator +pytestmark = pytest.mark.generators + def default_streaming_callback(chunk): """ @@ -16,16 +19,13 @@ def default_streaming_callback(chunk): print(chunk.text, flush=True, end="") # noqa: T201 -@pytest.mark.integration class TestCohereGenerator: def test_init_default(self): - import cohere - component = CohereGenerator(api_key="test-api-key") assert component.api_key == "test-api-key" assert component.model_name == "command" assert component.streaming_callback is None - assert component.api_base_url == cohere.COHERE_API_URL + assert component.api_base_url == COHERE_API_URL assert component.model_parameters == {} def test_init_with_parameters(self): @@ -45,8 +45,6 @@ def test_init_with_parameters(self): assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self): - import cohere - component = CohereGenerator(api_key="test-api-key") data = component.to_dict() assert data == { @@ -54,7 +52,7 @@ def test_to_dict_default(self): "init_parameters": { "model_name": "command", "streaming_callback": None, - "api_base_url": cohere.COHERE_API_URL, + "api_base_url": COHERE_API_URL, }, } @@ -112,7 +110,7 @@ def test_from_dict(self, monkeypatch): "streaming_callback": "tests.test_cohere_generators.default_streaming_callback", }, } - component = CohereGenerator.from_dict(data) + component: CohereGenerator = CohereGenerator.from_dict(data) assert component.api_key == "test-key" assert component.model_name == "command" assert component.streaming_callback == default_streaming_callback @@ -134,7 +132,7 @@ def test_check_truncated_answers(self, caplog): ) @pytest.mark.integration def test_cohere_generator_run(self): - component = CohereGenerator(api_key=os.environ.get("COHERE_API_KEY")) + component = CohereGenerator() results = component.run(prompt="What's the capital of France?") assert len(results["replies"]) == 1 assert "Paris" in results["replies"][0] @@ -149,7 +147,7 @@ def test_cohere_generator_run(self): def test_cohere_generator_run_wrong_model_name(self): import cohere - component = CohereGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) + component = CohereGenerator(model_name="something-obviously-wrong") with pytest.raises( cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model.", @@ -171,7 +169,7 @@ def __call__(self, chunk): return chunk callback = Callback() - component = CohereGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) + component = CohereGenerator(streaming_callback=callback) results = component.run(prompt="What's the capital of France?") assert len(results["replies"]) == 1 diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py new file mode 100644 index 000000000..d6309704c --- /dev/null +++ b/integrations/cohere/tests/test_document_embedder.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +from cohere import COHERE_API_URL +from haystack import Document + +from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder + +pytestmark = pytest.mark.embedders + + +class TestCohereDocumentEmbedder: + def test_init_default(self): + embedder = CohereDocumentEmbedder(api_key="test-api-key") + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == COHERE_API_URL + assert embedder.truncate == "END" + assert embedder.use_async_client is False + assert embedder.max_retries == 3 + assert embedder.timeout == 120 + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.metadata_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_with_parameters(self): + embedder = CohereDocumentEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["test_field"], + embedding_separator="-", + ) + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.truncate == "START" + assert embedder.use_async_client is True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == "-" + + def test_to_dict(self): + embedder_component = CohereDocumentEmbedder(api_key="test-api-key") + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", + "init_parameters": { + "model_name": "embed-english-v2.0", + "api_base_url": COHERE_API_URL, + "truncate": "END", + "use_async_client": False, + "max_retries": 3, + "timeout": 120, + "batch_size": 32, + "progress_bar": True, + "metadata_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + def test_to_dict_with_custom_init_parameters(self): + embedder_component = CohereDocumentEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["text_field"], + embedding_separator="-", + ) + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", + "init_parameters": { + "model_name": "embed-multilingual-v2.0", + "api_base_url": "https://custom-api-base-url.com", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + "batch_size": 64, + "progress_bar": False, + "metadata_fields_to_embed": ["text_field"], + "embedding_separator": "-", + }, + } + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = CohereDocumentEmbedder() + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + result = embedder.run(docs) + docs_with_embeddings = result["documents"] + + assert isinstance(docs_with_embeddings, list) + assert len(docs_with_embeddings) == len(docs) + for doc in docs_with_embeddings: + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) + + def test_run_wrong_input_format(self): + embedder = CohereDocumentEmbedder(api_key="test-api-key") + + with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents="text") + with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=[1, 2, 3]) + + assert embedder.run(documents=[]) == {"documents": [], "metadata": {}} diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py new file mode 100644 index 000000000..d2aed79c1 --- /dev/null +++ b/integrations/cohere/tests/test_text_embedder.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +from cohere import COHERE_API_URL + +from cohere_haystack.embedders.text_embedder import CohereTextEmbedder + +pytestmark = pytest.mark.embedders + + +class TestCohereTextEmbedder: + def test_init_default(self): + """ + Test default initialization parameters for CohereTextEmbedder. + """ + embedder = CohereTextEmbedder(api_key="test-api-key") + + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == COHERE_API_URL + assert embedder.truncate == "END" + assert embedder.use_async_client is False + assert embedder.max_retries == 3 + assert embedder.timeout == 120 + + def test_init_with_parameters(self): + """ + Test custom initialization parameters for CohereTextEmbedder. + """ + embedder = CohereTextEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + ) + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.truncate == "START" + assert embedder.use_async_client is True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + + def test_to_dict(self): + """ + Test serialization of this component to a dictionary, using default initialization parameters. + """ + embedder_component = CohereTextEmbedder(api_key="test-api-key") + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", + "init_parameters": { + "model_name": "embed-english-v2.0", + "api_base_url": COHERE_API_URL, + "truncate": "END", + "use_async_client": False, + "max_retries": 3, + "timeout": 120, + }, + } + + def test_to_dict_with_custom_init_parameters(self): + """ + Test serialization of this component to a dictionary, using custom initialization parameters. + """ + embedder_component = CohereTextEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + ) + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", + "init_parameters": { + "model_name": "embed-multilingual-v2.0", + "api_base_url": "https://custom-api-base-url.com", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + }, + } + + def test_run_wrong_input_format(self): + """ + Test for checking incorrect input when creating embedding. + """ + embedder = CohereTextEmbedder(api_key="test-api-key") + list_integers_input = ["text_snippet_1", "text_snippet_2"] + + with pytest.raises(TypeError, match="CohereTextEmbedder expects a string as input"): + embedder.run(text=list_integers_input) + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = CohereTextEmbedder() + text = "The food was delicious" + result = embedder.run(text) + assert all(isinstance(x, float) for x in result["embedding"]) From 4076949668030eab0ef2eeada80f518067c3aa69 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Thu, 7 Dec 2023 09:43:10 +0100 Subject: [PATCH 06/29] [cohere] fix cohere pypi version badge and add Embedder note (#86) * fix cohere pypi version badge and add Embedder note * increase version --- README.md | 2 +- integrations/cohere/src/cohere_haystack/__about__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index bbd816d98..91032874e 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ deepset-haystack | Package | Type | PyPi Package | Status | | ------------------------------------------------------------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | -| [cohere-haystack](integrations/cohere/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | +| [cohere-haystack](integrations/cohere/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | | [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | | [instructor-embedders-haystack](integrations/instructor-embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | diff --git a/integrations/cohere/src/cohere_haystack/__about__.py b/integrations/cohere/src/cohere_haystack/__about__.py index 0e4fa27cf..8430bf8d4 100644 --- a/integrations/cohere/src/cohere_haystack/__about__.py +++ b/integrations/cohere/src/cohere_haystack/__about__.py @@ -1,4 +1,4 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.1" +__version__ = "0.1.1" From 7e47dff250361cf5546bb22d4cbcec09782af569 Mon Sep 17 00:00:00 2001 From: Ashwin Mathur <97467100+awinml@users.noreply.github.com> Date: Mon, 11 Dec 2023 17:28:58 +0530 Subject: [PATCH 07/29] feat: Add support for V3 Embed models to CohereEmbedders (#89) * Add v3 embed models support * fix model names * Update integrations/cohere/src/cohere_haystack/embedders/document_embedder.py Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> * Remove Optional type for input_type parameter * Apply docstring suggestions from code review Co-authored-by: Daria Fokina --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Co-authored-by: Daria Fokina --- .../embedders/document_embedder.py | 40 ++++++++++++++++--- .../embedders/text_embedder.py | 35 +++++++++++++--- .../src/cohere_haystack/embedders/utils.py | 10 +++-- .../cohere/tests/test_document_embedder.py | 6 +++ .../cohere/tests/test_text_embedder.py | 6 +++ 5 files changed, 83 insertions(+), 14 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py index 681471947..deec7c20d 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -16,12 +16,28 @@ class CohereDocumentEmbedder: """ A component for computing Document embeddings using Cohere models. The embedding of each Document is stored in the `embedding` field of the Document. + + Usage Example: + ```python + from haystack import Document + from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = CohereDocumentEmbedder() + + result = document_embedder.run([doc]) + print(result['documents'][0].embedding) + + # [-0.453125, 1.2236328, 2.0058594, ...] + ``` """ def __init__( self, api_key: Optional[str] = None, model_name: str = "embed-english-v2.0", + input_type: str = "search_document", api_base_url: str = COHERE_API_URL, truncate: str = "END", use_async_client: bool = False, @@ -37,9 +53,15 @@ def __init__( :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are - `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, - `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, + `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, + `"embed-multilingual-v2.0"`. This list of all supported models can be found in the + [model documentation](https://docs.cohere.com/docs/models#representation). + :param input_type: Specifies the type of input you're giving to the model. Supported values are + "search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not + required for older versions of the embedding models (meaning anything lower than v3), but is required for more + recent versions (meaning anything bigger than v2). :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both @@ -68,6 +90,7 @@ def __init__( self.api_key = api_key self.model_name = model_name + self.input_type = input_type self.api_base_url = api_base_url self.truncate = truncate self.use_async_client = use_async_client @@ -85,6 +108,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model_name=self.model_name, + input_type=self.input_type, api_base_url=self.api_base_url, truncate=self.truncate, use_async_client=self.use_async_client, @@ -137,14 +161,20 @@ def run(self, documents: List[Document]): self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) all_embeddings, metadata = asyncio.run( - get_async_response(cohere_client, texts_to_embed, self.model_name, self.truncate) + get_async_response(cohere_client, texts_to_embed, self.model_name, self.input_type, self.truncate) ) else: cohere_client = Client( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) all_embeddings, metadata = get_response( - cohere_client, texts_to_embed, self.model_name, self.truncate, self.batch_size, self.progress_bar + cohere_client, + texts_to_embed, + self.model_name, + self.input_type, + self.truncate, + self.batch_size, + self.progress_bar, ) for doc, embeddings in zip(documents, all_embeddings): diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 936926b99..5d139427f 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -15,12 +15,27 @@ class CohereTextEmbedder: """ A component for embedding strings using Cohere models. + + Usage Example: + ```python + from cohere_haystack.embedders.text_embedder import CohereTextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = CohereTextEmbedder() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [-0.453125, 1.2236328, 2.0058594, ...] + # 'metadata': {'api_version': {'version': '1'}, 'billed_units': {'input_tokens': 4}}} + ``` """ def __init__( self, api_key: Optional[str] = None, model_name: str = "embed-english-v2.0", + input_type: str = "search_document", api_base_url: str = COHERE_API_URL, truncate: str = "END", use_async_client: bool = False, @@ -32,9 +47,15 @@ def __init__( :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are - `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, - `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, + `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, + `"embed-multilingual-v2.0"`. This list of all supported models can be found in the + [model documentation](https://docs.cohere.com/docs/models#representation). + :param input_type: Specifies the type of input you're giving to the model. Supported values are + "search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not + required for older versions of the embedding models (meaning anything lower than v3), but is required for more + recent versions (meaning anything bigger than v2). :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both @@ -58,6 +79,7 @@ def __init__( self.api_key = api_key self.model_name = model_name + self.input_type = input_type self.api_base_url = api_base_url self.truncate = truncate self.use_async_client = use_async_client @@ -71,6 +93,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model_name=self.model_name, + input_type=self.input_type, api_base_url=self.api_base_url, truncate=self.truncate, use_async_client=self.use_async_client, @@ -94,11 +117,13 @@ def run(self, text: str): cohere_client = AsyncClient( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) - embedding, metadata = asyncio.run(get_async_response(cohere_client, [text], self.model_name, self.truncate)) + embedding, metadata = asyncio.run( + get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) + ) else: cohere_client = Client( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) - embedding, metadata = get_response(cohere_client, [text], self.model_name, self.truncate) + embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) return {"embedding": embedding[0], "metadata": metadata} diff --git a/integrations/cohere/src/cohere_haystack/embedders/utils.py b/integrations/cohere/src/cohere_haystack/embedders/utils.py index a3511008b..165d34acd 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/utils.py +++ b/integrations/cohere/src/cohere_haystack/embedders/utils.py @@ -9,11 +9,13 @@ API_BASE_URL = "https://api.cohere.ai/v1/embed" -async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, truncate): +async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} try: - response = await cohere_async_client.embed(texts=texts, model=model_name, truncate=truncate) + response = await cohere_async_client.embed( + texts=texts, model=model_name, input_type=input_type, truncate=truncate + ) if response.meta is not None: metadata = response.meta for emb in response.embeddings: @@ -27,7 +29,7 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], def get_response( - cohere_client: Client, texts: List[str], model_name, truncate, batch_size=32, progress_bar=False + cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False ) -> Tuple[List[List[float]], Dict[str, Any]]: """ We support batching with the sync client. @@ -42,7 +44,7 @@ def get_response( desc="Calculating embeddings", ): batch = texts[i : i + batch_size] - response = cohere_client.embed(batch, model=model_name, truncate=truncate) + response = cohere_client.embed(batch, model=model_name, input_type=input_type, truncate=truncate) for emb in response.embeddings: all_embeddings.append(emb) embeddings = [list(map(float, emb)) for emb in response.embeddings] diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index d6309704c..5b0ad5c3f 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -17,6 +17,7 @@ def test_init_default(self): embedder = CohereDocumentEmbedder(api_key="test-api-key") assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-english-v2.0" + assert embedder.input_type == "search_document" assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" assert embedder.use_async_client is False @@ -31,6 +32,7 @@ def test_init_with_parameters(self): embedder = CohereDocumentEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -43,6 +45,7 @@ def test_init_with_parameters(self): ) assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.input_type == "search_query" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" assert embedder.use_async_client is True @@ -60,6 +63,7 @@ def test_to_dict(self): "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", "init_parameters": { "model_name": "embed-english-v2.0", + "input_type": "search_document", "api_base_url": COHERE_API_URL, "truncate": "END", "use_async_client": False, @@ -76,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self): embedder_component = CohereDocumentEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -91,6 +96,7 @@ def test_to_dict_with_custom_init_parameters(self): "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", "init_parameters": { "model_name": "embed-multilingual-v2.0", + "input_type": "search_query", "api_base_url": "https://custom-api-base-url.com", "truncate": "START", "use_async_client": True, diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index d2aed79c1..9ec673c98 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -20,6 +20,7 @@ def test_init_default(self): assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-english-v2.0" + assert embedder.input_type == "search_document" assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" assert embedder.use_async_client is False @@ -33,6 +34,7 @@ def test_init_with_parameters(self): embedder = CohereTextEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -41,6 +43,7 @@ def test_init_with_parameters(self): ) assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.input_type == "search_query" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" assert embedder.use_async_client is True @@ -57,6 +60,7 @@ def test_to_dict(self): "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { "model_name": "embed-english-v2.0", + "input_type": "search_document", "api_base_url": COHERE_API_URL, "truncate": "END", "use_async_client": False, @@ -72,6 +76,7 @@ def test_to_dict_with_custom_init_parameters(self): embedder_component = CohereTextEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -83,6 +88,7 @@ def test_to_dict_with_custom_init_parameters(self): "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { "model_name": "embed-multilingual-v2.0", + "input_type": "search_query", "api_base_url": "https://custom-api-base-url.com", "truncate": "START", "use_async_client": True, From 7b7201b79ac8c57f633d0e4b0bb983bd05856f12 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Mon, 11 Dec 2023 13:33:01 +0100 Subject: [PATCH 08/29] increase version to prepare release (#92) --- integrations/cohere/src/cohere_haystack/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/cohere/src/cohere_haystack/__about__.py b/integrations/cohere/src/cohere_haystack/__about__.py index 8430bf8d4..447ed9770 100644 --- a/integrations/cohere/src/cohere_haystack/__about__.py +++ b/integrations/cohere/src/cohere_haystack/__about__.py @@ -1,4 +1,4 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.1.1" +__version__ = "0.2.0" From f3ec5a015fa2d3ab5ef9a697fb75a23a3b58ae9f Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Mon, 11 Dec 2023 14:47:28 +0100 Subject: [PATCH 09/29] rm unused constant (#91) --- integrations/cohere/src/cohere_haystack/embedders/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/embedders/utils.py b/integrations/cohere/src/cohere_haystack/embedders/utils.py index 165d34acd..1c1049852 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/utils.py +++ b/integrations/cohere/src/cohere_haystack/embedders/utils.py @@ -6,8 +6,6 @@ from cohere import AsyncClient, Client, CohereError from tqdm import tqdm -API_BASE_URL = "https://api.cohere.ai/v1/embed" - async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): all_embeddings: List[List[float]] = [] From 6a9de9a340a60a159d1f8203efc74005a72a05d1 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Date: Mon, 11 Dec 2023 17:57:50 +0100 Subject: [PATCH 10/29] feat: add Jina Embeddings integration (#93) * feat: add Jina Embeddings integration * try to fix problems with missing imports * refactor: apply suggestions from code review Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> * refactor: fix comments * Update document_embedder.py * Update text_embedder.py --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> --- .github/workflows/jina.yml | 56 ++++ README.md | 1 + integrations/jina/LICENSE.txt | 201 +++++++++++++++ integrations/jina/README.md | 32 +++ integrations/jina/pyproject.toml | 161 ++++++++++++ .../jina/src/jina_haystack/__about__.py | 4 + .../jina/src/jina_haystack/__init__.py | 8 + .../src/jina_haystack/document_embedder.py | 180 +++++++++++++ .../jina/src/jina_haystack/text_embedder.py | 106 ++++++++ integrations/jina/tests/__init__.py | 3 + .../jina/tests/test_document_embedder.py | 239 ++++++++++++++++++ integrations/jina/tests/test_text_embedder.py | 100 ++++++++ 12 files changed, 1091 insertions(+) create mode 100644 .github/workflows/jina.yml create mode 100644 integrations/jina/LICENSE.txt create mode 100644 integrations/jina/README.md create mode 100644 integrations/jina/pyproject.toml create mode 100644 integrations/jina/src/jina_haystack/__about__.py create mode 100644 integrations/jina/src/jina_haystack/__init__.py create mode 100644 integrations/jina/src/jina_haystack/document_embedder.py create mode 100644 integrations/jina/src/jina_haystack/text_embedder.py create mode 100644 integrations/jina/tests/__init__.py create mode 100644 integrations/jina/tests/test_document_embedder.py create mode 100644 integrations/jina/tests/test_text_embedder.py diff --git a/.github/workflows/jina.yml b/.github/workflows/jina.yml new file mode 100644 index 000000000..894456877 --- /dev/null +++ b/.github/workflows/jina.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / jina + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - 'integrations/jina/**' + - '.github/workflows/jina.yml' + +defaults: + run: + working-directory: integrations/jina + +concurrency: + group: jina-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ['3.9', '3.10'] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov \ No newline at end of file diff --git a/README.md b/README.md index 91032874e..024945b1c 100644 --- a/README.md +++ b/README.md @@ -64,3 +64,4 @@ deepset-haystack | [instructor-embedders-haystack](integrations/instructor-embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/fileconverter/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured / fileconverter](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml) | +| [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | diff --git a/integrations/jina/LICENSE.txt b/integrations/jina/LICENSE.txt new file mode 100644 index 000000000..6134ab324 --- /dev/null +++ b/integrations/jina/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-present deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/jina/README.md b/integrations/jina/README.md new file mode 100644 index 000000000..309472562 --- /dev/null +++ b/integrations/jina/README.md @@ -0,0 +1,32 @@ +# jina-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/jina-haystack.svg)](https://pypi.org/project/jina-haystack) + +----- + +**Table of Contents** + +- [jina-haystack](#jina-haystack) + - [Installation](#installation) + - [Usage](#usage) + - [License](#license) + +## Installation + +```console +pip install jina-haystack +``` + +## Usage + +You can use `JinaTextEmbedder` and `JinaDocumentEmbedder` by importing as: + +```python +from jina_haystack.document_embedder import JinaDocumentEmbedder +from jina_haystack.text_embedder import JinaTextEmbedder +``` + +## License + +`jina-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml new file mode 100644 index 000000000..2e35201d7 --- /dev/null +++ b/integrations/jina/pyproject.toml @@ -0,0 +1,161 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "jina-haystack" +dynamic = ["version"] +description = '' +readme = "README.md" +requires-python = ">=3.7" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "Joan Fontanals Martinez", email = "joan.fontanals.martinez@jina.ai" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = ["requests", "haystack-ai"] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina/jina-haystack#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina/jina-haystack" + +[tool.hatch.version] +path = "src/jina_haystack/__about__.py" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/jina_haystack tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["jina_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["jina_haystack", "tests"] +branch = true +parallel = true +omit = [ + "src/jina_haystack/__about__.py", +] + +[tool.coverage.paths] +jina_haystack = ["src/jina_haystack", "*/jina-haystack/src/jina_haystack"] +tests = ["tests", "*/jina-haystack/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "pytest.*" +] +ignore_missing_imports = true diff --git a/integrations/jina/src/jina_haystack/__about__.py b/integrations/jina/src/jina_haystack/__about__.py new file mode 100644 index 000000000..0e4fa27cf --- /dev/null +++ b/integrations/jina/src/jina_haystack/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "0.0.1" diff --git a/integrations/jina/src/jina_haystack/__init__.py b/integrations/jina/src/jina_haystack/__init__.py new file mode 100644 index 000000000..581b23df5 --- /dev/null +++ b/integrations/jina/src/jina_haystack/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from jina_haystack.document_embedder import JinaDocumentEmbedder +from jina_haystack.text_embedder import JinaTextEmbedder + +__all__ = ["JinaDocumentEmbedder", "JinaTextEmbedder"] diff --git a/integrations/jina/src/jina_haystack/document_embedder.py b/integrations/jina/src/jina_haystack/document_embedder.py new file mode 100644 index 000000000..f030a82c0 --- /dev/null +++ b/integrations/jina/src/jina_haystack/document_embedder.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Any, Dict, List, Optional, Tuple + +import requests +from haystack import Document, component, default_to_dict +from tqdm import tqdm + +JINA_API_URL: str = "https://api.jina.ai/v1/embeddings" + + +@component +class JinaDocumentEmbedder: + """ + A component for computing Document embeddings using Jina AI models. + The embedding of each Document is stored in the `embedding` field of the Document. + + Usage example: + ```python + from haystack import Document + from jina_haystack import JinaDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = JinaDocumentEmbedder() + + result = document_embedder.run([doc]) + print(result['documents'][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "jina-embeddings-v2-base-en", + prefix: str = "", + suffix: str = "", + batch_size: int = 32, + progress_bar: bool = True, + metadata_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create a JinaDocumentEmbedder component. + :param api_key: The Jina API key. It can be explicitly provided or automatically read from the + environment variable JINA_API_KEY (recommended). + :param model_name: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/` + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. + :param batch_size: Number of Documents to encode at once. + :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments + to keep the logs clean. + :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. + :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + """ + # if the user does not provide the API key, check if it is set in the module client + if api_key is None: + try: + api_key = os.environ["JINA_API_KEY"] + except KeyError as e: + msg = ( + "JinaDocumentEmbedder expects a Jina API key. " + "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) from e + + self.model_name = model_name + self.prefix = prefix + self.suffix = suffix + self.prefix = prefix + self.suffix = suffix + self.batch_size = batch_size + self.progress_bar = progress_bar + self.metadata_fields_to_embed = metadata_fields_to_embed or [] + self.embedding_separator = embedding_separator + self._session = requests.Session() + self._session.headers.update( + { + "Authorization": f"Bearer {api_key}", + "Accept-Encoding": "identity", + "Content-type": "application/json", + } + ) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model_name} + + def to_dict(self) -> Dict[str, Any]: + """ + This method overrides the default serializer in order to avoid leaking the `api_key` value passed + to the constructor. + """ + return default_to_dict( + self, + model_name=self.model_name, + prefix=self.prefix, + suffix=self.suffix, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + metadata_fields_to_embed=self.metadata_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) + for key in self.metadata_fields_to_embed + if key in doc.meta and doc.meta[key] is not None + ] + text_to_embed = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ) + + texts_to_embed.append(text_to_embed) + return texts_to_embed + + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + Embed a list of texts in batches. + """ + + all_embeddings = [] + metadata = {} + for i in tqdm( + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + ): + batch = texts_to_embed[i : i + batch_size] + response = self._session.post(JINA_API_URL, json={"input": batch, "model": self.model_name}).json() + if "data" not in response: + raise RuntimeError(response["detail"]) + + # Sort resulting embeddings by index + sorted_embeddings = sorted(response["data"], key=lambda e: e["index"]) + embeddings = [result["embedding"] for result in sorted_embeddings] + all_embeddings.extend(embeddings) + if "model" not in metadata: + metadata["model"] = response["model"] + if "usage" not in metadata: + metadata["usage"] = dict(response["usage"].items()) + else: + metadata["usage"]["prompt_tokens"] += response["usage"]["prompt_tokens"] + metadata["usage"]["total_tokens"] += response["usage"]["total_tokens"] + + return all_embeddings, metadata + + @component.output_types(documents=List[Document], metadata=Dict[str, Any]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + + :param documents: A list of Documents to embed. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = ( + "JinaDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a string, please use the JinaTextEmbedder." + ) + raise TypeError(msg) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings, metadata = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents, "metadata": metadata} diff --git a/integrations/jina/src/jina_haystack/text_embedder.py b/integrations/jina/src/jina_haystack/text_embedder.py new file mode 100644 index 000000000..b24cf8518 --- /dev/null +++ b/integrations/jina/src/jina_haystack/text_embedder.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from typing import Any, Dict, List, Optional + +import requests +from haystack import component, default_to_dict + +JINA_API_URL: str = "https://api.jina.ai/v1/embeddings" + + +@component +class JinaTextEmbedder: + """ + A component for embedding strings using Jina models. + + Usage example: + ```python + from jina_haystack import JinaTextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = JinaTextEmbedder() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [0.017020374536514282, -0.023255806416273117, ...], + # 'metadata': {'model': 'jina-embeddings-v2-base-en', + # 'usage': {'prompt_tokens': 4, 'total_tokens': 4}}} + ``` + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "jina-embeddings-v2-base-en", + prefix: str = "", + suffix: str = "", + ): + """ + Create an JinaTextEmbedder component. + + :param api_key: The Jina API key. It can be explicitly provided or automatically read from the + environment variable JINA_API_KEY (recommended). + :param model_name: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/` + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. + """ + # if the user does not provide the API key, check if it is set in the module client + if api_key is None: + try: + api_key = os.environ["JINA_API_KEY"] + except KeyError as e: + msg = ( + "JinaTextEmbedder expects a Jina API key. " + "Set the JINA_API_KEY environment variable (recommended) or pass it explicitly." + ) + raise ValueError(msg) from e + + self.model_name = model_name + self.prefix = prefix + self.suffix = suffix + self._session = requests.Session() + self._session.headers.update( + { + "Authorization": f"Bearer {api_key}", + "Accept-Encoding": "identity", + "Content-type": "application/json", + } + ) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model_name} + + def to_dict(self) -> Dict[str, Any]: + """ + This method overrides the default serializer in order to avoid leaking the `api_key` value passed + to the constructor. + """ + + return default_to_dict(self, model_name=self.model_name, prefix=self.prefix, suffix=self.suffix) + + @component.output_types(embedding=List[float], metadata=Dict[str, Any]) + def run(self, text: str): + """Embed a string.""" + if not isinstance(text, str): + msg = ( + "JinaTextEmbedder expects a string as an input." + "In case you want to embed a list of Documents, please use the JinaDocumentEmbedder." + ) + raise TypeError(msg) + + text_to_embed = self.prefix + text + self.suffix + + resp = self._session.post(JINA_API_URL, json={"input": [text_to_embed], "model": self.model_name}).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + + metadata = {"model": resp["model"], "usage": dict(resp["usage"].items())} + embedding = resp["data"][0]["embedding"] + + return {"embedding": embedding, "metadata": metadata} diff --git a/integrations/jina/tests/__init__.py b/integrations/jina/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/jina/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/jina/tests/test_document_embedder.py b/integrations/jina/tests/test_document_embedder.py new file mode 100644 index 000000000..c32ffd500 --- /dev/null +++ b/integrations/jina/tests/test_document_embedder.py @@ -0,0 +1,239 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import json +from unittest.mock import patch + +import pytest +import requests +from haystack import Document + +from jina_haystack import JinaDocumentEmbedder + + +def mock_session_post_response(*args, **kwargs): # noqa: ARG001 + inputs = kwargs["json"]["input"] + model = kwargs["json"]["model"] + mock_response = requests.Response() + mock_response.status_code = 200 + data = [{"object": "embedding", "index": i, "embedding": [0.1, 0.2, 0.3]} for i in range(len(inputs))] + mock_response._content = json.dumps( + {"model": model, "object": "list", "usage": {"total_tokens": 4, "prompt_tokens": 4}, "data": data} + ).encode() + + return mock_response + + +class TestJinaDocumentEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "fake-api-key") + embedder = JinaDocumentEmbedder() + + assert embedder.model_name == "jina-embeddings-v2-base-en" + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.metadata_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_with_parameters(self): + embedder = JinaDocumentEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + assert embedder.model_name == "model" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("JINA_API_KEY", raising=False) + with pytest.raises(ValueError, match="JinaDocumentEmbedder expects a Jina API key"): + JinaDocumentEmbedder() + + def test_to_dict(self): + component = JinaDocumentEmbedder(api_key="fake-api-key") + data = component.to_dict() + assert data == { + "type": "jina_haystack.document_embedder.JinaDocumentEmbedder", + "init_parameters": { + "model_name": "jina-embeddings-v2-base-en", + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "metadata_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + def test_to_dict_with_custom_init_parameters(self): + component = JinaDocumentEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + data = component.to_dict() + assert data == { + "type": "jina_haystack.document_embedder.JinaDocumentEmbedder", + "init_parameters": { + "model_name": "model", + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 64, + "progress_bar": False, + "metadata_fields_to_embed": ["test_field"], + "embedding_separator": " | ", + }, + } + + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(content=f"document number {i}:\ncontent", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder = JinaDocumentEmbedder( + api_key="fake-api-key", metadata_fields_to_embed=["meta_field"], embedding_separator=" | " + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + # note that newline is replaced by space + assert prepared_texts == [ + "meta_value 0 | document number 0:\ncontent", + "meta_value 1 | document number 1:\ncontent", + "meta_value 2 | document number 2:\ncontent", + "meta_value 3 | document number 3:\ncontent", + "meta_value 4 | document number 4:\ncontent", + ] + + def test_prepare_texts_to_embed_w_suffix(self): + documents = [Document(content=f"document number {i}") for i in range(5)] + + embedder = JinaDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix") + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "my_prefix document number 0 my_suffix", + "my_prefix document number 1 my_suffix", + "my_prefix document number 2 my_suffix", + "my_prefix document number 3 my_suffix", + "my_prefix document number 4 my_suffix", + ] + + def test_embed_batch(self): + texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] + + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): + embedder = JinaDocumentEmbedder(api_key="fake-api-key", model_name="model") + + embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2) + + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + for embedding in embeddings: + assert isinstance(embedding, list) + assert len(embedding) == 3 + assert all(isinstance(x, float) for x in embedding) + + assert metadata == {"model": "model", "usage": {"prompt_tokens": 3 * 4, "total_tokens": 3 * 4}} + + def test_run(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + model = "jina-embeddings-v2-base-en" + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): + embedder = JinaDocumentEmbedder( + api_key="fake-api-key", + model_name=model, + prefix="prefix ", + suffix=" suffix", + metadata_fields_to_embed=["topic"], + embedding_separator=" | ", + ) + + result = embedder.run(documents=docs) + + documents_with_embeddings = result["documents"] + metadata = result["metadata"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 3 + assert all(isinstance(x, float) for x in doc.embedding) + assert metadata == {"model": model, "usage": {"prompt_tokens": 4, "total_tokens": 4}} + + def test_run_custom_batch_size(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + model = "jina-embeddings-v2-base-en" + with patch("requests.sessions.Session.post", side_effect=mock_session_post_response): + embedder = JinaDocumentEmbedder( + api_key="fake-api-key", + model_name=model, + prefix="prefix ", + suffix=" suffix", + metadata_fields_to_embed=["topic"], + embedding_separator=" | ", + batch_size=1, + ) + + result = embedder.run(documents=docs) + + documents_with_embeddings = result["documents"] + metadata = result["metadata"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 3 + assert all(isinstance(x, float) for x in doc.embedding) + + assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}} + + def test_run_wrong_input_format(self): + embedder = JinaDocumentEmbedder(api_key="fake-api-key") + + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="JinaDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=string_input) + + with pytest.raises(TypeError, match="JinaDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=list_integers_input) + + def test_run_on_empty_list(self): + embedder = JinaDocumentEmbedder(api_key="fake-api-key") + + empty_list_input = [] + result = embedder.run(documents=empty_list_input) + + assert result["documents"] is not None + assert not result["documents"] # empty list diff --git a/integrations/jina/tests/test_text_embedder.py b/integrations/jina/tests/test_text_embedder.py new file mode 100644 index 000000000..14476487e --- /dev/null +++ b/integrations/jina/tests/test_text_embedder.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import json +from unittest.mock import patch + +import pytest +import requests + +from jina_haystack import JinaTextEmbedder + + +class TestJinaTextEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("JINA_API_KEY", "fake-api-key") + embedder = JinaTextEmbedder() + + assert embedder.model_name == "jina-embeddings-v2-base-en" + assert embedder.prefix == "" + assert embedder.suffix == "" + + def test_init_with_parameters(self): + embedder = JinaTextEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + ) + assert embedder.model_name == "model" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("JINA_API_KEY", raising=False) + with pytest.raises(ValueError, match="JinaTextEmbedder expects a Jina API key"): + JinaTextEmbedder() + + def test_to_dict(self): + component = JinaTextEmbedder(api_key="fake-api-key") + data = component.to_dict() + assert data == { + "type": "jina_haystack.text_embedder.JinaTextEmbedder", + "init_parameters": { + "model_name": "jina-embeddings-v2-base-en", + "prefix": "", + "suffix": "", + }, + } + + def test_to_dict_with_custom_init_parameters(self): + component = JinaTextEmbedder( + api_key="fake-api-key", + model_name="model", + prefix="prefix", + suffix="suffix", + ) + data = component.to_dict() + assert data == { + "type": "jina_haystack.text_embedder.JinaTextEmbedder", + "init_parameters": { + "model_name": "model", + "prefix": "prefix", + "suffix": "suffix", + }, + } + + def test_run(self): + model = "jina-embeddings-v2-base-en" + with patch("requests.sessions.Session.post") as mock_post: + # Configure the mock to return a specific response + mock_response = requests.Response() + mock_response.status_code = 200 + mock_response._content = json.dumps( + { + "model": "jina-embeddings-v2-base-en", + "object": "list", + "usage": {"total_tokens": 6, "prompt_tokens": 6}, + "data": [{"object": "embedding", "index": 0, "embedding": [0.1, 0.2, 0.3]}], + } + ).encode() + + mock_post.return_value = mock_response + + embedder = JinaTextEmbedder(api_key="fake-api-key", model_name=model, prefix="prefix ", suffix=" suffix") + result = embedder.run(text="The food was delicious") + + assert len(result["embedding"]) == 3 + assert all(isinstance(x, float) for x in result["embedding"]) + assert result["metadata"] == { + "model": "jina-embeddings-v2-base-en", + "usage": {"prompt_tokens": 6, "total_tokens": 6}, + } + + def test_run_wrong_input_format(self): + embedder = JinaTextEmbedder(api_key="fake-api-key") + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="JinaTextEmbedder expects a string as an input"): + embedder.run(text=list_integers_input) From 22653de66d9c33f395a54b4cf070c575ecbbef2e Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Mon, 11 Dec 2023 19:19:45 +0100 Subject: [PATCH 11/29] Update labeler.yml to breaking changes in version 5 (#84) * Update labeler.yml to breaking changes in version 5 * Update labeler.yml * use `any` in labeler config * Update labeler.yml * bump * revert * checkout * try * try * revert --------- Co-authored-by: Massimiliano Pippi --- .github/labeler.yml | 37 +++++++++++++++++++++++------------ .github/workflows/labeler.yml | 2 +- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/.github/labeler.yml b/.github/labeler.yml index 5234f6a51..0b4e14660 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,24 +1,35 @@ # Integrations integration:chroma: -- integrations/chroma/**/* +- changed-files: + - any-glob-to-any-file: 'integrations/chroma/**/*' + integration:elasticsearch: -- integrations/elasticsearch/**/* +- changed-files: + - any-glob-to-any-file: 'integrations/elasticsearch/**/*' + integration:gradient: -- integrations/gradient/**/* +- changed-files: + - any-glob-to-any-file: 'integrations/gradient/**/*' + integration:instructor-embedders: -- integrations/instructor-embedders/**/* +- changed-files: + - any-glob-to-any-file: 'integrations/instructor-embedders/**/*' + integration:opensearch: -- integrations/opensearch/**/* +- changed-files: + - any-glob-to-any-file: 'integrations/opensearch/**/*' + integration:unstructured-fileconverter: -- integrations/unstructured/fileconverter/**/* +- changed-files: + - any-glob-to-any-file: 'integrations/unstructured/fileconverter/**/*' + # Topics topic:CI: -- .github/* -- .github/**/* +- changed-files: + - any-glob-to-any-file: ['.github/*', '.github/**/*'] topic:DX: -- CONTRIBUTING.md -- .pre-commit-config.yaml -- .gitignore -- requirements.txt +- changed-files: + - any-glob-to-any-file: ['CONTRIBUTING.md', '.pre-commit-config.yaml', '.gitignore', 'requirements.txt'] topic:security: -- SECURITY.md +- changed-files: + - any-glob-to-any-file: ['SECURITY.md'] diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index d3e9adbd9..2af558297 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -10,6 +10,6 @@ jobs: triage: runs-on: ubuntu-latest steps: - - uses: actions/labeler@v4 + - uses: actions/labeler@v5 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" From 09829289c44ad6af141b85b731d54271d6e78d14 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Mon, 11 Dec 2023 22:56:11 +0100 Subject: [PATCH 12/29] add missing integrations to the labeler (#94) --- .github/labeler.yml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/labeler.yml b/.github/labeler.yml index 0b4e14660..5c1e12fa8 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -3,6 +3,10 @@ integration:chroma: - changed-files: - any-glob-to-any-file: 'integrations/chroma/**/*' +integration:cohere: +- changed-files: + - any-glob-to-any-file: 'integrations/cohere/**/*' + integration:elasticsearch: - changed-files: - any-glob-to-any-file: 'integrations/elasticsearch/**/*' @@ -15,6 +19,10 @@ integration:instructor-embedders: - changed-files: - any-glob-to-any-file: 'integrations/instructor-embedders/**/*' +integration:jina: +- changed-files: + - any-glob-to-any-file: 'integrations/jina/**/*' + integration:opensearch: - changed-files: - any-glob-to-any-file: 'integrations/opensearch/**/*' From e166c2b5401ade559c464afaec3bbcb6f7eaa17f Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Tue, 12 Dec 2023 11:15:18 +0100 Subject: [PATCH 13/29] chore: fix Wheel package for instructor-embedders, uniform tests and CI (#95) * fix wheel * use hatch to run tests * bump version * fix linter --- .github/workflows/instructor_embedders.yml | 20 +++----- .../instructor_embedders/__init__.py | 7 --- .../__about__.py | 2 +- .../instructor_embedders_haystack/__init__.py | 7 +++ .../embedding_backend/__init__.py | 0 .../embedding_backend/instructor_backend.py | 0 .../instructor_document_embedder.py | 4 +- .../instructor_text_embedder.py | 4 +- .../instructor-embedders/pyproject.toml | 48 ++++++++++++++++--- .../tests/test_instructor_backend.py | 20 ++++---- .../test_instructor_document_embedder.py | 25 +++------- .../tests/test_instructor_embedders.py | 7 --- .../tests/test_instructor_text_embedder.py | 24 +++------- 13 files changed, 84 insertions(+), 84 deletions(-) delete mode 100644 integrations/instructor-embedders/instructor_embedders/__init__.py rename integrations/instructor-embedders/{instructor_embedders => instructor_embedders_haystack}/__about__.py (83%) create mode 100644 integrations/instructor-embedders/instructor_embedders_haystack/__init__.py rename integrations/instructor-embedders/{instructor_embedders => instructor_embedders_haystack}/embedding_backend/__init__.py (100%) rename integrations/instructor-embedders/{instructor_embedders => instructor_embedders_haystack}/embedding_backend/instructor_backend.py (100%) rename integrations/instructor-embedders/{instructor_embedders => instructor_embedders_haystack}/instructor_document_embedder.py (97%) rename integrations/instructor-embedders/{instructor_embedders => instructor_embedders_haystack}/instructor_text_embedder.py (96%) delete mode 100644 integrations/instructor-embedders/tests/test_instructor_embedders.py diff --git a/.github/workflows/instructor_embedders.yml b/.github/workflows/instructor_embedders.yml index 0c7765f11..05ecfb05f 100644 --- a/.github/workflows/instructor_embedders.yml +++ b/.github/workflows/instructor_embedders.yml @@ -24,19 +24,11 @@ jobs: with: python-version: '3.10' - - name: Ruff - uses: chartboost/ruff-action@v1 - with: - src: integrations/instructor-embedders - - - name: Install instructor-embedders - run: | - pip install -e .[dev] + - name: Install Hatch + run: pip install --upgrade hatch - - name: Run unit tests - run: | - pytest -v -m unit + - name: Lint + run: hatch run lint:all - - name: Run integration tests - run: | - pytest -v -m integration + - name: Run tests + run: hatch run cov diff --git a/integrations/instructor-embedders/instructor_embedders/__init__.py b/integrations/instructor-embedders/instructor_embedders/__init__.py deleted file mode 100644 index 99bed8705..000000000 --- a/integrations/instructor-embedders/instructor_embedders/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -from instructor_embedders.instructor_document_embedder import InstructorDocumentEmbedder -from instructor_embedders.instructor_text_embedder import InstructorTextEmbedder - -__all__ = ["InstructorDocumentEmbedder", "InstructorTextEmbedder"] diff --git a/integrations/instructor-embedders/instructor_embedders/__about__.py b/integrations/instructor-embedders/instructor_embedders_haystack/__about__.py similarity index 83% rename from integrations/instructor-embedders/instructor_embedders/__about__.py rename to integrations/instructor-embedders/instructor_embedders_haystack/__about__.py index bccfd8317..447ed9770 100644 --- a/integrations/instructor-embedders/instructor_embedders/__about__.py +++ b/integrations/instructor-embedders/instructor_embedders_haystack/__about__.py @@ -1,4 +1,4 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/integrations/instructor-embedders/instructor_embedders_haystack/__init__.py b/integrations/instructor-embedders/instructor_embedders_haystack/__init__.py new file mode 100644 index 000000000..88e2e9df2 --- /dev/null +++ b/integrations/instructor-embedders/instructor_embedders_haystack/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from instructor_embedders_haystack.instructor_document_embedder import InstructorDocumentEmbedder +from instructor_embedders_haystack.instructor_text_embedder import InstructorTextEmbedder + +__all__ = ["InstructorDocumentEmbedder", "InstructorTextEmbedder"] diff --git a/integrations/instructor-embedders/instructor_embedders/embedding_backend/__init__.py b/integrations/instructor-embedders/instructor_embedders_haystack/embedding_backend/__init__.py similarity index 100% rename from integrations/instructor-embedders/instructor_embedders/embedding_backend/__init__.py rename to integrations/instructor-embedders/instructor_embedders_haystack/embedding_backend/__init__.py diff --git a/integrations/instructor-embedders/instructor_embedders/embedding_backend/instructor_backend.py b/integrations/instructor-embedders/instructor_embedders_haystack/embedding_backend/instructor_backend.py similarity index 100% rename from integrations/instructor-embedders/instructor_embedders/embedding_backend/instructor_backend.py rename to integrations/instructor-embedders/instructor_embedders_haystack/embedding_backend/instructor_backend.py diff --git a/integrations/instructor-embedders/instructor_embedders/instructor_document_embedder.py b/integrations/instructor-embedders/instructor_embedders_haystack/instructor_document_embedder.py similarity index 97% rename from integrations/instructor-embedders/instructor_embedders/instructor_document_embedder.py rename to integrations/instructor-embedders/instructor_embedders_haystack/instructor_document_embedder.py index ba3f6c9b3..ecf3594e6 100644 --- a/integrations/instructor-embedders/instructor_embedders/instructor_document_embedder.py +++ b/integrations/instructor-embedders/instructor_embedders_haystack/instructor_document_embedder.py @@ -5,7 +5,7 @@ from haystack import Document, component, default_from_dict, default_to_dict -from instructor_embedders.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory +from instructor_embedders_haystack.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory @component @@ -19,7 +19,7 @@ class InstructorDocumentEmbedder: # To use this component, install the "instructor-embedders-haystack" package. # pip install instructor-embedders-haystack - from instructor_embedders.instructor_document_embedder import InstructorDocumentEmbedder + from instructor_embedders_haystack.instructor_document_embedder import InstructorDocumentEmbedder from haystack.dataclasses import Document diff --git a/integrations/instructor-embedders/instructor_embedders/instructor_text_embedder.py b/integrations/instructor-embedders/instructor_embedders_haystack/instructor_text_embedder.py similarity index 96% rename from integrations/instructor-embedders/instructor_embedders/instructor_text_embedder.py rename to integrations/instructor-embedders/instructor_embedders_haystack/instructor_text_embedder.py index 043d562d5..6c665ab3b 100644 --- a/integrations/instructor-embedders/instructor_embedders/instructor_text_embedder.py +++ b/integrations/instructor-embedders/instructor_embedders_haystack/instructor_text_embedder.py @@ -5,7 +5,7 @@ from haystack import component, default_from_dict, default_to_dict -from instructor_embedders.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory +from instructor_embedders_haystack.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory @component @@ -18,7 +18,7 @@ class InstructorTextEmbedder: # To use this component, install the "instructor-embedders-haystack" package. # pip install instructor-embedders-haystack - from instructor_embedders.instructor_text_embedder import InstructorTextEmbedder + from instructor_embedders_haystack.instructor_text_embedder import InstructorTextEmbedder text = "It clearly says online this will work on a Mac OS system. The disk comes and it does not, only Windows. Do Not order this if you have a Mac!!" instruction = ( diff --git a/integrations/instructor-embedders/pyproject.toml b/integrations/instructor-embedders/pyproject.toml index b4b8ac3ce..c10f43703 100644 --- a/integrations/instructor-embedders/pyproject.toml +++ b/integrations/instructor-embedders/pyproject.toml @@ -50,12 +50,12 @@ dependencies = [ dev = ["pytest"] [project.urls] -Documentation = "https://github.com/deepset-ai/haystack-extras/tree/main/components/instructor-embedders#readme" -Issues = "https://github.com/deepset-ai/haystack-extras/issues" -Source = "https://github.com/deepset-ai/haystack-extras/tree/main/components/instructor-embedders" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/components/instructor-embedders#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/components/instructor-embedders" [tool.hatch.version] -path = "instructor_embedders/__about__.py" +path = "instructor_embedders_haystack/__about__.py" [tool.hatch.envs.default] dependencies = ["pytest", "pytest-cov"] @@ -63,9 +63,33 @@ dependencies = ["pytest", "pytest-cov"] [tool.hatch.envs.default.scripts] cov = "pytest --cov-report=term-missing --cov-config=pyproject.toml --cov=instructor-embedders --cov=tests" no-cov = "cov --no-cov" +test = "pytest {args:tests}" [[tool.hatch.envs.test.matrix]] -python = ["37", "38", "39", "310", "311"] +python = ["38", "39", "310", "311"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:instructor_embedders_haystack tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] [tool.coverage.run] branch = true @@ -76,7 +100,7 @@ omit = ["instructor_embedders/__about__.py"] exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [tool.ruff] -target-version = "py37" +target-version = "py38" line-length = 120 select = [ "A", @@ -132,8 +156,18 @@ ban-relative-imports = "all" [tool.pytest.ini_options] minversion = "6.0" addopts = "--strict-markers" -markers = ["integration: integration tests", "unit: unit tests"] +markers = ["integration: integration tests"] log_cli = true [tool.black] line-length = 120 + +[[tool.mypy.overrides]] +module = [ + "instructor_embedders_haystack.*", + "InstructorEmbedding.*", + "haystack.*", + "pytest.*", + "numpy.*", +] +ignore_missing_imports = true \ No newline at end of file diff --git a/integrations/instructor-embedders/tests/test_instructor_backend.py b/integrations/instructor-embedders/tests/test_instructor_backend.py index 6cd9d8b77..27e31317a 100644 --- a/integrations/instructor-embedders/tests/test_instructor_backend.py +++ b/integrations/instructor-embedders/tests/test_instructor_backend.py @@ -1,12 +1,9 @@ from unittest.mock import patch -import pytest +from instructor_embedders_haystack.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory -from instructor_embedders.embedding_backend.instructor_backend import _InstructorEmbeddingBackendFactory - -@pytest.mark.unit -@patch("instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR") +@patch("instructor_embedders_haystack.embedding_backend.instructor_backend.INSTRUCTOR") def test_factory_behavior(mock_instructor): # noqa: ARG001 embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="hkunlp/instructor-large", device="cpu" @@ -19,9 +16,11 @@ def test_factory_behavior(mock_instructor): # noqa: ARG001 assert same_embedding_backend is embedding_backend assert another_embedding_backend is not embedding_backend + # restore the factory state + _InstructorEmbeddingBackendFactory._instances = {} + -@pytest.mark.unit -@patch("instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR") +@patch("instructor_embedders_haystack.embedding_backend.instructor_backend.INSTRUCTOR") def test_model_initialization(mock_instructor): _InstructorEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token="huggingface_auth_token" @@ -29,10 +28,11 @@ def test_model_initialization(mock_instructor): mock_instructor.assert_called_once_with( model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token="huggingface_auth_token" ) + # restore the factory state + _InstructorEmbeddingBackendFactory._instances = {} -@pytest.mark.unit -@patch("instructor_embedders.embedding_backend.instructor_backend.INSTRUCTOR") +@patch("instructor_embedders_haystack.embedding_backend.instructor_backend.INSTRUCTOR") def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 embedding_backend = _InstructorEmbeddingBackendFactory.get_embedding_backend( model_name_or_path="hkunlp/instructor-base" @@ -42,3 +42,5 @@ def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 embedding_backend.embed(data=data, normalize_embeddings=True) embedding_backend.model.encode.assert_called_once_with(data, normalize_embeddings=True) + # restore the factory state + _InstructorEmbeddingBackendFactory._instances = {} diff --git a/integrations/instructor-embedders/tests/test_instructor_document_embedder.py b/integrations/instructor-embedders/tests/test_instructor_document_embedder.py index 6d9434976..53466ced2 100644 --- a/integrations/instructor-embedders/tests/test_instructor_document_embedder.py +++ b/integrations/instructor-embedders/tests/test_instructor_document_embedder.py @@ -4,11 +4,10 @@ import pytest from haystack import Document -from instructor_embedders.instructor_document_embedder import InstructorDocumentEmbedder +from instructor_embedders_haystack.instructor_document_embedder import InstructorDocumentEmbedder class TestInstructorDocumentEmbedder: - @pytest.mark.unit def test_init_default(self): """ Test default initialization parameters for InstructorDocumentEmbedder. @@ -24,7 +23,6 @@ def test_init_default(self): assert embedder.metadata_fields_to_embed == [] assert embedder.embedding_separator == "\n" - @pytest.mark.unit def test_init_with_parameters(self): """ Test custom initialization parameters for InstructorDocumentEmbedder. @@ -50,7 +48,6 @@ def test_init_with_parameters(self): assert embedder.metadata_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " - @pytest.mark.unit def test_to_dict(self): """ Test serialization of InstructorDocumentEmbedder to a dictionary, using default initialization parameters. @@ -58,7 +55,7 @@ def test_to_dict(self): embedder = InstructorDocumentEmbedder(model_name_or_path="hkunlp/instructor-base") embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cpu", @@ -72,7 +69,6 @@ def test_to_dict(self): }, } - @pytest.mark.unit def test_to_dict_with_custom_init_parameters(self): """ Test serialization of InstructorDocumentEmbedder to a dictionary, using custom initialization parameters. @@ -90,7 +86,7 @@ def test_to_dict_with_custom_init_parameters(self): ) embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cuda", @@ -104,13 +100,12 @@ def test_to_dict_with_custom_init_parameters(self): }, } - @pytest.mark.unit def test_from_dict(self): """ Test deserialization of InstructorDocumentEmbedder from a dictionary, using default initialization parameters. """ embedder_dict = { - "type": "instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cpu", @@ -134,13 +129,12 @@ def test_from_dict(self): assert embedder.metadata_fields_to_embed == [] assert embedder.embedding_separator == "\n" - @pytest.mark.unit def test_from_dict_with_custom_init_parameters(self): """ Test deserialization of InstructorDocumentEmbedder from a dictionary, using custom initialization parameters. """ embedder_dict = { - "type": "instructor_embedders.instructor_document_embedder.InstructorDocumentEmbedder", + "type": "instructor_embedders_haystack.instructor_document_embedder.InstructorDocumentEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cuda", @@ -164,8 +158,7 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.metadata_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " - @pytest.mark.unit - @patch("instructor_embedders.instructor_document_embedder._InstructorEmbeddingBackendFactory") + @patch("instructor_embedders_haystack.instructor_document_embedder._InstructorEmbeddingBackendFactory") def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. @@ -177,8 +170,7 @@ def test_warmup(self, mocked_factory): model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token=None ) - @pytest.mark.unit - @patch("instructor_embedders.instructor_document_embedder._InstructorEmbeddingBackendFactory") + @patch("instructor_embedders_haystack.instructor_document_embedder._InstructorEmbeddingBackendFactory") def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. @@ -189,7 +181,6 @@ def test_warmup_does_not_reload(self, mocked_factory): embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once() - @pytest.mark.unit def test_embed(self): """ Test for checking output dimensions and embedding dimensions. @@ -209,7 +200,6 @@ def test_embed(self): assert isinstance(doc.embedding, list) assert isinstance(doc.embedding[0], float) - @pytest.mark.unit def test_embed_incorrect_input_format(self): """ Test for checking incorrect input format when creating embedding. @@ -225,7 +215,6 @@ def test_embed_incorrect_input_format(self): with pytest.raises(TypeError, match="InstructorDocumentEmbedder expects a list of Documents as input."): embedder.run(documents=list_integers_input) - @pytest.mark.unit def test_embed_metadata(self): """ Test for checking output dimensions and embedding dimensions for documents diff --git a/integrations/instructor-embedders/tests/test_instructor_embedders.py b/integrations/instructor-embedders/tests/test_instructor_embedders.py deleted file mode 100644 index 1abbe3b32..000000000 --- a/integrations/instructor-embedders/tests/test_instructor_embedders.py +++ /dev/null @@ -1,7 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 - - -def test_instructor_embedders(): - assert True diff --git a/integrations/instructor-embedders/tests/test_instructor_text_embedder.py b/integrations/instructor-embedders/tests/test_instructor_text_embedder.py index 4dd1b13af..a4adde771 100644 --- a/integrations/instructor-embedders/tests/test_instructor_text_embedder.py +++ b/integrations/instructor-embedders/tests/test_instructor_text_embedder.py @@ -3,11 +3,10 @@ import numpy as np import pytest -from instructor_embedders.instructor_text_embedder import InstructorTextEmbedder +from instructor_embedders_haystack.instructor_text_embedder import InstructorTextEmbedder class TestInstructorTextEmbedder: - @pytest.mark.unit def test_init_default(self): """ Test default initialization parameters for InstructorTextEmbedder. @@ -21,7 +20,6 @@ def test_init_default(self): assert embedder.progress_bar is True assert embedder.normalize_embeddings is False - @pytest.mark.unit def test_init_with_parameters(self): """ Test custom initialization parameters for InstructorTextEmbedder. @@ -43,7 +41,6 @@ def test_init_with_parameters(self): assert embedder.progress_bar is False assert embedder.normalize_embeddings is True - @pytest.mark.unit def test_to_dict(self): """ Test serialization of InstructorTextEmbedder to a dictionary, using default initialization parameters. @@ -51,7 +48,7 @@ def test_to_dict(self): embedder = InstructorTextEmbedder(model_name_or_path="hkunlp/instructor-base") embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", + "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cpu", @@ -63,7 +60,6 @@ def test_to_dict(self): }, } - @pytest.mark.unit def test_to_dict_with_custom_init_parameters(self): """ Test serialization of InstructorTextEmbedder to a dictionary, using custom initialization parameters. @@ -79,7 +75,7 @@ def test_to_dict_with_custom_init_parameters(self): ) embedder_dict = embedder.to_dict() assert embedder_dict == { - "type": "instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", + "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cuda", @@ -91,13 +87,12 @@ def test_to_dict_with_custom_init_parameters(self): }, } - @pytest.mark.unit def test_from_dict(self): """ Test deserialization of InstructorTextEmbedder from a dictionary, using default initialization parameters. """ embedder_dict = { - "type": "instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", + "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cpu", @@ -117,13 +112,12 @@ def test_from_dict(self): assert embedder.progress_bar is True assert embedder.normalize_embeddings is False - @pytest.mark.unit def test_from_dict_with_custom_init_parameters(self): """ Test deserialization of InstructorTextEmbedder from a dictionary, using custom initialization parameters. """ embedder_dict = { - "type": "instructor_embedders.instructor_text_embedder.InstructorTextEmbedder", + "type": "instructor_embedders_haystack.instructor_text_embedder.InstructorTextEmbedder", "init_parameters": { "model_name_or_path": "hkunlp/instructor-base", "device": "cuda", @@ -143,8 +137,7 @@ def test_from_dict_with_custom_init_parameters(self): assert embedder.progress_bar is False assert embedder.normalize_embeddings is True - @pytest.mark.unit - @patch("instructor_embedders.instructor_text_embedder._InstructorEmbeddingBackendFactory") + @patch("instructor_embedders_haystack.instructor_text_embedder._InstructorEmbeddingBackendFactory") def test_warmup(self, mocked_factory): """ Test for checking embedder instances after warm-up. @@ -156,8 +149,7 @@ def test_warmup(self, mocked_factory): model_name_or_path="hkunlp/instructor-base", device="cpu", use_auth_token=None ) - @pytest.mark.unit - @patch("instructor_embedders.instructor_text_embedder._InstructorEmbeddingBackendFactory") + @patch("instructor_embedders_haystack.instructor_text_embedder._InstructorEmbeddingBackendFactory") def test_warmup_does_not_reload(self, mocked_factory): """ Test for checking backend instances after multiple warm-ups. @@ -168,7 +160,6 @@ def test_warmup_does_not_reload(self, mocked_factory): embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once() - @pytest.mark.unit def test_embed(self): """ Test for checking output dimensions and embedding dimensions. @@ -185,7 +176,6 @@ def test_embed(self): assert isinstance(embedding, list) assert all(isinstance(emb, float) for emb in embedding) - @pytest.mark.unit def test_run_wrong_incorrect_format(self): """ Test for checking incorrect input format when creating embedding. From 3eb6379f8444adb46dcca62c3e91466f02185b14 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Tue, 12 Dec 2023 11:17:53 +0100 Subject: [PATCH 14/29] fix project urls (#96) --- integrations/chroma/pyproject.toml | 6 +++--- integrations/cohere/pyproject.toml | 6 +++--- integrations/elasticsearch/pyproject.toml | 4 ++-- integrations/gradient/pyproject.toml | 6 +++--- integrations/jina/pyproject.toml | 4 ++-- integrations/unstructured/fileconverter/pyproject.toml | 4 ++-- nodes/text2speech/pyproject.toml | 6 +++--- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index d19461895..8bace9a1b 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -28,9 +28,9 @@ dependencies = [ ] [project.urls] -Documentation = "https://github.com/masci/chroma-haystack#readme" -Issues = "https://github.com/masci/chroma-haystack/issues" -Source = "https://github.com/masci/chroma-haystack" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma" [tool.hatch.version] path = "src/chroma_haystack/__about__.py" diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 5d589df7b..71501d28e 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -30,9 +30,9 @@ dependencies = [ ] [project.urls] -Documentation = "https://github.com/unknown/cohere-haystack#readme" -Issues = "https://github.com/unknown/cohere-haystack/issues" -Source = "https://github.com/unknown/cohere-haystack" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere" [tool.hatch.version] path = "src/cohere_haystack/__about__.py" diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index f67e9cc35..12922fbc2 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -29,9 +29,9 @@ dependencies = [ ] [project.urls] -Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/document_stores/elasticsearch#readme" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/elasticsearch#readme" Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" -Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/document_stores/elasticsearch" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/elasticsearch" [tool.hatch.version] path = "src/elasticsearch_haystack/__about__.py" diff --git a/integrations/gradient/pyproject.toml b/integrations/gradient/pyproject.toml index 79a39a384..ae9e047c7 100644 --- a/integrations/gradient/pyproject.toml +++ b/integrations/gradient/pyproject.toml @@ -30,9 +30,9 @@ dependencies = [ ] [project.urls] -Documentation = "https://github.com/unknown/gradient-haystack#readme" -Issues = "https://github.com/unknown/gradient-haystack/issues" -Source = "https://github.com/unknown/gradient-haystack" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/gradient#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/gradient" [tool.hatch.version] path = "src/gradient_haystack/__about__.py" diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index 2e35201d7..a6a57efda 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -27,9 +27,9 @@ classifiers = [ dependencies = ["requests", "haystack-ai"] [project.urls] -Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina/jina-haystack#readme" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina#readme" Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" -Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina/jina-haystack" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina" [tool.hatch.version] path = "src/jina_haystack/__about__.py" diff --git a/integrations/unstructured/fileconverter/pyproject.toml b/integrations/unstructured/fileconverter/pyproject.toml index faaba8d71..c19409cde 100644 --- a/integrations/unstructured/fileconverter/pyproject.toml +++ b/integrations/unstructured/fileconverter/pyproject.toml @@ -30,9 +30,9 @@ dependencies = [ ] [project.urls] -Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/components/converters/unstructured_fileconverter#readme" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/unstructured/fileconverter#readme" Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" -Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/components/converters/unstructured_fileconverter" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/unstructured/fileconverter" [tool.hatch.version] path = "src/unstructured_fileconverter_haystack/__about__.py" diff --git a/nodes/text2speech/pyproject.toml b/nodes/text2speech/pyproject.toml index 4dfd65e1d..cc39e8780 100644 --- a/nodes/text2speech/pyproject.toml +++ b/nodes/text2speech/pyproject.toml @@ -43,9 +43,9 @@ dynamic = ["version"] dev = ["pytest", "pylint", "black"] [project.urls] -Documentation = "https://github.com/deepset-ai/haystack-extras/tree/main/nodes/text2speech#readme" -Issues = "https://github.com/deepset-ai/haystack-extras/issues" -Source = "https://github.com/deepset-ai/haystack-extras/tree/main/nodes/text2speech" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/nodes/text2speech#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/nodes/text2speech" [tool.hatch.version] path = "text2speech/__about__.py" From 8454bda062c78f67e3209c63d272ecb04569fdbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bilge=20Y=C3=BCcel?= Date: Tue, 12 Dec 2023 16:48:32 +0300 Subject: [PATCH 15/29] Update README.md (#97) --- integrations/opensearch/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/opensearch/README.md b/integrations/opensearch/README.md index 9c6dd6e6c..40a2f8eaa 100644 --- a/integrations/opensearch/README.md +++ b/integrations/opensearch/README.md @@ -1,4 +1,4 @@ -[![test](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/document_stores_opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/document_stores_opensearch.yml) +[![test](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) From d45eca2854fbfaa4bdecfb31ab540e2a5f5bb660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bilge=20Y=C3=BCcel?= Date: Wed, 13 Dec 2023 12:03:52 +0300 Subject: [PATCH 16/29] Change default 'input_type' for CohereTextEmbedder (#99) * Change default 'input_type' for CohereTextEmbedder * Update tests after updating the default value of input_type --- .../src/cohere_haystack/embedders/text_embedder.py | 2 +- integrations/cohere/tests/test_text_embedder.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 5d139427f..f21060965 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -35,7 +35,7 @@ def __init__( self, api_key: Optional[str] = None, model_name: str = "embed-english-v2.0", - input_type: str = "search_document", + input_type: str = "search_query", api_base_url: str = COHERE_API_URL, truncate: str = "END", use_async_client: bool = False, diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index 9ec673c98..46f77cb43 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -20,7 +20,7 @@ def test_init_default(self): assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-english-v2.0" - assert embedder.input_type == "search_document" + assert embedder.input_type == "search_query" assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" assert embedder.use_async_client is False @@ -34,7 +34,7 @@ def test_init_with_parameters(self): embedder = CohereTextEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", - input_type="search_query", + input_type="classification", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -43,7 +43,7 @@ def test_init_with_parameters(self): ) assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-multilingual-v2.0" - assert embedder.input_type == "search_query" + assert embedder.input_type == "classification" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" assert embedder.use_async_client is True @@ -60,7 +60,7 @@ def test_to_dict(self): "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { "model_name": "embed-english-v2.0", - "input_type": "search_document", + "input_type": "search_query", "api_base_url": COHERE_API_URL, "truncate": "END", "use_async_client": False, @@ -76,7 +76,7 @@ def test_to_dict_with_custom_init_parameters(self): embedder_component = CohereTextEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", - input_type="search_query", + input_type="classification", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -88,7 +88,7 @@ def test_to_dict_with_custom_init_parameters(self): "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { "model_name": "embed-multilingual-v2.0", - "input_type": "search_query", + "input_type": "classification", "api_base_url": "https://custom-api-base-url.com", "truncate": "START", "use_async_client": True, From 824f57c921dca8900f5ca981e45decdc03f28fc0 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Thu, 14 Dec 2023 13:12:37 +0100 Subject: [PATCH 17/29] Add Google Vertex AI integration (#101) * Add Google Vertex AI integration * Add Gemini * Add label --- .github/labeler.yml | 54 +++-- .github/workflows/google_vertex.yml | 56 +++++ README.md | 10 +- integrations/google-vertex/LICENSE.txt | 201 ++++++++++++++++++ integrations/google-vertex/README.md | 43 ++++ integrations/google-vertex/pyproject.toml | 174 +++++++++++++++ .../src/google_vertex_haystack/__about__.py | 4 + .../src/google_vertex_haystack/__init__.py | 3 + .../generators/__init__.py | 3 + .../generators/captioner.py | 53 +++++ .../generators/code_generator.py | 54 +++++ .../generators/gemini.py | 105 +++++++++ .../generators/image_generator.py | 55 +++++ .../generators/question_answering.py | 53 +++++ .../generators/text_generator.py | 81 +++++++ integrations/google-vertex/tests/__init__.py | 3 + .../google-vertex/tests/test_captioner.py | 76 +++++++ .../tests/test_code_generator.py | 79 +++++++ .../tests/test_image_generator.py | 92 ++++++++ .../tests/test_question_answering.py | 83 ++++++++ .../tests/test_text_generator.py | 96 +++++++++ 21 files changed, 1354 insertions(+), 24 deletions(-) create mode 100644 .github/workflows/google_vertex.yml create mode 100644 integrations/google-vertex/LICENSE.txt create mode 100644 integrations/google-vertex/README.md create mode 100644 integrations/google-vertex/pyproject.toml create mode 100644 integrations/google-vertex/src/google_vertex_haystack/__about__.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/__init__.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/__init__.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/captioner.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/code_generator.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/image_generator.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/question_answering.py create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/text_generator.py create mode 100644 integrations/google-vertex/tests/__init__.py create mode 100644 integrations/google-vertex/tests/test_captioner.py create mode 100644 integrations/google-vertex/tests/test_code_generator.py create mode 100644 integrations/google-vertex/tests/test_image_generator.py create mode 100644 integrations/google-vertex/tests/test_question_answering.py create mode 100644 integrations/google-vertex/tests/test_text_generator.py diff --git a/.github/labeler.yml b/.github/labeler.yml index 5c1e12fa8..e1e9727f6 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -1,43 +1,53 @@ # Integrations integration:chroma: -- changed-files: - - any-glob-to-any-file: 'integrations/chroma/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/chroma/**/*" integration:cohere: -- changed-files: - - any-glob-to-any-file: 'integrations/cohere/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/cohere/**/*" integration:elasticsearch: -- changed-files: - - any-glob-to-any-file: 'integrations/elasticsearch/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/elasticsearch/**/*" + +integration:google-vertex: + - changed-files: + - any-glob-to-any-file: "integrations/google-vertex/**/*" integration:gradient: -- changed-files: - - any-glob-to-any-file: 'integrations/gradient/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/gradient/**/*" integration:instructor-embedders: -- changed-files: - - any-glob-to-any-file: 'integrations/instructor-embedders/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/instructor-embedders/**/*" integration:jina: -- changed-files: - - any-glob-to-any-file: 'integrations/jina/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/jina/**/*" integration:opensearch: -- changed-files: - - any-glob-to-any-file: 'integrations/opensearch/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/opensearch/**/*" integration:unstructured-fileconverter: -- changed-files: - - any-glob-to-any-file: 'integrations/unstructured/fileconverter/**/*' + - changed-files: + - any-glob-to-any-file: "integrations/unstructured/fileconverter/**/*" # Topics topic:CI: -- changed-files: - - any-glob-to-any-file: ['.github/*', '.github/**/*'] + - changed-files: + - any-glob-to-any-file: [".github/*", ".github/**/*"] topic:DX: -- changed-files: - - any-glob-to-any-file: ['CONTRIBUTING.md', '.pre-commit-config.yaml', '.gitignore', 'requirements.txt'] + - changed-files: + - any-glob-to-any-file: + [ + "CONTRIBUTING.md", + ".pre-commit-config.yaml", + ".gitignore", + "requirements.txt", + ] topic:security: -- changed-files: - - any-glob-to-any-file: ['SECURITY.md'] + - changed-files: + - any-glob-to-any-file: ["SECURITY.md"] diff --git a/.github/workflows/google_vertex.yml b/.github/workflows/google_vertex.yml new file mode 100644 index 000000000..7be3973bf --- /dev/null +++ b/.github/workflows/google_vertex.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / google-vertex + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/google-vertex/**" + - ".github/workflows/google-vertex.yml" + +defaults: + run: + working-directory: integrations/google-vertex + +concurrency: + group: google-vertex-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov diff --git a/README.md b/README.md index 024945b1c..395f59db4 100644 --- a/README.md +++ b/README.md @@ -11,16 +11,19 @@ You will need `hatch` to work on or create new integrations. Run `pip install ha All the integrations are self contained, so the first step before working on one is to `cd` into the proper folder. For example, to work on the Chroma Document Store, from the root of the repo: + ```sh $ cd integrations/chroma ``` From there, you can run the tests with `hatch`, that will take care of setting up an isolated Python environment: + ```sh hatch run test ``` Similarly, to run the linters: + ```sh hatch run lint:all ``` @@ -31,11 +34,13 @@ hatch run lint:all > you're integrating Haystack with. For example, a deepset integration would be named as `deepset-haystack`. To create a new integration, from the root of the repo change directory into `integrations`: + ```sh cd integrations ``` From there, use `hatch` to create the scaffold of the new integration: + ```sh $ hatch --config hatch.toml new -i Project name: deepset-haystack @@ -58,10 +63,11 @@ deepset-haystack | Package | Type | PyPi Package | Status | | ------------------------------------------------------------------------------- | ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | -| [cohere-haystack](integrations/cohere/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | +| [cohere-haystack](integrations/cohere/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | +| [google-vertex-haystack](integrations/google-vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google-vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google-vertex.yml) | | [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | | [instructor-embedders-haystack](integrations/instructor-embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/fileconverter/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured / fileconverter](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml) | -| [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | +| [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | diff --git a/integrations/google-vertex/LICENSE.txt b/integrations/google-vertex/LICENSE.txt new file mode 100644 index 000000000..6134ab324 --- /dev/null +++ b/integrations/google-vertex/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-present deepset GmbH + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/integrations/google-vertex/README.md b/integrations/google-vertex/README.md new file mode 100644 index 000000000..17a445c02 --- /dev/null +++ b/integrations/google-vertex/README.md @@ -0,0 +1,43 @@ +# google-vertex-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) + +--- + +**Table of Contents** + +- [google-vertex-haystack](#google-vertex-haystack) + - [Installation](#installation) + - [Contributing](#contributing) + - [License](#license) + +## Installation + +```console +pip install google-vertex-haystack +``` + +## Contributing + +`hatch` is the best way to interact with this project, to install it: + +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: + +``` +hatch run test +``` + +To run the linters `ruff` and `mypy`: + +``` +hatch run lint:all +``` + +## License + +`google-vertex-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/google-vertex/pyproject.toml b/integrations/google-vertex/pyproject.toml new file mode 100644 index 000000000..2455b4fa9 --- /dev/null +++ b/integrations/google-vertex/pyproject.toml @@ -0,0 +1,174 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "google-vertex-haystack" +dynamic = ["version"] +description = '' +readme = "README.md" +requires-python = ">=3.7" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "google-cloud-aiplatform", +] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google-vertex#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google-vertex" + +[tool.hatch.version] +path = "src/google_vertex_haystack/__about__.py" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/google_vertex_haystack tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["google_vertex_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["google_vertex_haystack", "tests"] +branch = true +parallel = true +omit = [ + "src/google_vertex_haystack/__about__.py", +] + +[tool.coverage.paths] +google_vertex_haystack = ["src/google_vertex_haystack", "*/google-vertex-haystack/src/google_vertex_haystack"] +tests = ["tests", "*/google-vertex-haystack/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[[tool.mypy.overrides]] +module = [ + "vertexai.*", + "haystack.*", + "pytest.*", + "numpy.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = [ + "integration: integration tests", + "embedders: embedders tests", + "generators: generators tests", +] +log_cli = true \ No newline at end of file diff --git a/integrations/google-vertex/src/google_vertex_haystack/__about__.py b/integrations/google-vertex/src/google_vertex_haystack/__about__.py new file mode 100644 index 000000000..0e4fa27cf --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "0.0.1" diff --git a/integrations/google-vertex/src/google_vertex_haystack/__init__.py b/integrations/google-vertex/src/google_vertex_haystack/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/__init__.py b/integrations/google-vertex/src/google_vertex_haystack/generators/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/captioner.py b/integrations/google-vertex/src/google_vertex_haystack/generators/captioner.py new file mode 100644 index 000000000..83322b33b --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/captioner.py @@ -0,0 +1,53 @@ +import logging +from typing import Any, Dict, List, Optional + +import vertexai +from haystack.core.component import component +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses.byte_stream import ByteStream +from vertexai.vision_models import Image, ImageTextModel + +logger = logging.getLogger(__name__) + + +@component +class VertexAIImageCaptioner: + def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): + """ + Generate image captions using a Google Vertex AI model. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "imagetext". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `ImageTextModel.get_captions()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._kwargs = kwargs + + self._model = ImageTextModel.from_pretrained(self._model_name) + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageCaptioner": + return default_from_dict(cls, data) + + @component.output_types(captions=List[str]) + def run(self, image: ByteStream): + captions = self._model.get_captions(image=Image(image.data), **self._kwargs) + return {"captions": captions} diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/code_generator.py b/integrations/google-vertex/src/google_vertex_haystack/generators/code_generator.py new file mode 100644 index 000000000..1914af289 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/code_generator.py @@ -0,0 +1,54 @@ +import logging +from typing import Any, Dict, List, Optional + +import vertexai +from haystack.core.component import component +from haystack.core.serialization import default_from_dict, default_to_dict +from vertexai.language_models import CodeGenerationModel + +logger = logging.getLogger(__name__) + + +@component +class VertexAICodeGenerator: + def __init__(self, *, model: str = "code-bison", project_id: str, location: Optional[str] = None, **kwargs): + """ + Generate code using a Google Vertex AI model. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "text-bison". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `TextGenerationModel.predict()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._kwargs = kwargs + + self._model = CodeGenerationModel.from_pretrained(self._model_name) + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VertexAICodeGenerator": + return default_from_dict(cls, data) + + @component.output_types(answers=List[str]) + def run(self, prefix: str, suffix: Optional[str] = None): + res = self._model.predict(prefix=prefix, suffix=suffix, **self._kwargs) + # Handle the case where the model returns multiple candidates + answers = [c.text for c in res.candidates] if hasattr(res, "candidates") else [res.text] + return {"answers": answers} diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py new file mode 100644 index 000000000..b01dc6795 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py @@ -0,0 +1,105 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +import vertexai +from haystack.core.component import component +from haystack.core.component.types import Variadic +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses.byte_stream import ByteStream +from vertexai.preview.generative_models import ( + Content, + GenerativeModel, + Part, +) + +logger = logging.getLogger(__name__) + + +@component +class GeminiGenerator: + def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, location: Optional[str] = None, **kwargs): + """ + Multi modal generator using Gemini model via Google Vertex AI. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "gemini-pro-vision". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `GenerativeModel.generate_content()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._kwargs = kwargs + + if kwargs.get("stream"): + msg = "The `stream` parameter is not supported by the Gemini generator." + raise ValueError(msg) + + self._model = GenerativeModel(self._model_name) + + def to_dict(self) -> Dict[str, Any]: + # TODO: This is not fully implemented yet + return default_to_dict( + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GeminiGenerator": + # TODO: This is not fully implemented yet + return default_from_dict(cls, data) + + def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: + if isinstance(part, str): + return Part.from_text(part) + elif isinstance(part, ByteStream): + return Part.from_data(part.data, part.mime_type) + elif isinstance(part, Part): + return part + else: + msg = f"Unsupported type {type(part)} for part {part}" + raise ValueError(msg) + + @component.output_types(answers=List[Union[str, Dict[str, str]]]) + def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]): + converted_parts = [self._convert_part(p) for p in parts] + + contents = [Content(parts=converted_parts, role="user")] + res = self._model.generate_content(contents=contents, **self._kwargs) + self._model.start_chat() + answers = [] + for candidate in res.candidates: + for part in candidate.content.parts: + if part._raw_part.text != "": + answers.append(part.text) + elif part.function_call is not None: + function_call = { + "name": part.function_call.name, + "args": dict(part.function_call.args.items()), + } + answers.append(function_call) + + return {"answers": answers} + + +# generator = GeminiGenerator(project_id="infinite-byte-223810") +# res = generator.run(["What can you do for me?"]) +# res +# another_res = generator.run(["Can you solve this math problems?", "2 + 2", "3 + 3", "1 / 1"]) +# another_res["answers"] +# from pathlib import Path + +# image = ByteStream.from_file_path( +# Path("/Users/silvanocerza/Downloads/photo_2023-11-07_11-45-42.jpg"), mime_type="image/jpeg" +# ) +# res = generator.run(["What is this about?", image]) +# res["answers"] diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/image_generator.py b/integrations/google-vertex/src/google_vertex_haystack/generators/image_generator.py new file mode 100644 index 000000000..67d270347 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/image_generator.py @@ -0,0 +1,55 @@ +import logging +from typing import Any, Dict, List, Optional + +import vertexai +from haystack.core.component import component +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses.byte_stream import ByteStream +from vertexai.preview.vision_models import ImageGenerationModel + +logger = logging.getLogger(__name__) + + +@component +class VertexAIImageGenerator: + def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): + """ + Generates images using a Google Vertex AI model. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "imagetext". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `ImageGenerationModel.generate_images()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._kwargs = kwargs + + self._model = ImageGenerationModel.from_pretrained(self._model_name) + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageGenerator": + return default_from_dict(cls, data) + + @component.output_types(images=List[ByteStream]) + def run(self, prompt: str, negative_prompt: Optional[str] = None): + negative_prompt = negative_prompt or self._kwargs.get("negative_prompt") + res = self._model.generate_images(prompt=prompt, negative_prompt=negative_prompt, **self._kwargs) + images = [ByteStream(data=i._image_bytes, metadata=i.generation_parameters) for i in res.images] + return {"images": images} diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/question_answering.py b/integrations/google-vertex/src/google_vertex_haystack/generators/question_answering.py new file mode 100644 index 000000000..276364227 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/question_answering.py @@ -0,0 +1,53 @@ +import logging +from typing import Any, Dict, List, Optional + +import vertexai +from haystack.core.component import component +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses.byte_stream import ByteStream +from vertexai.vision_models import Image, ImageTextModel + +logger = logging.getLogger(__name__) + + +@component +class VertexAIImageQA: + def __init__(self, *, model: str = "imagetext", project_id: str, location: Optional[str] = None, **kwargs): + """ + Answers questions about an image using a Google Vertex AI model. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "imagetext". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `ImageTextModel.ask_question()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._kwargs = kwargs + + self._model = ImageTextModel.from_pretrained(self._model_name) + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VertexAIImageQA": + return default_from_dict(cls, data) + + @component.output_types(answers=List[str]) + def run(self, image: ByteStream, question: str): + answers = self._model.ask_question(image=Image(image.data), question=question, **self._kwargs) + return {"answers": answers} diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/text_generator.py b/integrations/google-vertex/src/google_vertex_haystack/generators/text_generator.py new file mode 100644 index 000000000..6022bcf4f --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/text_generator.py @@ -0,0 +1,81 @@ +import importlib +import logging +from dataclasses import fields +from typing import Any, Dict, List, Optional + +import vertexai +from haystack.core.component import component +from haystack.core.serialization import default_from_dict, default_to_dict +from vertexai.language_models import TextGenerationModel + +logger = logging.getLogger(__name__) + + +@component +class VertexAITextGenerator: + def __init__(self, *, model: str = "text-bison", project_id: str, location: Optional[str] = None, **kwargs): + """ + Generate text using a Google Vertex AI model. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "text-bison". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `TextGenerationModel.predict()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._kwargs = kwargs + + self._model = TextGenerationModel.from_pretrained(self._model_name) + + def to_dict(self) -> Dict[str, Any]: + data = default_to_dict( + self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + ) + + if (grounding_source := data["init_parameters"].get("grounding_source")) is not None: + # Handle the grounding source dataclasses + class_type = f"{grounding_source.__module__}.{grounding_source.__class__.__name__}" + init_fields = {f.name: getattr(grounding_source, f.name) for f in fields(grounding_source) if f.init} + data["init_parameters"]["grounding_source"] = { + "type": class_type, + "init_parameters": init_fields, + } + + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "VertexAITextGenerator": + if (grounding_source := data["init_parameters"].get("grounding_source")) is not None: + module_name, class_name = grounding_source["type"].rsplit(".", 1) + module = importlib.import_module(module_name) + data["init_parameters"]["grounding_source"] = getattr(module, class_name)( + **grounding_source["init_parameters"] + ) + return default_from_dict(cls, data) + + @component.output_types(answers=List[str], safety_attributes=Dict[str, float], citations=List[Dict[str, Any]]) + def run(self, prompt: str): + res = self._model.predict(prompt=prompt, **self._kwargs) + + answers = [] + safety_attributes = [] + citations = [] + + for prediction in res.raw_prediction_response.predictions: + answers.append(prediction["content"]) + safety_attributes.append(prediction["safetyAttributes"]) + citations.append(prediction["citationMetadata"]["citations"]) + + return {"answers": answers, "safety_attributes": safety_attributes, "citations": citations} diff --git a/integrations/google-vertex/tests/__init__.py b/integrations/google-vertex/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/google-vertex/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/google-vertex/tests/test_captioner.py b/integrations/google-vertex/tests/test_captioner.py new file mode 100644 index 000000000..bc7e4f829 --- /dev/null +++ b/integrations/google-vertex/tests/test_captioner.py @@ -0,0 +1,76 @@ +from unittest.mock import Mock, patch + +from haystack.dataclasses.byte_stream import ByteStream + +from google_vertex_haystack.generators.captioner import VertexAIImageCaptioner + + +@patch("google_vertex_haystack.generators.captioner.vertexai") +@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +def test_init(mock_model_class, mock_vertexai): + captioner = VertexAIImageCaptioner( + model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" + ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) + mock_model_class.from_pretrained.assert_called_once_with("imagetext") + assert captioner._model_name == "imagetext" + assert captioner._project_id == "myproject-123456" + assert captioner._location is None + assert captioner._kwargs == {"number_of_results": 1, "language": "it"} + + +@patch("google_vertex_haystack.generators.captioner.vertexai") +@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +def test_to_dict(_mock_model_class, _mock_vertexai): + captioner = VertexAIImageCaptioner( + model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" + ) + assert captioner.to_dict() == { + "type": "google_vertex_haystack.generators.captioner.VertexAIImageCaptioner", + "init_parameters": { + "model": "imagetext", + "project_id": "myproject-123456", + "location": None, + "number_of_results": 1, + "language": "it", + }, + } + + +@patch("google_vertex_haystack.generators.captioner.vertexai") +@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +def test_from_dict(_mock_model_class, _mock_vertexai): + captioner = VertexAIImageCaptioner.from_dict( + { + "type": "google_vertex_haystack.generators.captioner.VertexAIImageCaptioner", + "init_parameters": { + "model": "imagetext", + "project_id": "myproject-123456", + "number_of_results": 1, + "language": "it", + }, + } + ) + assert captioner._model_name == "imagetext" + assert captioner._project_id == "myproject-123456" + assert captioner._location is None + assert captioner._kwargs == {"number_of_results": 1, "language": "it"} + assert captioner._model is not None + + +@patch("google_vertex_haystack.generators.captioner.vertexai") +@patch("google_vertex_haystack.generators.captioner.ImageTextModel") +def test_run_calls_get_captions(mock_model_class, _mock_vertexai): + mock_model = Mock() + mock_model_class.from_pretrained.return_value = mock_model + captioner = VertexAIImageCaptioner( + model="imagetext", project_id="myproject-123456", number_of_results=1, language="it" + ) + + image = ByteStream(data=b"image data") + captioner.run(image=image) + mock_model.get_captions.assert_called_once() + assert len(mock_model.get_captions.call_args.kwargs) == 3 + assert mock_model.get_captions.call_args.kwargs["image"]._image_bytes == image.data + assert mock_model.get_captions.call_args.kwargs["number_of_results"] == 1 + assert mock_model.get_captions.call_args.kwargs["language"] == "it" diff --git a/integrations/google-vertex/tests/test_code_generator.py b/integrations/google-vertex/tests/test_code_generator.py new file mode 100644 index 000000000..c2a2e5aa9 --- /dev/null +++ b/integrations/google-vertex/tests/test_code_generator.py @@ -0,0 +1,79 @@ +from unittest.mock import Mock, patch + +from vertexai.language_models import TextGenerationResponse + +from google_vertex_haystack.generators.code_generator import VertexAICodeGenerator + + +@patch("google_vertex_haystack.generators.code_generator.vertexai") +@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +def test_init(mock_model_class, mock_vertexai): + generator = VertexAICodeGenerator( + model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 + ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) + mock_model_class.from_pretrained.assert_called_once_with("code-bison") + assert generator._model_name == "code-bison" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == {"candidate_count": 3, "temperature": 0.5} + + +@patch("google_vertex_haystack.generators.code_generator.vertexai") +@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +def test_to_dict(_mock_model_class, _mock_vertexai): + generator = VertexAICodeGenerator( + model="code-bison", project_id="myproject-123456", candidate_count=3, temperature=0.5 + ) + assert generator.to_dict() == { + "type": "google_vertex_haystack.generators.code_generator.VertexAICodeGenerator", + "init_parameters": { + "model": "code-bison", + "project_id": "myproject-123456", + "location": None, + "candidate_count": 3, + "temperature": 0.5, + }, + } + + +@patch("google_vertex_haystack.generators.code_generator.vertexai") +@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +def test_from_dict(_mock_model_class, _mock_vertexai): + generator = VertexAICodeGenerator.from_dict( + { + "type": "google_vertex_haystack.generators.code_generator.VertexAICodeGenerator", + "init_parameters": { + "model": "code-bison", + "project_id": "myproject-123456", + "candidate_count": 2, + "temperature": 0.5, + }, + } + ) + assert generator._model_name == "code-bison" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == {"candidate_count": 2, "temperature": 0.5} + assert generator._model is not None + + +@patch("google_vertex_haystack.generators.code_generator.vertexai") +@patch("google_vertex_haystack.generators.code_generator.CodeGenerationModel") +def test_run_calls_predict(mock_model_class, _mock_vertexai): + mock_model = Mock() + mock_model.predict.return_value = TextGenerationResponse("answer", None) + mock_model_class.from_pretrained.return_value = mock_model + generator = VertexAICodeGenerator( + model="code-bison", project_id="myproject-123456", candidate_count=1, temperature=0.5 + ) + + prefix = "def print_json(data):\n" + generator.run(prefix=prefix) + + mock_model.predict.assert_called_once() + assert len(mock_model.predict.call_args.kwargs) == 4 + assert mock_model.predict.call_args.kwargs["prefix"] == prefix + assert mock_model.predict.call_args.kwargs["suffix"] is None + assert mock_model.predict.call_args.kwargs["candidate_count"] == 1 + assert mock_model.predict.call_args.kwargs["temperature"] == 0.5 diff --git a/integrations/google-vertex/tests/test_image_generator.py b/integrations/google-vertex/tests/test_image_generator.py new file mode 100644 index 000000000..1c5381a48 --- /dev/null +++ b/integrations/google-vertex/tests/test_image_generator.py @@ -0,0 +1,92 @@ +from unittest.mock import Mock, patch + +from vertexai.preview.vision_models import ImageGenerationResponse + +from google_vertex_haystack.generators.image_generator import VertexAIImageGenerator + + +@patch("google_vertex_haystack.generators.image_generator.vertexai") +@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +def test_init(mock_model_class, mock_vertexai): + generator = VertexAIImageGenerator( + model="imagetext", + project_id="myproject-123456", + guidance_scale=12, + number_of_images=3, + ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) + mock_model_class.from_pretrained.assert_called_once_with("imagetext") + assert generator._model_name == "imagetext" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == { + "guidance_scale": 12, + "number_of_images": 3, + } + + +@patch("google_vertex_haystack.generators.image_generator.vertexai") +@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +def test_to_dict(_mock_model_class, _mock_vertexai): + generator = VertexAIImageGenerator( + model="imagetext", + project_id="myproject-123456", + guidance_scale=12, + number_of_images=3, + ) + assert generator.to_dict() == { + "type": "google_vertex_haystack.generators.image_generator.VertexAIImageGenerator", + "init_parameters": { + "model": "imagetext", + "project_id": "myproject-123456", + "location": None, + "guidance_scale": 12, + "number_of_images": 3, + }, + } + + +@patch("google_vertex_haystack.generators.image_generator.vertexai") +@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +def test_from_dict(_mock_model_class, _mock_vertexai): + generator = VertexAIImageGenerator.from_dict( + { + "type": "google_vertex_haystack.generators.image_generator.VertexAIImageGenerator", + "init_parameters": { + "model": "imagetext", + "project_id": "myproject-123456", + "location": None, + "guidance_scale": 12, + "number_of_images": 3, + }, + } + ) + assert generator._model_name == "imagetext" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == { + "guidance_scale": 12, + "number_of_images": 3, + } + + +@patch("google_vertex_haystack.generators.image_generator.vertexai") +@patch("google_vertex_haystack.generators.image_generator.ImageGenerationModel") +def test_run_calls_generate_images(mock_model_class, _mock_vertexai): + mock_model = Mock() + mock_model.generate_images.return_value = ImageGenerationResponse(images=[]) + mock_model_class.from_pretrained.return_value = mock_model + generator = VertexAIImageGenerator( + model="imagetext", + project_id="myproject-123456", + guidance_scale=12, + number_of_images=3, + ) + + prompt = "Generate an image of a dog" + negative_prompt = "Generate an image of a cat" + generator.run(prompt=prompt, negative_prompt=negative_prompt) + + mock_model.generate_images.assert_called_once_with( + prompt=prompt, negative_prompt=negative_prompt, guidance_scale=12, number_of_images=3 + ) diff --git a/integrations/google-vertex/tests/test_question_answering.py b/integrations/google-vertex/tests/test_question_answering.py new file mode 100644 index 000000000..3495afcb2 --- /dev/null +++ b/integrations/google-vertex/tests/test_question_answering.py @@ -0,0 +1,83 @@ +from unittest.mock import Mock, patch + +from haystack.dataclasses.byte_stream import ByteStream + +from google_vertex_haystack.generators.question_answering import VertexAIImageQA + + +@patch("google_vertex_haystack.generators.question_answering.vertexai") +@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +def test_init(mock_model_class, mock_vertexai): + generator = VertexAIImageQA( + model="imagetext", + project_id="myproject-123456", + number_of_results=3, + ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) + mock_model_class.from_pretrained.assert_called_once_with("imagetext") + assert generator._model_name == "imagetext" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == {"number_of_results": 3} + + +@patch("google_vertex_haystack.generators.question_answering.vertexai") +@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +def test_to_dict(_mock_model_class, _mock_vertexai): + generator = VertexAIImageQA( + model="imagetext", + project_id="myproject-123456", + number_of_results=3, + ) + assert generator.to_dict() == { + "type": "google_vertex_haystack.generators.question_answering.VertexAIImageQA", + "init_parameters": { + "model": "imagetext", + "project_id": "myproject-123456", + "location": None, + "number_of_results": 3, + }, + } + + +@patch("google_vertex_haystack.generators.question_answering.vertexai") +@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +def test_from_dict(_mock_model_class, _mock_vertexai): + generator = VertexAIImageQA.from_dict( + { + "type": "google_vertex_haystack.generators.question_answering.VertexAIImageQA", + "init_parameters": { + "model": "imagetext", + "project_id": "myproject-123456", + "location": None, + "number_of_results": 3, + }, + } + ) + assert generator._model_name == "imagetext" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == {"number_of_results": 3} + + +@patch("google_vertex_haystack.generators.question_answering.vertexai") +@patch("google_vertex_haystack.generators.question_answering.ImageTextModel") +def test_run_calls_ask_question(mock_model_class, _mock_vertexai): + mock_model = Mock() + mock_model.ask_question.return_value = [] + mock_model_class.from_pretrained.return_value = mock_model + generator = VertexAIImageQA( + model="imagetext", + project_id="myproject-123456", + number_of_results=3, + ) + + image = ByteStream(data=b"image data") + question = "What is this?" + generator.run(image=image, question=question) + + mock_model.ask_question.assert_called_once() + assert len(mock_model.ask_question.call_args.kwargs) == 3 + assert mock_model.ask_question.call_args.kwargs["image"]._image_bytes == image.data + assert mock_model.ask_question.call_args.kwargs["number_of_results"] == 3 + assert mock_model.ask_question.call_args.kwargs["question"] == question diff --git a/integrations/google-vertex/tests/test_text_generator.py b/integrations/google-vertex/tests/test_text_generator.py new file mode 100644 index 000000000..f2edbfc3b --- /dev/null +++ b/integrations/google-vertex/tests/test_text_generator.py @@ -0,0 +1,96 @@ +from unittest.mock import MagicMock, Mock, patch + +from vertexai.language_models import GroundingSource + +from google_vertex_haystack.generators.text_generator import VertexAITextGenerator + + +@patch("google_vertex_haystack.generators.text_generator.vertexai") +@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +def test_init(mock_model_class, mock_vertexai): + grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") + generator = VertexAITextGenerator( + model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source + ) + mock_vertexai.init.assert_called_once_with(project="myproject-123456", location=None) + mock_model_class.from_pretrained.assert_called_once_with("text-bison") + assert generator._model_name == "text-bison" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == {"temperature": 0.2, "grounding_source": grounding_source} + + +@patch("google_vertex_haystack.generators.text_generator.vertexai") +@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +def test_to_dict(_mock_model_class, _mock_vertexai): + grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") + generator = VertexAITextGenerator( + model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source + ) + assert generator.to_dict() == { + "type": "google_vertex_haystack.generators.text_generator.VertexAITextGenerator", + "init_parameters": { + "model": "text-bison", + "project_id": "myproject-123456", + "location": None, + "temperature": 0.2, + "grounding_source": { + "type": "vertexai.language_models._language_models.VertexAISearch", + "init_parameters": { + "location": "us-central-1", + "data_store_id": "1234", + "project": None, + "disable_attribution": False, + }, + }, + }, + } + + +@patch("google_vertex_haystack.generators.text_generator.vertexai") +@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +def test_from_dict(_mock_model_class, _mock_vertexai): + generator = VertexAITextGenerator.from_dict( + { + "type": "google_vertex_haystack.generators.text_generator.VertexAITextGenerator", + "init_parameters": { + "model": "text-bison", + "project_id": "myproject-123456", + "location": None, + "temperature": 0.2, + "grounding_source": { + "type": "vertexai.language_models._language_models.VertexAISearch", + "init_parameters": { + "location": "us-central-1", + "data_store_id": "1234", + "project": None, + "disable_attribution": False, + }, + }, + }, + } + ) + assert generator._model_name == "text-bison" + assert generator._project_id == "myproject-123456" + assert generator._location is None + assert generator._kwargs == { + "temperature": 0.2, + "grounding_source": GroundingSource.VertexAISearch("1234", "us-central-1"), + } + + +@patch("google_vertex_haystack.generators.text_generator.vertexai") +@patch("google_vertex_haystack.generators.text_generator.TextGenerationModel") +def test_run_calls_get_captions(mock_model_class, _mock_vertexai): + mock_model = Mock() + mock_model.predict.return_value = MagicMock() + mock_model_class.from_pretrained.return_value = mock_model + grounding_source = GroundingSource.VertexAISearch("1234", "us-central-1") + generator = VertexAITextGenerator( + model="text-bison", project_id="myproject-123456", temperature=0.2, grounding_source=grounding_source + ) + + prompt = "What is the answer?" + generator.run(prompt=prompt) + + mock_model.predict.assert_called_once_with(prompt=prompt, temperature=0.2, grounding_source=grounding_source) From e215185b1cd9d20340a1fe1f56f3c1ba47fb6d74 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:14:49 +0100 Subject: [PATCH 18/29] Update Google Vertex integration (#102) * Pin google-cloud-aiplatform version * Make generation kwargs explicit and implement serialization * Implement GeminiChatGenerator * Fix linting and remove dead code * Bump version --- integrations/google-vertex/pyproject.toml | 2 +- .../src/google_vertex_haystack/__about__.py | 2 +- .../generators/chat/gemini.py | 179 ++++++++++++++++++ .../generators/gemini.py | 108 ++++++++--- 4 files changed, 261 insertions(+), 30 deletions(-) create mode 100644 integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py diff --git a/integrations/google-vertex/pyproject.toml b/integrations/google-vertex/pyproject.toml index 2455b4fa9..a34853a60 100644 --- a/integrations/google-vertex/pyproject.toml +++ b/integrations/google-vertex/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "google-cloud-aiplatform", + "google-cloud-aiplatform>=1.38", ] [project.urls] diff --git a/integrations/google-vertex/src/google_vertex_haystack/__about__.py b/integrations/google-vertex/src/google_vertex_haystack/__about__.py index 0e4fa27cf..d4a92df1b 100644 --- a/integrations/google-vertex/src/google_vertex_haystack/__about__.py +++ b/integrations/google-vertex/src/google_vertex_haystack/__about__.py @@ -1,4 +1,4 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.1" +__version__ = "0.0.2" diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py b/integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py new file mode 100644 index 000000000..765706301 --- /dev/null +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/chat/gemini.py @@ -0,0 +1,179 @@ +import logging +from typing import Any, Dict, List, Optional, Union + +import vertexai +from haystack.core.component import component +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.dataclasses.byte_stream import ByteStream +from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from vertexai.preview.generative_models import ( + Content, + FunctionDeclaration, + GenerationConfig, + GenerativeModel, + HarmBlockThreshold, + HarmCategory, + Part, + Tool, +) + +logger = logging.getLogger(__name__) + + +@component +class GeminiChatGenerator: + def __init__( + self, + *, + model: str = "gemini-pro", + project_id: str, + location: Optional[str] = None, + generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, + safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Multi modal generator using Gemini model via Google Vertex AI. + + Authenticates using Google Cloud Application Default Credentials (ADCs). + For more information see the official Google documentation: + https://cloud.google.com/docs/authentication/provide-credentials-adc + + :param project_id: ID of the GCP project to use. + :param model: Name of the model to use, defaults to "gemini-pro-vision". + :param location: The default location to use when making API calls, if not set uses us-central-1. + Defaults to None. + :param kwargs: Additional keyword arguments to pass to the model. + For a list of supported arguments see the `GenerativeModel.generate_content()` documentation. + """ + + # Login to GCP. This will fail if user has not set up their gcloud SDK + vertexai.init(project=project_id, location=location) + + self._model_name = model + self._project_id = project_id + self._location = location + self._model = GenerativeModel(self._model_name) + + self._generation_config = generation_config + self._safety_settings = safety_settings + self._tools = tools + + def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: + return { + "name": function._raw_function_declaration.name, + "parameters": function._raw_function_declaration.parameters, + "description": function._raw_function_declaration.description, + } + + def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: + return { + "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], + } + + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(config, dict): + return config + return { + "temperature": config._raw_generation_config.temperature, + "top_p": config._raw_generation_config.top_p, + "top_k": config._raw_generation_config.top_k, + "candidate_count": config._raw_generation_config.candidate_count, + "max_output_tokens": config._raw_generation_config.max_output_tokens, + "stop_sequences": config._raw_generation_config.stop_sequences, + } + + def to_dict(self) -> Dict[str, Any]: + data = default_to_dict( + self, + model=self._model_name, + project_id=self._project_id, + location=self._location, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + ) + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GeminiChatGenerator": + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) + + return default_from_dict(cls, data) + + def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: + if isinstance(part, str): + return Part.from_text(part) + elif isinstance(part, ByteStream): + return Part.from_data(part.data, part.mime_type) + elif isinstance(part, Part): + return part + else: + msg = f"Unsupported type {type(part)} for part {part}" + raise ValueError(msg) + + def _message_to_part(self, message: ChatMessage) -> Part: + if message.role == ChatRole.SYSTEM and message.name: + p = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) + for k, v in message.content.items(): + p.function_call.args[k] = v + return p + elif message.role == ChatRole.SYSTEM: + return Part.from_text(message.content) + elif message.role == ChatRole.FUNCTION: + return Part.from_function_response(name=message.name, response=message.content) + elif message.role == ChatRole.USER: + return self._convert_part(message.content) + + def _message_to_content(self, message: ChatMessage) -> Content: + if message.role == ChatRole.SYSTEM and message.name: + part = Part.from_dict({"function_call": {"name": message.name, "args": {}}}) + for k, v in message.content.items(): + part.function_call.args[k] = v + elif message.role == ChatRole.SYSTEM: + part = Part.from_text(message.content) + elif message.role == ChatRole.FUNCTION: + part = Part.from_function_response(name=message.name, response=message.content) + elif message.role == ChatRole.USER: + part = self._convert_part(message.content) + else: + msg = f"Unsupported message role {message.role}" + raise ValueError(msg) + role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model" + return Content(parts=[part], role=role) + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage]): + history = [self._message_to_content(m) for m in messages[:-1]] + session = self._model.start_chat(history=history) + + new_message = self._message_to_part(messages[-1]) + res = session.send_message( + content=new_message, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + ) + + replies = [] + for candidate in res.candidates: + for part in candidate.content.parts: + if part._raw_part.text != "": + replies.append(ChatMessage.from_system(part.text)) + elif part.function_call is not None: + replies.append( + ChatMessage( + content=dict(part.function_call.args.items()), + role=ChatRole.SYSTEM, + name=part.function_call.name, + ) + ) + + return {"replies": replies} diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py index b01dc6795..aa8f9d9e8 100644 --- a/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py @@ -8,8 +8,13 @@ from haystack.dataclasses.byte_stream import ByteStream from vertexai.preview.generative_models import ( Content, + FunctionDeclaration, + GenerationConfig, GenerativeModel, + HarmBlockThreshold, + HarmCategory, Part, + Tool, ) logger = logging.getLogger(__name__) @@ -17,7 +22,16 @@ @component class GeminiGenerator: - def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, location: Optional[str] = None, **kwargs): + def __init__( + self, + *, + model: str = "gemini-pro-vision", + project_id: str, + location: Optional[str] = None, + generation_config: Optional[Union[GenerationConfig, Dict[str, Any]]] = None, + safety_settings: Optional[Dict[HarmCategory, HarmBlockThreshold]] = None, + tools: Optional[List[Tool]] = None, + ): """ Multi modal generator using Gemini model via Google Vertex AI. @@ -29,8 +43,19 @@ def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, locatio :param model: Name of the model to use, defaults to "gemini-pro-vision". :param location: The default location to use when making API calls, if not set uses us-central-1. Defaults to None. - :param kwargs: Additional keyword arguments to pass to the model. - For a list of supported arguments see the `GenerativeModel.generate_content()` documentation. + :param generation_config: The generation config to use, defaults to None. + Can either be a GenerationConfig object or a dictionary of parameters. + Accepted fields are: + - temperature + - top_p + - top_k + - candidate_count + - max_output_tokens + - stop_sequences + :param safety_settings: The safety settings to use, defaults to None. + A dictionary of HarmCategory to HarmBlockThreshold. + :param tools: The tools to use, defaults to None. + A list of Tool objects that can be used to modify the generation process. """ # Login to GCP. This will fail if user has not set up their gcloud SDK @@ -39,23 +64,59 @@ def __init__(self, *, model: str = "gemini-pro-vision", project_id: str, locatio self._model_name = model self._project_id = project_id self._location = location - self._kwargs = kwargs - - if kwargs.get("stream"): - msg = "The `stream` parameter is not supported by the Gemini generator." - raise ValueError(msg) - self._model = GenerativeModel(self._model_name) + self._generation_config = generation_config + self._safety_settings = safety_settings + self._tools = tools + + def _function_to_dict(self, function: FunctionDeclaration) -> Dict[str, Any]: + return { + "name": function._raw_function_declaration.name, + "parameters": function._raw_function_declaration.parameters, + "description": function._raw_function_declaration.description, + } + + def _tool_to_dict(self, tool: Tool) -> Dict[str, Any]: + return { + "function_declarations": [self._function_to_dict(f) for f in tool._raw_tool.function_declarations], + } + + def _generation_config_to_dict(self, config: Union[GenerationConfig, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(config, dict): + return config + return { + "temperature": config._raw_generation_config.temperature, + "top_p": config._raw_generation_config.top_p, + "top_k": config._raw_generation_config.top_k, + "candidate_count": config._raw_generation_config.candidate_count, + "max_output_tokens": config._raw_generation_config.max_output_tokens, + "stop_sequences": config._raw_generation_config.stop_sequences, + } + def to_dict(self) -> Dict[str, Any]: - # TODO: This is not fully implemented yet - return default_to_dict( - self, model=self._model_name, project_id=self._project_id, location=self._location, **self._kwargs + data = default_to_dict( + self, + model=self._model_name, + project_id=self._project_id, + location=self._location, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, ) + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [self._tool_to_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = self._generation_config_to_dict(generation_config) + return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> "GeminiGenerator": - # TODO: This is not fully implemented yet + if (tools := data["init_parameters"].get("tools")) is not None: + data["init_parameters"]["tools"] = [Tool.from_dict(t) for t in tools] + if (generation_config := data["init_parameters"].get("generation_config")) is not None: + data["init_parameters"]["generation_config"] = GenerationConfig.from_dict(generation_config) + return default_from_dict(cls, data) def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: @@ -74,7 +135,12 @@ def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]): converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] - res = self._model.generate_content(contents=contents, **self._kwargs) + res = self._model.generate_content( + contents=contents, + generation_config=self._generation_config, + safety_settings=self._safety_settings, + tools=self._tools, + ) self._model.start_chat() answers = [] for candidate in res.candidates: @@ -89,17 +155,3 @@ def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]): answers.append(function_call) return {"answers": answers} - - -# generator = GeminiGenerator(project_id="infinite-byte-223810") -# res = generator.run(["What can you do for me?"]) -# res -# another_res = generator.run(["Can you solve this math problems?", "2 + 2", "3 + 3", "1 / 1"]) -# another_res["answers"] -# from pathlib import Path - -# image = ByteStream.from_file_path( -# Path("/Users/silvanocerza/Downloads/photo_2023-11-07_11-45-42.jpg"), mime_type="image/jpeg" -# ) -# res = generator.run(["What is this about?", image]) -# res["answers"] From 5945cdd2d267e8e9407a5abe077854f3140096a5 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 12:26:39 +0100 Subject: [PATCH 19/29] pin chroma version (#104) --- integrations/chroma/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index 8bace9a1b..b50723f4c 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "chromadb", + "chromadb<0.4.20", # FIXME: investigate why filtering tests broke on 0.4.20 ] [project.urls] From 71e20eb2dda909559dcbdce48b51df054a97107b Mon Sep 17 00:00:00 2001 From: Anush Date: Fri, 15 Dec 2023 19:02:20 +0530 Subject: [PATCH 20/29] feat: Add Qdrant integration (#98) * feat: qdrant-haystack * ci: Create qdrant.yml * docs: Update README.md, mypy overrides (#1) * docs: Update README.md * chore: mypy overrides * docs: README.md typo fix * chore: update pyproject.toml (#2) * chore: pin pyproject.toml version * Update pyproject.toml * Update pyproject.toml * Delete integrations/qdrant/src/qdrant_haystack/utils.py * Rename test_dict_convertors.py to test_dict_converters.py --------- Co-authored-by: Julian Risch --- .github/workflows/qdrant.yml | 56 +++ integrations/qdrant/LICENSE.txt | 73 +++ integrations/qdrant/README.md | 28 ++ integrations/qdrant/pyproject.toml | 171 +++++++ .../qdrant/src/qdrant_haystack/__about__.py | 4 + .../qdrant/src/qdrant_haystack/__init__.py | 7 + .../qdrant/src/qdrant_haystack/converters.py | 56 +++ .../src/qdrant_haystack/document_store.py | 458 +++++++++++++++++ .../qdrant/src/qdrant_haystack/filters.py | 211 ++++++++ .../qdrant/src/qdrant_haystack/retriever.py | 98 ++++ integrations/qdrant/tests/__init__.py | 3 + integrations/qdrant/tests/test_converters.py | 53 ++ .../qdrant/tests/test_dict_converters.py | 102 ++++ .../qdrant/tests/test_document_store.py | 42 ++ integrations/qdrant/tests/test_filters.py | 115 +++++ .../qdrant/tests/test_legacy_filters.py | 459 ++++++++++++++++++ integrations/qdrant/tests/test_retriever.py | 113 +++++ 17 files changed, 2049 insertions(+) create mode 100644 .github/workflows/qdrant.yml create mode 100644 integrations/qdrant/LICENSE.txt create mode 100644 integrations/qdrant/README.md create mode 100644 integrations/qdrant/pyproject.toml create mode 100644 integrations/qdrant/src/qdrant_haystack/__about__.py create mode 100644 integrations/qdrant/src/qdrant_haystack/__init__.py create mode 100644 integrations/qdrant/src/qdrant_haystack/converters.py create mode 100644 integrations/qdrant/src/qdrant_haystack/document_store.py create mode 100644 integrations/qdrant/src/qdrant_haystack/filters.py create mode 100644 integrations/qdrant/src/qdrant_haystack/retriever.py create mode 100644 integrations/qdrant/tests/__init__.py create mode 100644 integrations/qdrant/tests/test_converters.py create mode 100644 integrations/qdrant/tests/test_dict_converters.py create mode 100644 integrations/qdrant/tests/test_document_store.py create mode 100644 integrations/qdrant/tests/test_filters.py create mode 100644 integrations/qdrant/tests/test_legacy_filters.py create mode 100644 integrations/qdrant/tests/test_retriever.py diff --git a/.github/workflows/qdrant.yml b/.github/workflows/qdrant.yml new file mode 100644 index 000000000..2bbf4f63a --- /dev/null +++ b/.github/workflows/qdrant.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / qdrant + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - 'integrations/qdrant/**' + - '.github/workflows/qdrant.yml' + +defaults: + run: + working-directory: integrations/qdrant + +concurrency: + group: qdrant-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ['3.9', '3.10'] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov diff --git a/integrations/qdrant/LICENSE.txt b/integrations/qdrant/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/qdrant/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/qdrant/README.md b/integrations/qdrant/README.md new file mode 100644 index 000000000..18bf5b7a8 --- /dev/null +++ b/integrations/qdrant/README.md @@ -0,0 +1,28 @@ +# qdrant-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [License](#license) + +## Installation + +```console +pip install qdrant-haystack +``` + +## Testing +The test suites use Qdrant's in-memory instance. No additional steps required. + +```console +hatch run test +``` + +## License + +`qdrant-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml new file mode 100644 index 000000000..f8209e0c9 --- /dev/null +++ b/integrations/qdrant/pyproject.toml @@ -0,0 +1,171 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "qdrant-haystack" +dynamic = ["version"] +description = 'An integration of Qdrant ANN vector database backend with Haystack' +readme = "README.md" +requires-python = ">=3.7" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "Kacper Łukawski", email = "kacper.lukawski@qdrant.com" }, + { name = "Anush Shetty", email = "anush.shetty@qdrant.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "qdrant-client", +] + +[project.urls] +Source = "https://github.com/deepset-ai/haystack-core-integrations" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/qdrant/README.md" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" + +[tool.hatch.version] +path = "src/qdrant_haystack/__about__.py" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/qdrant_haystack tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["qdrant_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["qdrant_haystack", "tests"] +branch = true +parallel = true +omit = [ + "src/qdrant_haystack/__about__.py", +] + +[tool.coverage.paths] +qdrant_haystack = ["src/qdrant_haystack", "*/qdrant-haystack/src/qdrant_haystack"] +tests = ["tests", "*/qdrant-haystack/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "pytest.*", + "qdrant_client.*", + "numpy", + "grpc" +] +ignore_missing_imports = true diff --git a/integrations/qdrant/src/qdrant_haystack/__about__.py b/integrations/qdrant/src/qdrant_haystack/__about__.py new file mode 100644 index 000000000..0e4fa27cf --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "0.0.1" diff --git a/integrations/qdrant/src/qdrant_haystack/__init__.py b/integrations/qdrant/src/qdrant_haystack/__init__.py new file mode 100644 index 000000000..765ced0ef --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from qdrant_haystack.document_store import QdrantDocumentStore + +__all__ = ("QdrantDocumentStore",) diff --git a/integrations/qdrant/src/qdrant_haystack/converters.py b/integrations/qdrant/src/qdrant_haystack/converters.py new file mode 100644 index 000000000..3fb6dabd6 --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/converters.py @@ -0,0 +1,56 @@ +import uuid +from typing import List, Union + +from haystack.dataclasses import Document +from qdrant_client.http import models as rest + + +class HaystackToQdrant: + """A converter from Haystack to Qdrant types.""" + + UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") + + def documents_to_batch( + self, + documents: List[Document], + *, + embedding_field: str, + ) -> List[rest.PointStruct]: + points = [] + for document in documents: + payload = document.to_dict(flatten=False) + vector = payload.pop(embedding_field) or {} + _id = self.convert_id(payload.get("id")) + + point = rest.PointStruct( + payload=payload, + vector=vector, + id=_id, + ) + points.append(point) + return points + + def convert_id(self, _id: str) -> str: + """ + Converts any string into a UUID-like format in a deterministic way. + + Qdrant does not accept any string as an id, so an internal id has to be + generated for each point. This is a deterministic way of doing so. + """ + return uuid.uuid5(self.UUID_NAMESPACE, _id).hex + + +QdrantPoint = Union[rest.ScoredPoint, rest.Record] + + +class QdrantToHaystack: + def __init__(self, content_field: str, name_field: str, embedding_field: str): + self.content_field = content_field + self.name_field = name_field + self.embedding_field = embedding_field + + def point_to_document(self, point: QdrantPoint) -> Document: + payload = {**point.payload} + payload["embedding"] = point.vector if hasattr(point, "vector") else None + payload["score"] = point.score if hasattr(point, "score") else None + return Document.from_dict(payload) diff --git a/integrations/qdrant/src/qdrant_haystack/document_store.py b/integrations/qdrant/src/qdrant_haystack/document_store.py new file mode 100644 index 000000000..c4b709332 --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/document_store.py @@ -0,0 +1,458 @@ +import inspect +import logging +from itertools import islice +from typing import Any, ClassVar, Dict, Generator, List, Optional, Set, Union + +import numpy as np +import qdrant_client +from grpc import RpcError +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError +from haystack.utils.filters import convert +from qdrant_client import grpc +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse +from tqdm import tqdm + +from qdrant_haystack.converters import HaystackToQdrant, QdrantToHaystack +from qdrant_haystack.filters import QdrantFilterConverter + +logger = logging.getLogger(__name__) + + +class QdrantStoreError(DocumentStoreError): + pass + + +FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]] + + +def get_batches_from_generator(iterable, n): + """ + Batch elements of an iterable into fixed-length chunks or blocks. + """ + it = iter(iterable) + x = tuple(islice(it, n)) + while x: + yield x + x = tuple(islice(it, n)) + + +class QdrantDocumentStore: + SIMILARITY: ClassVar[Dict[str, str]] = { + "cosine": rest.Distance.COSINE, + "dot_product": rest.Distance.DOT, + "l2": rest.Distance.EUCLID, + } + + def __init__( + self, + location: Optional[str] = None, + url: Optional[str] = None, + port: int = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, # noqa: FBT001, FBT002 + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + path: Optional[str] = None, + index: str = "Document", + embedding_dim: int = 768, + content_field: str = "content", + name_field: str = "name", + embedding_field: str = "embedding", + similarity: str = "cosine", + return_embedding: bool = False, # noqa: FBT001, FBT002 + progress_bar: bool = True, # noqa: FBT001, FBT002 + duplicate_documents: str = "overwrite", + recreate_index: bool = False, # noqa: FBT001, FBT002 + shard_number: Optional[int] = None, + replication_factor: Optional[int] = None, + write_consistency_factor: Optional[int] = None, + on_disk_payload: Optional[bool] = None, + hnsw_config: Optional[dict] = None, + optimizers_config: Optional[dict] = None, + wal_config: Optional[dict] = None, + quantization_config: Optional[dict] = None, + init_from: Optional[dict] = None, + wait_result_from_api: bool = True, # noqa: FBT001, FBT002 + metadata: Optional[dict] = None, + write_batch_size: int = 100, + scroll_size: int = 10_000, + ): + super().__init__() + + metadata = metadata or {} + self.client = qdrant_client.QdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + metadata=metadata, + ) + + # Store the Qdrant client specific attributes + self.location = location + self.url = url + self.port = port + self.grpc_port = grpc_port + self.prefer_grpc = prefer_grpc + self.https = https + self.api_key = api_key + self.prefix = prefix + self.timeout = timeout + self.host = host + self.path = path + self.metadata = metadata + + # Store the Qdrant collection specific attributes + self.shard_number = shard_number + self.replication_factor = replication_factor + self.write_consistency_factor = write_consistency_factor + self.on_disk_payload = on_disk_payload + self.hnsw_config = hnsw_config + self.optimizers_config = optimizers_config + self.wal_config = wal_config + self.quantization_config = quantization_config + self.init_from = init_from + self.wait_result_from_api = wait_result_from_api + self.recreate_index = recreate_index + + # Make sure the collection is properly set up + self._set_up_collection(index, embedding_dim, recreate_index, similarity) + + self.embedding_dim = embedding_dim + self.content_field = content_field + self.name_field = name_field + self.embedding_field = embedding_field + self.similarity = similarity + self.index = index + self.return_embedding = return_embedding + self.progress_bar = progress_bar + self.duplicate_documents = duplicate_documents + self.qdrant_filter_converter = QdrantFilterConverter() + self.haystack_to_qdrant_converter = HaystackToQdrant() + self.qdrant_to_haystack = QdrantToHaystack( + content_field, + name_field, + embedding_field, + ) + self.write_batch_size = write_batch_size + self.scroll_size = scroll_size + + def count_documents(self) -> int: + try: + response = self.client.count( + collection_name=self.index, + ) + return response.count + except (UnexpectedResponse, ValueError): + # Qdrant local raises ValueError if the collection is not found, but + # with the remote server UnexpectedResponse is raised. Until that's unified, + # we need to catch both. + return 0 + + def filter_documents( + self, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + if filters and not isinstance(filters, dict): + msg = "Filter must be a dictionary" + raise ValueError(msg) + + if filters and "operator" not in filters: + filters = convert(filters) + return list( + self.get_documents_generator( + filters, + ) + ) + + def write_documents( + self, + documents: List[Document], + policy: DuplicatePolicy = DuplicatePolicy.FAIL, + ): + for doc in documents: + if not isinstance(doc, Document): + msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of {type(doc)}." + raise ValueError(msg) + self._set_up_collection(self.index, self.embedding_dim, False, self.similarity) + + if len(documents) == 0: + logger.warning("Calling QdrantDocumentStore.write_documents() with empty list") + return + + document_objects = self._handle_duplicate_documents( + documents=documents, + index=self.index, + policy=policy, + ) + + batched_documents = get_batches_from_generator(document_objects, self.write_batch_size) + with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: + for document_batch in batched_documents: + batch = self.haystack_to_qdrant_converter.documents_to_batch( + document_batch, + embedding_field=self.embedding_field, + ) + + self.client.upsert( + collection_name=self.index, + points=batch, + wait=self.wait_result_from_api, + ) + + progress_bar.update(self.write_batch_size) + return len(document_objects) + + def delete_documents(self, ids: List[str]): + ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + try: + self.client.delete( + collection_name=self.index, + points_selector=ids, + wait=self.wait_result_from_api, + ) + except KeyError: + logger.warning( + "Called QdrantDocumentStore.delete_documents() on a non-existing ID", + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "QdrantDocumentStore": + return default_from_dict(cls, data) + + def to_dict(self) -> Dict[str, Any]: + params = inspect.signature(self.__init__).parameters # type: ignore + # All the __init__ params must be set as attributes + # Set as init_parms without default values + init_params = {k: getattr(self, k) for k in params} + return default_to_dict( + self, + **init_params, + ) + + def get_documents_generator( + self, + filters: Optional[Dict[str, Any]] = None, + ) -> Generator[Document, None, None]: + index = self.index + qdrant_filters = self.qdrant_filter_converter.convert(filters) + + next_offset = None + stop_scrolling = False + while not stop_scrolling: + records, next_offset = self.client.scroll( + collection_name=index, + scroll_filter=qdrant_filters, + limit=self.scroll_size, + offset=next_offset, + with_payload=True, + with_vectors=True, + ) + stop_scrolling = next_offset is None or ( + isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == "" + ) + + for record in records: + yield self.qdrant_to_haystack.point_to_document(record) + + def get_documents_by_id( + self, + ids: List[str], + index: Optional[str] = None, + ) -> List[Document]: + index = index or self.index + + documents: List[Document] = [] + + ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + records = self.client.retrieve( + collection_name=index, + ids=ids, + with_payload=True, + with_vectors=True, + ) + + for record in records: + documents.append(self.qdrant_to_haystack.point_to_document(record)) + return documents + + def query_by_embedding( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, # noqa: FBT001, FBT002 + return_embedding: bool = False, # noqa: FBT001, FBT002 + ) -> List[Document]: + qdrant_filters = self.qdrant_filter_converter.convert(filters) + + points = self.client.search( + collection_name=self.index, + query_vector=query_embedding, + query_filter=qdrant_filters, + limit=top_k, + with_vectors=return_embedding, + ) + + results = [self.qdrant_to_haystack.point_to_document(point) for point in points] + if scale_score: + for document in results: + score = document.score + if self.similarity == "cosine": + score = (score + 1) / 2 + else: + score = float(1 / (1 + np.exp(-score / 100))) + document.score = score + return results + + def _get_distance(self, similarity: str) -> rest.Distance: + try: + return self.SIMILARITY[similarity] + except KeyError as ke: + msg = ( + f"Provided similarity '{similarity}' is not supported by Qdrant " + f"document store. Please choose one of the options: " + f"{', '.join(self.SIMILARITY.keys())}" + ) + raise QdrantStoreError(msg) from ke + + def _set_up_collection( + self, + collection_name: str, + embedding_dim: int, + recreate_collection: bool, # noqa: FBT001 + similarity: str, + ): + distance = self._get_distance(similarity) + + if recreate_collection: + # There is no need to verify the current configuration of that + # collection. It might be just recreated again. + self._recreate_collection(collection_name, distance, embedding_dim) + return + + try: + # Check if the collection already exists and validate its + # current configuration with the parameters. + collection_info = self.client.get_collection(collection_name) + except (UnexpectedResponse, RpcError, ValueError): + # That indicates the collection does not exist, so it can be + # safely created with any configuration. + # + # Qdrant local raises ValueError if the collection is not found, but + # with the remote server UnexpectedResponse / RpcError is raised. + # Until that's unified, we need to catch both. + self._recreate_collection(collection_name, distance, embedding_dim) + return + + current_distance = collection_info.config.params.vectors.distance + current_vector_size = collection_info.config.params.vectors.size + + if current_distance != distance: + msg = ( + f"Collection '{collection_name}' already exists in Qdrant, " + f"but it is configured with a similarity '{current_distance.name}'. " + f"If you want to use that collection, but with a different " + f"similarity, please set `recreate_collection=True` argument." + ) + raise ValueError(msg) + + if current_vector_size != embedding_dim: + msg = ( + f"Collection '{collection_name}' already exists in Qdrant, " + f"but it is configured with a vector size '{current_vector_size}'. " + f"If you want to use that collection, but with a different " + f"vector size, please set `recreate_collection=True` argument." + ) + raise ValueError(msg) + + def _recreate_collection(self, collection_name: str, distance, embedding_dim: int): + self.client.recreate_collection( + collection_name=collection_name, + vectors_config=rest.VectorParams( + size=embedding_dim, + distance=distance, + ), + shard_number=self.shard_number, + replication_factor=self.replication_factor, + write_consistency_factor=self.write_consistency_factor, + on_disk_payload=self.on_disk_payload, + hnsw_config=self.hnsw_config, + optimizers_config=self.optimizers_config, + wal_config=self.wal_config, + quantization_config=self.quantization_config, + init_from=self.init_from, + ) + + def _handle_duplicate_documents( + self, + documents: List[Document], + index: Optional[str] = None, + policy: DuplicatePolicy = None, + ): + """ + Checks whether any of the passed documents is already existing in the chosen index and returns a list of + documents that are not in the index yet. + + :param documents: A list of Haystack Document objects. + :param index: name of the index + :param duplicate_documents: Handle duplicates document based on parameter options. + Parameter options : ( 'skip','overwrite','fail') + skip (default option): Ignore the duplicates documents + overwrite: Update any existing documents with the same ID when adding documents. + fail: an error is raised if the document ID of the document being added already + exists. + :return: A list of Haystack Document objects. + """ + + index = index or self.index + if policy in (DuplicatePolicy.SKIP, DuplicatePolicy.FAIL): + documents = self._drop_duplicate_documents(documents, index) + documents_found = self.get_documents_by_id(ids=[doc.id for doc in documents], index=index) + ids_exist_in_db: List[str] = [doc.id for doc in documents_found] + + if len(ids_exist_in_db) > 0 and policy == DuplicatePolicy.FAIL: + msg = f"Document with ids '{', '.join(ids_exist_in_db)} already exists in index = '{index}'." + raise DuplicateDocumentError(msg) + + documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents)) + + return documents + + def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]: + """ + Drop duplicates documents based on same hash ID + + :param documents: A list of Haystack Document objects. + :param index: name of the index + :return: A list of Haystack Document objects. + """ + _hash_ids: Set = set() + _documents: List[Document] = [] + + for document in documents: + if document.id in _hash_ids: + logger.info( + "Duplicate Documents: Document with id '%s' already exists in index '%s'", + document.id, + index or self.index, + ) + continue + _documents.append(document) + _hash_ids.add(document.id) + + return _documents diff --git a/integrations/qdrant/src/qdrant_haystack/filters.py b/integrations/qdrant/src/qdrant_haystack/filters.py new file mode 100644 index 000000000..cc6b2b6a5 --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/filters.py @@ -0,0 +1,211 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Union + +from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError +from qdrant_client.http import models + +from qdrant_haystack.converters import HaystackToQdrant + +COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys() +LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys() + + +class BaseFilterConverter(ABC): + """Converts Haystack filters to a format accepted by an external tool.""" + + @abstractmethod + def convert(self, filter_term: Optional[Union[List[dict], dict]]) -> Optional[Any]: + raise NotImplementedError + + +class QdrantFilterConverter(BaseFilterConverter): + """Converts Haystack filters to the format used by Qdrant.""" + + def __init__(self): + self.haystack_to_qdrant_converter = HaystackToQdrant() + + def convert( + self, + filter_term: Optional[Union[List[dict], dict]] = None, + ) -> Optional[models.Filter]: + if not filter_term: + return None + + must_clauses, should_clauses, must_not_clauses = [], [], [] + + if isinstance(filter_term, dict): + filter_term = [filter_term] + + for item in filter_term: + operator = item.get("operator") + if operator is None: + msg = "Operator not found in filters" + raise FilterError(msg) + + if operator in LOGICAL_OPERATORS and "conditions" not in item: + msg = f"'conditions' not found for '{operator}'" + raise FilterError(msg) + + if operator == "AND": + must_clauses.append(self.convert(item.get("conditions", []))) + elif operator == "OR": + should_clauses.append(self.convert(item.get("conditions", []))) + elif operator == "NOT": + must_not_clauses.append(self.convert(item.get("conditions", []))) + elif operator in COMPARISON_OPERATORS: + field = item.get("field") + value = item.get("value") + if field is None or value is None: + msg = f"'field' or 'value' not found for '{operator}'" + raise FilterError(msg) + + must_clauses.extend( + self._parse_comparison_operation(comparison_operation=operator, key=field, value=value) + ) + else: + msg = f"Unknown operator {operator} used in filters" + raise FilterError(msg) + + payload_filter = models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + + filter_result = self._squeeze_filter(payload_filter) + + return filter_result + + def _parse_comparison_operation( + self, comparison_operation: str, key: str, value: Union[dict, List, str, float] + ) -> List[models.Condition]: + conditions: List[models.Condition] = [] + + condition_builder_mapping = { + "==": self._build_eq_condition, + "in": self._build_in_condition, + "!=": self._build_ne_condition, + "not in": self._build_nin_condition, + ">": self._build_gt_condition, + ">=": self._build_gte_condition, + "<": self._build_lt_condition, + "<=": self._build_lte_condition, + } + + condition_builder = condition_builder_mapping.get(comparison_operation) + + if condition_builder is None: + msg = f"Unknown operator {comparison_operation} used in filters" + raise ValueError(msg) + + conditions.append(condition_builder(key, value)) + + return conditions + + def _build_eq_condition(self, key: str, value: models.ValueVariants) -> models.Condition: + if isinstance(value, str) and " " in value: + models.FieldCondition(key=key, match=models.MatchText(text=value)) + return models.FieldCondition(key=key, match=models.MatchValue(value=value)) + + def _build_in_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" + raise FilterError(msg) + return models.Filter( + should=[ + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + for item in value + ] + ) + + def _build_ne_condition(self, key: str, value: models.ValueVariants) -> models.Condition: + return models.Filter( + must_not=[ + models.FieldCondition(key=key, match=models.MatchText(text=value)) + if isinstance(value, str) and " " not in value + else models.FieldCondition(key=key, match=models.MatchValue(value=value)) + ] + ) + + def _build_nin_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" + raise FilterError(msg) + return models.Filter( + must_not=[ + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + for item in value + ] + ) + + def _build_lt_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(lt=value)) + + def _build_lte_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(lte=value)) + + def _build_gt_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(gt=value)) + + def _build_gte_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(gte=value)) + + def _build_has_id_condition(self, id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: + return models.HasIdCondition( + has_id=[ + # Ids are converted into their internal representation + self.haystack_to_qdrant_converter.convert_id(item) + for item in id_values + ] + ) + + def _squeeze_filter(self, payload_filter: models.Filter) -> models.Filter: + """ + Simplify given payload filter, if the nested structure might be unnested. + That happens if there is a single clause in that filter. + :param payload_filter: + :return: + """ + filter_parts = { + "must": payload_filter.must, + "should": payload_filter.should, + "must_not": payload_filter.must_not, + } + + total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) + if total_clauses == 0 or total_clauses > 1: + return payload_filter + + # Payload filter has just a single clause provided (either must, should + # or must_not). If that single clause is also of a models.Filter type, + # then it might be returned instead. + for part_name, filter_part in filter_parts.items(): + if not filter_part: + continue + + subfilter = filter_part[0] + if not isinstance(subfilter, models.Filter): + # The inner statement is a simple condition like models.FieldCondition + # so it cannot be simplified. + continue + + if subfilter.must: + return models.Filter(**{part_name: subfilter.must}) + + return payload_filter diff --git a/integrations/qdrant/src/qdrant_haystack/retriever.py b/integrations/qdrant/src/qdrant_haystack/retriever.py new file mode 100644 index 000000000..054ba96ac --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/retriever.py @@ -0,0 +1,98 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict + +from qdrant_haystack import QdrantDocumentStore + + +@component +class QdrantRetriever: + """ + A component for retrieving documents from an QdrantDocumentStore. + """ + + def __init__( + self, + document_store: QdrantDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, # noqa: FBT001, FBT002 + return_embedding: bool = False, # noqa: FBT001, FBT002 + ): + """ + Create a QdrantRetriever component. + + :param document_store: An instance of QdrantDocumentStore. + :param filters: A dictionary with filters to narrow down the search space. Default is None. + :param top_k: The maximum number of documents to retrieve. Default is 10. + :param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True. + :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False. + + :raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore. + """ + + if not isinstance(document_store, QdrantDocumentStore): + msg = "document_store must be an instance of QdrantDocumentStore" + raise ValueError(msg) + + self._document_store = document_store + + self._filters = filters + self._top_k = top_k + self._scale_score = scale_score + self._return_embedding = return_embedding + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + d = default_to_dict( + self, + document_store=self._document_store, + filters=self._filters, + top_k=self._top_k, + scale_score=self._scale_score, + return_embedding=self._return_embedding, + ) + d["init_parameters"]["document_store"] = self._document_store.to_dict() + + return d + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "QdrantRetriever": + """ + Deserialize this component from a dictionary. + """ + document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) + data["init_parameters"]["document_store"] = document_store + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + return_embedding: Optional[bool] = None, + ): + """ + Run the Embedding Retriever on the given input data. + + :param query_embedding: Embedding of the query. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to return. + :param scale_score: Whether to scale the scores of the retrieved documents or not. + :param return_embedding: Whether to return the embedding of the retrieved Documents. + :return: The retrieved documents. + + """ + docs = self._document_store.query_by_embedding( + query_embedding=query_embedding, + filters=filters or self._filters, + top_k=top_k or self._top_k, + scale_score=scale_score or self._scale_score, + return_embedding=return_embedding or self._return_embedding, + ) + + return {"documents": docs} diff --git a/integrations/qdrant/tests/__init__.py b/integrations/qdrant/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/qdrant/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/qdrant/tests/test_converters.py b/integrations/qdrant/tests/test_converters.py new file mode 100644 index 000000000..dc4866293 --- /dev/null +++ b/integrations/qdrant/tests/test_converters.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest +from qdrant_client.http import models as rest + +from qdrant_haystack.converters import HaystackToQdrant, QdrantToHaystack + +CONTENT_FIELD = "content" +NAME_FIELD = "name" +EMBEDDING_FIELD = "vector" + + +@pytest.fixture +def haystack_to_qdrant() -> HaystackToQdrant: + return HaystackToQdrant() + + +@pytest.fixture +def qdrant_to_haystack() -> QdrantToHaystack: + return QdrantToHaystack( + content_field=CONTENT_FIELD, + name_field=NAME_FIELD, + embedding_field=EMBEDDING_FIELD, + ) + + +def test_convert_id_is_deterministic(haystack_to_qdrant: HaystackToQdrant): + first_id = haystack_to_qdrant.convert_id("test-id") + second_id = haystack_to_qdrant.convert_id("test-id") + assert first_id == second_id + + +def test_point_to_document_reverts_proper_structure_from_record( + qdrant_to_haystack: QdrantToHaystack, +): + point = rest.Record( + id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", + payload={ + "id": "my-id", + "id_hash_keys": ["content"], + "content": "Lorem ipsum", + "content_type": "text", + "meta": { + "test_field": 1, + }, + }, + vector=[1.0, 0.0, 0.0, 0.0], + ) + document = qdrant_to_haystack.point_to_document(point) + assert "my-id" == document.id + assert "Lorem ipsum" == document.content + assert "text" == document.content_type + assert {"test_field": 1} == document.meta + assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding) diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py new file mode 100644 index 000000000..1a211655c --- /dev/null +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -0,0 +1,102 @@ +from qdrant_haystack import QdrantDocumentStore + + +def test_to_dict(): + document_store = QdrantDocumentStore(location=":memory:", index="test") + + expected = { + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "init_parameters": { + "location": ":memory:", + "url": None, + "port": 6333, + "grpc_port": 6334, + "prefer_grpc": False, + "https": None, + "api_key": None, + "prefix": None, + "timeout": None, + "host": None, + "path": None, + "index": "test", + "embedding_dim": 768, + "content_field": "content", + "name_field": "name", + "embedding_field": "embedding", + "similarity": "cosine", + "return_embedding": False, + "progress_bar": True, + "duplicate_documents": "overwrite", + "recreate_index": False, + "shard_number": None, + "replication_factor": None, + "write_consistency_factor": None, + "on_disk_payload": None, + "hnsw_config": None, + "optimizers_config": None, + "wal_config": None, + "quantization_config": None, + "init_from": None, + "wait_result_from_api": True, + "metadata": {}, + "write_batch_size": 100, + "scroll_size": 10000, + }, + } + + assert document_store.to_dict() == expected + + +def test_from_dict(): + document_store = QdrantDocumentStore.from_dict( + { + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "init_parameters": { + "location": ":memory:", + "index": "test", + "embedding_dim": 768, + "content_field": "content", + "name_field": "name", + "embedding_field": "embedding", + "similarity": "cosine", + "return_embedding": False, + "progress_bar": True, + "duplicate_documents": "overwrite", + "recreate_index": True, + "shard_number": None, + "quantization_config": None, + "init_from": None, + "wait_result_from_api": True, + "metadata": {}, + "write_batch_size": 1000, + "scroll_size": 10000, + }, + } + ) + + assert all( + [ + document_store.index == "test", + document_store.content_field == "content", + document_store.name_field == "name", + document_store.embedding_field == "embedding", + document_store.similarity == "cosine", + document_store.return_embedding is False, + document_store.progress_bar, + document_store.duplicate_documents == "overwrite", + document_store.recreate_index is True, + document_store.shard_number is None, + document_store.replication_factor is None, + document_store.write_consistency_factor is None, + document_store.on_disk_payload is None, + document_store.hnsw_config is None, + document_store.optimizers_config is None, + document_store.wal_config is None, + document_store.quantization_config is None, + document_store.init_from is None, + document_store.wait_result_from_api, + document_store.metadata == {}, + document_store.write_batch_size == 1000, + document_store.scroll_size == 10000, + ] + ) diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py new file mode 100644 index 000000000..bbc16b9df --- /dev/null +++ b/integrations/qdrant/tests/test_document_store.py @@ -0,0 +1,42 @@ +from typing import List + +import pytest +from haystack import Document +from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.errors import DuplicateDocumentError +from haystack.testing.document_store import ( + CountDocumentsTest, + DeleteDocumentsTest, + WriteDocumentsTest, +) + +from qdrant_haystack import QdrantDocumentStore + + +class TestQdrantStoreBaseTests(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + @pytest.fixture + def document_store(self) -> QdrantDocumentStore: + return QdrantDocumentStore( + ":memory:", + recreate_index=True, + return_embedding=True, + wait_result_from_api=True, + ) + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test. + """ + + # Check that the lengths of the lists are the same + assert len(received) == len(expected) + + # Check that the sets are equal, meaning the content and IDs match regardless of order + assert {doc.id for doc in received} == {doc.id for doc in expected} + + def test_write_documents(self, document_store: QdrantDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + with pytest.raises(DuplicateDocumentError): + document_store.write_documents(docs, DuplicatePolicy.FAIL) diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py new file mode 100644 index 000000000..a25f4a672 --- /dev/null +++ b/integrations/qdrant/tests/test_filters.py @@ -0,0 +1,115 @@ +from typing import List + +import pytest +from haystack import Document +from haystack.testing.document_store import FilterDocumentsTest +from haystack.utils.filters import FilterError + +from qdrant_haystack import QdrantDocumentStore + + +class TestQdrantStoreBaseTests(FilterDocumentsTest): + @pytest.fixture + def document_store(self) -> QdrantDocumentStore: + return QdrantDocumentStore( + ":memory:", + recreate_index=True, + return_embedding=True, + wait_result_from_api=True, + ) + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test. + """ + + # Check that the lengths of the lists are the same + assert len(received) == len(expected) + + # Check that the sets are equal, meaning the content and IDs match regardless of order + assert {doc.id for doc in received} == {doc.id for doc in expected} + + def test_not_operator(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + filters={ + "operator": "NOT", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.name", "operator": "==", "value": "name_0"}, + ], + } + ) + self.assert_documents_are_equal( + result, + [d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")], + ) + + # ======== OVERRIDES FOR NONE VALUED FILTERS ======== + + def test_comparison_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "==", "value": None}) + self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") is None]) + + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "!=", "value": None}) + self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") is not None]) + + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": None}) + self.assert_documents_are_equal(result, []) + + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": None}) + self.assert_documents_are_equal(result, []) + + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": None}) + self.assert_documents_are_equal(result, []) + + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": None}) + self.assert_documents_are_equal(result, []) + + # ======== ========================== ======== + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Cannot distinguish errors yet") + def test_missing_top_level_operator_key(self, document_store, filterable_docs): + ... diff --git a/integrations/qdrant/tests/test_legacy_filters.py b/integrations/qdrant/tests/test_legacy_filters.py new file mode 100644 index 000000000..6cb78c653 --- /dev/null +++ b/integrations/qdrant/tests/test_legacy_filters.py @@ -0,0 +1,459 @@ +from typing import List + +import pytest +from haystack import Document +from haystack.document_stores import DocumentStore +from haystack.testing.document_store import LegacyFilterDocumentsTest +from haystack.utils.filters import FilterError + +from qdrant_haystack import QdrantDocumentStore + +# The tests below are from haystack.testing.document_store.LegacyFilterDocumentsTest +# Updated to include `meta` prefix for filter keys wherever necessary +# And skip tests that are not supported in Qdrant(Dataframes, embeddings) + + +class TestQdrantLegacyFilterDocuments(LegacyFilterDocumentsTest): + """ + Utility class to test a Document Store `filter_documents` method using different types of legacy filters + """ + + @pytest.fixture + def document_store(self) -> QdrantDocumentStore: + return QdrantDocumentStore( + ":memory:", + recreate_index=True, + return_embedding=True, + wait_result_from_api=True, + ) + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test. + """ + + # Check that the lengths of the lists are the same + assert len(received) == len(expected) + + # Check that the sets are equal, meaning the content and IDs match regardless of order + assert {doc.id for doc in received} == {doc.id for doc in expected} + + def test_filter_simple_metadata_value(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": "100"}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_filter_document_dataframe(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + def test_eq_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$eq": "100"}}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_eq_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": "100"}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_eq_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_eq_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsNotEqualTest + + def test_ne_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$ne": "100"}}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"]) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_ne_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_ne_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsInTest + + def test_filter_simple_list_single_element(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100"]}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_filter_simple_list_one_value(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100"]}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]]) + + def test_filter_simple_list(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100", "123"]}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], + ) + + def test_incorrect_filter_value(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["nope"]}) + self.assert_documents_are_equal(result, []) + + def test_in_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$in": ["100", "123", "n.a."]}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], + ) + + def test_in_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100", "123", "n.a."]}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], + ) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_in_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_in_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsNotInTest + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_nin_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_nin_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + def test_nin_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$nin": ["100", "123", "n.a."]}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]], + ) + + # LegacyFilterDocumentsGreaterThanTest + + def test_gt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$gt": 0.0}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] > 0], + ) + + def test_gt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$gt": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_gt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_gt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsGreaterThanEqualTest + + def test_gte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$gte": -2}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] >= -2], + ) + + def test_gte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$gte": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_gte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_gte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsLessThanTest + + def test_lt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lt": 0.0}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] < 0], + ) + + def test_lt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$lt": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_lt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_lt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsLessThanEqualTest + + def test_lte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] <= 2.0], + ) + + def test_lte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$lte": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_lte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_lte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsSimpleLogicalTest + + def test_filter_simple_or(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters = { + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + } + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if (doc.meta.get("number") is not None and doc.meta["number"] < 1) + or doc.meta.get("name") in ["name_0", "name_1"] + ], + ) + + def test_filter_simple_implicit_and_with_multi_key_dict( + self, document_store: DocumentStore, filterable_docs: List[Document] + ): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0.0}}) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] >= 0.0 and doc.meta["number"] <= 2.0 + ], + ) + + def test_filter_simple_explicit_and_with_list(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 + ], + ) + + def test_filter_simple_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0}}) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 + ], + ) + + # LegacyFilterDocumentsNestedLogicalTest( + + def test_filter_nested_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "meta.number": {"$lte": 2, "$gte": 0}, + "meta.name": ["name_0", "name_1"], + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + "number" in doc.meta + and doc.meta["number"] <= 2 + and doc.meta["number"] >= 0 + and doc.meta.get("name") in ["name_0", "name_1"] + ) + ], + ) + + def test_filter_nested_or(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters = { + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + } + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("name") in ["name_0", "name_1"] + or (doc.meta.get("number") is not None and doc.meta["number"] < 1) + ) + ], + ) + + def test_filter_nested_and_or_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "$and": { + "meta.page": {"$eq": "123"}, + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + }, + } + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("page") in ["123"] + and ( + doc.meta.get("name") in ["name_0", "name_1"] + or ("number" in doc.meta and doc.meta["number"] < 1) + ) + ) + ], + ) + + def test_filter_nested_and_or_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "meta.page": {"$eq": "123"}, + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + }, + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("page") in ["123"] + and ( + doc.meta.get("name") in ["name_0", "name_1"] + or ("number" in doc.meta and doc.meta["number"] < 1) + ) + ) + ], + ) + + def test_filter_nested_or_and(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "$or": { + "meta.number": {"$lt": 1}, + "$and": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "$not": {"meta.chapter": {"$eq": "intro"}}, + }, + } + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + (doc.meta.get("number") is not None and doc.meta["number"] < 1) + or (doc.meta.get("name") in ["name_0", "name_1"] and (doc.meta.get("chapter") != "intro")) + ) + ], + ) + + def test_filter_nested_multiple_identical_operators_same_level( + self, document_store: DocumentStore, filterable_docs: List[Document] + ): + document_store.write_documents(filterable_docs) + filters = { + "$or": [ + { + "$and": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.page": "100", + } + }, + { + "$and": { + "meta.chapter": {"$in": ["intro", "abstract"]}, + "meta.page": "123", + } + }, + ] + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + (doc.meta.get("name") in ["name_0", "name_1"] and doc.meta.get("page") == "100") + or (doc.meta.get("chapter") in ["intro", "abstract"] and doc.meta.get("page") == "123") + ) + ], + ) + + def test_no_filter_not_empty(self, document_store: DocumentStore): + docs = [Document(content="test doc")] + document_store.write_documents(docs) + self.assert_documents_are_equal(document_store.filter_documents(), docs) + self.assert_documents_are_equal(document_store.filter_documents(filters={}), docs) diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py new file mode 100644 index 000000000..22eabfaad --- /dev/null +++ b/integrations/qdrant/tests/test_retriever.py @@ -0,0 +1,113 @@ +from typing import List + +from haystack.dataclasses import Document +from haystack.testing.document_store import ( + FilterableDocsFixtureMixin, + _random_embeddings, +) + +from qdrant_haystack import QdrantDocumentStore +from qdrant_haystack.retriever import QdrantRetriever + + +class TestQdrantRetriever(FilterableDocsFixtureMixin): + def test_init_default(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantRetriever(document_store=document_store) + assert retriever._document_store == document_store + assert retriever._filters is None + assert retriever._top_k == 10 + assert retriever._return_embedding is False + + def test_to_dict(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "qdrant_haystack.retriever.QdrantRetriever", + "init_parameters": { + "document_store": { + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "init_parameters": { + "location": ":memory:", + "url": None, + "port": 6333, + "grpc_port": 6334, + "prefer_grpc": False, + "https": None, + "api_key": None, + "prefix": None, + "timeout": None, + "host": None, + "path": None, + "index": "test", + "embedding_dim": 768, + "content_field": "content", + "name_field": "name", + "embedding_field": "embedding", + "similarity": "cosine", + "return_embedding": False, + "progress_bar": True, + "duplicate_documents": "overwrite", + "recreate_index": False, + "shard_number": None, + "replication_factor": None, + "write_consistency_factor": None, + "on_disk_payload": None, + "hnsw_config": None, + "optimizers_config": None, + "wal_config": None, + "quantization_config": None, + "init_from": None, + "wait_result_from_api": True, + "metadata": {}, + "write_batch_size": 100, + "scroll_size": 10000, + }, + }, + "filters": None, + "top_k": 10, + "scale_score": True, + "return_embedding": False, + }, + } + + def test_from_dict(self): + data = { + "type": "qdrant_haystack.retriever.QdrantRetriever", + "init_parameters": { + "document_store": { + "init_parameters": {"location": ":memory:", "index": "test"}, + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + }, + "filters": None, + "top_k": 5, + "scale_score": False, + "return_embedding": True, + }, + } + retriever = QdrantRetriever.from_dict(data) + assert isinstance(retriever._document_store, QdrantDocumentStore) + assert retriever._document_store.index == "test" + assert retriever._filters is None + assert retriever._top_k == 5 + assert retriever._scale_score is False + assert retriever._return_embedding is True + + def test_run(self, filterable_docs: List[Document]): + document_store = QdrantDocumentStore(location=":memory:", index="Boi") + + document_store.write_documents(filterable_docs) + + retriever = QdrantRetriever(document_store=document_store) + + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768)) + + assert len(results["documents"]) == 10 # type: ignore + + results = retriever.run(query_embedding=_random_embeddings(768), top_k=5, return_embedding=False) + + assert len(results["documents"]) == 5 # type: ignore + + for document in results["documents"]: # type: ignore + assert document.embedding is None From 743a154dbede3eb0929541ff4370a98746f67c25 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 14:45:47 +0100 Subject: [PATCH 21/29] chore: pin unstructured api version (#105) * pin unstructured api version * Update unstructured_fileconverter.yml * try --- .github/workflows/unstructured_fileconverter.yml | 2 +- integrations/unstructured/fileconverter/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unstructured_fileconverter.yml b/.github/workflows/unstructured_fileconverter.yml index 83e355ac6..ee70510e9 100644 --- a/.github/workflows/unstructured_fileconverter.yml +++ b/.github/workflows/unstructured_fileconverter.yml @@ -36,7 +36,7 @@ jobs: --health-cmd "curl --fail http://localhost:8000/healthcheck || exit 1" --health-interval 10s --health-timeout 1s - --health-retries 10 + --health-retries 10 steps: - uses: actions/checkout@v4 diff --git a/integrations/unstructured/fileconverter/pyproject.toml b/integrations/unstructured/fileconverter/pyproject.toml index c19409cde..bb977bdaa 100644 --- a/integrations/unstructured/fileconverter/pyproject.toml +++ b/integrations/unstructured/fileconverter/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ dependencies = [ # we distribute the preview version of Haystack 2.0 under the package "haystack-ai" "haystack-ai", - "unstructured", + "unstructured<0.11.4", # FIXME: investigate why 0.11.4 broke the tests ] [project.urls] From 67771ae13298f6b0824e3432129d59a618aade6f Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 14:47:41 +0100 Subject: [PATCH 22/29] fix labeller config (#106) --- .github/labeler.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/labeler.yml b/.github/labeler.yml index e1e9727f6..3b26050b9 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -2,38 +2,47 @@ integration:chroma: - changed-files: - any-glob-to-any-file: "integrations/chroma/**/*" + - any-glob-to-any-file: ".github/workflows/chroma.yml" integration:cohere: - changed-files: - any-glob-to-any-file: "integrations/cohere/**/*" + - any-glob-to-any-file: ".github/workflows/cohere.yml" integration:elasticsearch: - changed-files: - any-glob-to-any-file: "integrations/elasticsearch/**/*" + - any-glob-to-any-file: ".github/workflows/elasticsearch.yml" integration:google-vertex: - changed-files: - any-glob-to-any-file: "integrations/google-vertex/**/*" + - any-glob-to-any-file: ".github/workflows/google_vertex.yml" integration:gradient: - changed-files: - any-glob-to-any-file: "integrations/gradient/**/*" + - any-glob-to-any-file: ".github/workflows/gradient.yml" integration:instructor-embedders: - changed-files: - any-glob-to-any-file: "integrations/instructor-embedders/**/*" + - any-glob-to-any-file: ".github/workflows/instructor_embedders.yml" integration:jina: - changed-files: - any-glob-to-any-file: "integrations/jina/**/*" + - any-glob-to-any-file: ".github/workflows/jina.yml" integration:opensearch: - changed-files: - any-glob-to-any-file: "integrations/opensearch/**/*" + - any-glob-to-any-file: ".github/workflows/opensearch.yml" integration:unstructured-fileconverter: - changed-files: - any-glob-to-any-file: "integrations/unstructured/fileconverter/**/*" + - any-glob-to-any-file: ".github/workflows/unstructured_fileconverter.yml" # Topics topic:CI: From 6615a06f5a1d32a24bead5ecc7230f06d1827dc8 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 15:01:52 +0100 Subject: [PATCH 23/29] refactor: use `hatch_vcs` to manage integrations versioning (#103) * move versioning to vcs for all the integrations * qdrant --- integrations/chroma/pyproject.toml | 7 ++++++- integrations/chroma/src/chroma_haystack/__about__.py | 4 ---- integrations/cohere/pyproject.toml | 9 +++++++-- integrations/cohere/src/cohere_haystack/__about__.py | 4 ---- integrations/elasticsearch/pyproject.toml | 9 +++++++-- .../src/elasticsearch_haystack/__about__.py | 4 ---- integrations/google-vertex/pyproject.toml | 9 +++++++-- .../src/google_vertex_haystack/__about__.py | 4 ---- integrations/gradient/pyproject.toml | 9 +++++++-- integrations/gradient/src/gradient_haystack/__about__.py | 4 ---- .../instructor_embedders_haystack/__about__.py | 4 ---- integrations/instructor-embedders/pyproject.toml | 9 +++++++-- integrations/jina/pyproject.toml | 9 +++++++-- integrations/jina/src/jina_haystack/__about__.py | 4 ---- integrations/opensearch/pyproject.toml | 9 +++++++-- .../opensearch/src/opensearch_haystack/__about__.py | 4 ---- integrations/qdrant/pyproject.toml | 9 +++++++-- integrations/qdrant/src/qdrant_haystack/__about__.py | 4 ---- integrations/unstructured/fileconverter/pyproject.toml | 9 +++++++-- .../src/unstructured_fileconverter_haystack/__about__.py | 4 ---- 20 files changed, 69 insertions(+), 59 deletions(-) delete mode 100644 integrations/chroma/src/chroma_haystack/__about__.py delete mode 100644 integrations/cohere/src/cohere_haystack/__about__.py delete mode 100644 integrations/elasticsearch/src/elasticsearch_haystack/__about__.py delete mode 100644 integrations/google-vertex/src/google_vertex_haystack/__about__.py delete mode 100644 integrations/gradient/src/gradient_haystack/__about__.py delete mode 100644 integrations/instructor-embedders/instructor_embedders_haystack/__about__.py delete mode 100644 integrations/jina/src/jina_haystack/__about__.py delete mode 100644 integrations/opensearch/src/opensearch_haystack/__about__.py delete mode 100644 integrations/qdrant/src/qdrant_haystack/__about__.py delete mode 100644 integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/__about__.py diff --git a/integrations/chroma/pyproject.toml b/integrations/chroma/pyproject.toml index b50723f4c..fdabcf245 100644 --- a/integrations/chroma/pyproject.toml +++ b/integrations/chroma/pyproject.toml @@ -33,7 +33,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/chroma" [tool.hatch.version] -path = "src/chroma_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/chroma-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/chroma-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/chroma/src/chroma_haystack/__about__.py b/integrations/chroma/src/chroma_haystack/__about__.py deleted file mode 100644 index 63065fc23..000000000 --- a/integrations/chroma/src/chroma_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.8.1" diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 71501d28e..0e62d540b 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -35,7 +35,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere" [tool.hatch.version] -path = "src/cohere_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/cohere-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/cohere/src/cohere_haystack/__about__.py b/integrations/cohere/src/cohere_haystack/__about__.py deleted file mode 100644 index 447ed9770..000000000 --- a/integrations/cohere/src/cohere_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.2.0" diff --git a/integrations/elasticsearch/pyproject.toml b/integrations/elasticsearch/pyproject.toml index 12922fbc2..17c9158b9 100644 --- a/integrations/elasticsearch/pyproject.toml +++ b/integrations/elasticsearch/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -34,7 +34,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/elasticsearch" [tool.hatch.version] -path = "src/elasticsearch_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/elasticsearch-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/elasticsearch-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/elasticsearch/src/elasticsearch_haystack/__about__.py b/integrations/elasticsearch/src/elasticsearch_haystack/__about__.py deleted file mode 100644 index 8430bf8d4..000000000 --- a/integrations/elasticsearch/src/elasticsearch_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.1.1" diff --git a/integrations/google-vertex/pyproject.toml b/integrations/google-vertex/pyproject.toml index a34853a60..b773742e2 100644 --- a/integrations/google-vertex/pyproject.toml +++ b/integrations/google-vertex/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -34,7 +34,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/google-vertex" [tool.hatch.version] -path = "src/google_vertex_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/google-vertex-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/google-vertex-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/google-vertex/src/google_vertex_haystack/__about__.py b/integrations/google-vertex/src/google_vertex_haystack/__about__.py deleted file mode 100644 index d4a92df1b..000000000 --- a/integrations/google-vertex/src/google_vertex_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.2" diff --git a/integrations/gradient/pyproject.toml b/integrations/gradient/pyproject.toml index ae9e047c7..5a0930cc4 100644 --- a/integrations/gradient/pyproject.toml +++ b/integrations/gradient/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -35,7 +35,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/gradient" [tool.hatch.version] -path = "src/gradient_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/gradient-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/gradient-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/gradient/src/gradient_haystack/__about__.py b/integrations/gradient/src/gradient_haystack/__about__.py deleted file mode 100644 index bccfd8317..000000000 --- a/integrations/gradient/src/gradient_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.1.0" diff --git a/integrations/instructor-embedders/instructor_embedders_haystack/__about__.py b/integrations/instructor-embedders/instructor_embedders_haystack/__about__.py deleted file mode 100644 index 447ed9770..000000000 --- a/integrations/instructor-embedders/instructor_embedders_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.2.0" diff --git a/integrations/instructor-embedders/pyproject.toml b/integrations/instructor-embedders/pyproject.toml index c10f43703..31834c259 100644 --- a/integrations/instructor-embedders/pyproject.toml +++ b/integrations/instructor-embedders/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -55,7 +55,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/components/instructor-embedders" [tool.hatch.version] -path = "instructor_embedders_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/instructor-embedders-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/instructor-embedders-v[0-9]*"' [tool.hatch.envs.default] dependencies = ["pytest", "pytest-cov"] diff --git a/integrations/jina/pyproject.toml b/integrations/jina/pyproject.toml index a6a57efda..0fa01a7ab 100644 --- a/integrations/jina/pyproject.toml +++ b/integrations/jina/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -32,7 +32,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/jina" [tool.hatch.version] -path = "src/jina_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/jina-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/jina-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/jina/src/jina_haystack/__about__.py b/integrations/jina/src/jina_haystack/__about__.py deleted file mode 100644 index 0e4fa27cf..000000000 --- a/integrations/jina/src/jina_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.1" diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index fb02b3b63..cb63ee2f5 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -34,7 +34,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/opensearch" [tool.hatch.version] -path = "src/opensearch_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/opensearch-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/opensearch-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/opensearch/src/opensearch_haystack/__about__.py b/integrations/opensearch/src/opensearch_haystack/__about__.py deleted file mode 100644 index 8430bf8d4..000000000 --- a/integrations/opensearch/src/opensearch_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.1.1" diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index f8209e0c9..d1086fcdf 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -36,7 +36,12 @@ Documentation = "https://github.com/deepset-ai/haystack-core-integrations/blob/m Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" [tool.hatch.version] -path = "src/qdrant_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/qdrant-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/qdrant-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/qdrant/src/qdrant_haystack/__about__.py b/integrations/qdrant/src/qdrant_haystack/__about__.py deleted file mode 100644 index 0e4fa27cf..000000000 --- a/integrations/qdrant/src/qdrant_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.1" diff --git a/integrations/unstructured/fileconverter/pyproject.toml b/integrations/unstructured/fileconverter/pyproject.toml index bb977bdaa..97d3e068c 100644 --- a/integrations/unstructured/fileconverter/pyproject.toml +++ b/integrations/unstructured/fileconverter/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] @@ -35,7 +35,12 @@ Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/unstructured/fileconverter" [tool.hatch.version] -path = "src/unstructured_fileconverter_haystack/__about__.py" +source = "vcs" +tag-pattern = 'integrations\/unstructured-fileconverter-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../../.." +git_describe_command = 'git describe --tags --match="integrations/unstructured-fileconverter-v[0-9]*"' [tool.hatch.envs.default] dependencies = [ diff --git a/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/__about__.py b/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/__about__.py deleted file mode 100644 index 7200d918c..000000000 --- a/integrations/unstructured/fileconverter/src/unstructured_fileconverter_haystack/__about__.py +++ /dev/null @@ -1,4 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present deepset GmbH -# -# SPDX-License-Identifier: Apache-2.0 -__version__ = "0.0.4" From f0d352514d5bfda0a4fd07c2a4de8531aff6818f Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 15:07:47 +0100 Subject: [PATCH 24/29] ci: add qdrant to the labeller (#108) --- .github/labeler.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/labeler.yml b/.github/labeler.yml index 3b26050b9..f2dcedad2 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -39,6 +39,11 @@ integration:opensearch: - any-glob-to-any-file: "integrations/opensearch/**/*" - any-glob-to-any-file: ".github/workflows/opensearch.yml" +integration:qdrant: + - changed-files: + - any-glob-to-any-file: "integrations/qdrant/**/*" + - any-glob-to-any-file: ".github/workflows/qdrant.yml" + integration:unstructured-fileconverter: - changed-files: - any-glob-to-any-file: "integrations/unstructured/fileconverter/**/*" From 87d21bbd161da5f8f2556372a136a4d21418d76e Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Fri, 15 Dec 2023 15:46:59 +0100 Subject: [PATCH 25/29] Add qdrant integration to README (#109) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 395f59db4..e773d707d 100644 --- a/README.md +++ b/README.md @@ -69,5 +69,6 @@ deepset-haystack | [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | | [instructor-embedders-haystack](integrations/instructor-embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | +| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/fileconverter/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured / fileconverter](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml) | | [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | From 27f55ff807fbd738fda44ecb8893addb24093ded Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 16:09:16 +0100 Subject: [PATCH 26/29] update issue template --- .../ISSUE_TEMPLATE/new-integration-proposal.md | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/new-integration-proposal.md b/.github/ISSUE_TEMPLATE/new-integration-proposal.md index d12ed2085..21964186c 100644 --- a/.github/ISSUE_TEMPLATE/new-integration-proposal.md +++ b/.github/ISSUE_TEMPLATE/new-integration-proposal.md @@ -9,23 +9,25 @@ assignees: '' ## Summary and motivation -Briefly explain the feature request: why do we need this feature? What use cases does it support? - -## Alternatives - -A clear and concise description of any alternative solutions or features you've considered. +Briefly explain the request: why do we need this integration? What use cases does it support? ## Detailed design -Explain the design in enough detail for somebody familiar with Haystack to understand, and for somebody familiar with the implementation to implement. Get into specifics and corner-cases, and include examples of how the feature is used. Also, if there's any new terminology involved, define it here. +Explain the design in enough detail for somebody familiar with Haystack to understand, and for somebody familiar with +the implementation to implement. Get into specifics and corner-cases, and include examples of how the feature is used. +Also, if there's any new terminology involved, define it here. ## Checklist -If the feature request is accepted, ensure the following checklist is complete before closing the issue. +If the request is accepted, ensure the following checklist is complete before closing this issue. -- [ ] The package has been released on PyPI +- [ ] The code was merged in the `main` branch +- [ ] Docs are published at https://docs.haystack.deepset.ai/ - [ ] There is a Github workflow running the tests for the integration nightly and at every PR - [ ] A label named like `integration:` has been added to this repo - [ ] The [labeler.yml](https://github.com/deepset-ai/haystack-core-integrations/blob/main/.github/labeler.yml) file has been updated +- [ ] The package has been released on PyPI - [ ] An integration tile has been added to https://github.com/deepset-ai/haystack-integrations - [ ] The integration has been listed in the [Inventory section](https://github.com/deepset-ai/haystack-core-integrations#inventory) of this repo README +- [ ] There is an example available to demonstrate the feature +- [ ] The feature was announced through social media \ No newline at end of file From d6501a8ffcd171634917ff9f2c8ac722ec0c48c0 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Fri, 15 Dec 2023 19:08:33 +0100 Subject: [PATCH 27/29] Fix GeminiGenerator input type (#111) --- .../src/google_vertex_haystack/generators/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py index aa8f9d9e8..995a52614 100644 --- a/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py +++ b/integrations/google-vertex/src/google_vertex_haystack/generators/gemini.py @@ -131,7 +131,7 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) @component.output_types(answers=List[Union[str, Dict[str, str]]]) - def run(self, parts: Variadic[List[Union[str, ByteStream, Part]]]): + def run(self, parts: Variadic[Union[str, ByteStream, Part]]): converted_parts = [self._convert_part(p) for p in parts] contents = [Content(parts=converted_parts, role="user")] From 3805d5b2317e87e73392072db27fd03d96cf9d38 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 22:14:16 +0100 Subject: [PATCH 28/29] fix vertex badge --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e773d707d..9c47d9a98 100644 --- a/README.md +++ b/README.md @@ -65,10 +65,10 @@ deepset-haystack | [chroma-haystack](integrations/chroma/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/chroma-haystack.svg)](https://pypi.org/project/chroma-haystack) | [![Test / chroma](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/chroma.yml) | | [cohere-haystack](integrations/cohere/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/cohere-haystack.svg)](https://pypi.org/project/cohere-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/cohere.yml) | | [elasticsearch-haystack](integrations/elasticsearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) | [![Test / elasticsearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/elasticsearch.yml) | -| [google-vertex-haystack](integrations/google-vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google-vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google-vertex.yml) | +| [google-vertex-haystack](integrations/google-vertex/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/google-vertex-haystack.svg)](https://pypi.org/project/google-vertex-haystack) | [![Test / google-vertex](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/google_vertex.yml) | | [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | | [instructor-embedders-haystack](integrations/instructor-embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | -| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | +| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/fileconverter/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured / fileconverter](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml) | | [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) | From 36607c74ec686e192dbfb8f7a9bf2483e56375a5 Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Fri, 15 Dec 2023 22:23:09 +0100 Subject: [PATCH 29/29] adjust badge color --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9c47d9a98..2dc547050 100644 --- a/README.md +++ b/README.md @@ -69,6 +69,6 @@ deepset-haystack | [gradient-haystack](integrations/gradient/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/gradient-haystack.svg)](https://pypi.org/project/gradient-haystack) | [![Test / gradient](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/gradient.yml) | | [instructor-embedders-haystack](integrations/instructor-embedders/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/instructor-embedders-haystack.svg)](https://pypi.org/project/instructor-embedders-haystack) | [![Test / instructor-embedders](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/instructor_embedders.yml) | | [opensearch-haystack](integrations/opensearch/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/opensearch-haystack.svg)](https://pypi.org/project/opensearch-haystack) | [![Test / opensearch](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/opensearch.yml) | -| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | +| [qdrant-haystack](integrations/qdrant/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg?color=orange)](https://pypi.org/project/qdrant-haystack) | [![Test / qdrant](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/qdrant.yml) | | [unstructured-fileconverter-haystack](integrations/unstructured/fileconverter/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured / fileconverter](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured_fileconverter.yml) | | [jina-haystack](integrations/jina/) | Embedder | [![PyPI - Version](https://img.shields.io/pypi/v/jina-haystack.svg)](https://pypi.org/project/jina-haystack) | [![Test / cohere](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/jina.yml) |