From ad5a29039be7a68bb18eb1597da8e708ab6c793f Mon Sep 17 00:00:00 2001 From: Ashwin Mathur <97467100+awinml@users.noreply.github.com> Date: Wed, 21 Feb 2024 18:52:08 +0530 Subject: [PATCH] feat: Add Optimum Embedders (#379) * Add Optimum Embedders * Add CI workflow * Update dependencies * Add embedding backend * Add pooling sub module; Update pyproject.toml with dependency info * Remove backend factory; Address review comments * Add API docs generation workflow * Add additional tests * Update order to 185 for API docs * Update integrations/optimum/pydoc/config.yml Co-authored-by: Daria Fokina --------- Co-authored-by: Daria Fokina --- .github/labeler.yml | 5 + .github/workflows/optimum.yml | 60 ++++ integrations/optimum/LICENSE.txt | 73 +++++ integrations/optimum/README.md | 30 ++ integrations/optimum/pydoc/config.yml | 32 ++ integrations/optimum/pyproject.toml | 202 ++++++++++++ .../components/embedders/__init__.py | 8 + .../components/embedders/optimum_backend.py | 99 ++++++ .../embedders/optimum_document_embedder.py | 251 +++++++++++++++ .../embedders/optimum_text_embedder.py | 201 ++++++++++++ .../components/embedders/pooling.py | 137 +++++++++ integrations/optimum/tests/__init__.py | 3 + .../optimum/tests/test_optimum_backend.py | 32 ++ .../tests/test_optimum_document_embedder.py | 288 ++++++++++++++++++ .../tests/test_optimum_text_embedder.py | 193 ++++++++++++ 15 files changed, 1614 insertions(+) create mode 100644 .github/workflows/optimum.yml create mode 100644 integrations/optimum/LICENSE.txt create mode 100644 integrations/optimum/README.md create mode 100644 integrations/optimum/pydoc/config.yml create mode 100644 integrations/optimum/pyproject.toml create mode 100644 integrations/optimum/src/haystack_integrations/components/embedders/__init__.py create mode 100644 integrations/optimum/src/haystack_integrations/components/embedders/optimum_backend.py create mode 100644 integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py create mode 100644 integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py create mode 100644 integrations/optimum/src/haystack_integrations/components/embedders/pooling.py create mode 100644 integrations/optimum/tests/__init__.py create mode 100644 integrations/optimum/tests/test_optimum_backend.py create mode 100644 integrations/optimum/tests/test_optimum_document_embedder.py create mode 100644 integrations/optimum/tests/test_optimum_text_embedder.py diff --git a/.github/labeler.yml b/.github/labeler.yml index 5f1b76912..a60f02c67 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -84,6 +84,11 @@ integration:opensearch: - any-glob-to-any-file: "integrations/opensearch/**/*" - any-glob-to-any-file: ".github/workflows/opensearch.yml" +integration:optimum: + - changed-files: + - any-glob-to-any-file: "integrations/optimum/**/*" + - any-glob-to-any-file: ".github/workflows/optimum.yml" + integration:pgvector: - changed-files: - any-glob-to-any-file: "integrations/pgvector/**/*" diff --git a/.github/workflows/optimum.yml b/.github/workflows/optimum.yml new file mode 100644 index 000000000..3b0d137da --- /dev/null +++ b/.github/workflows/optimum.yml @@ -0,0 +1,60 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / optimum + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - "integrations/optimum/**" + - ".github/workflows/optimum.yml" + +defaults: + run: + working-directory: integrations/optimum + +concurrency: + group: optimum-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.9", "3.10"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Generate docs + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run docs + + - name: Run tests + run: hatch run cov diff --git a/integrations/optimum/LICENSE.txt b/integrations/optimum/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/optimum/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/optimum/README.md b/integrations/optimum/README.md new file mode 100644 index 000000000..1438f6e92 --- /dev/null +++ b/integrations/optimum/README.md @@ -0,0 +1,30 @@ +# optimum + +[![PyPI - Version](https://img.shields.io/pypi/v/optimum.svg)](https://pypi.org/project/optimum-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/optimum.svg)](https://pypi.org/project/optimum-haystack) + +----- + +Component to embed strings and Documents using models loaded with the HuggingFace Optimum library. This component is designed to seamlessly inference models using the high speed ONNX runtime. + +**Table of Contents** + +- [Installation](#installation) +- [License](#license) + +## Installation + +To use the ONNX runtime for CPU, use the CPU version: +```console +pip install optimum-haystack[cpu] +``` + +For using the GPU runtimes: +```console +pip install optimum-haystack[gpu] +``` + + +## License + +`optimum-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/optimum/pydoc/config.yml b/integrations/optimum/pydoc/config.yml new file mode 100644 index 000000000..5fb353b5d --- /dev/null +++ b/integrations/optimum/pydoc/config.yml @@ -0,0 +1,32 @@ +loaders: + - type: haystack_pydoc_tools.loaders.CustomPythonLoader + search_path: [../src] + modules: + [ + "haystack_integrations.components.embedders.optimum_backend", + "haystack_integrations.components.embedders.optimum_document_embedder", + "haystack_integrations.components.embedders.optimum_text_embedder", + "haystack_integrations.components.embedders.pooling", + ] + ignore_when_discovered: ["__init__"] +processors: + - type: filter + expression: + documented_only: true + do_not_filter_modules: false + skip_empty_modules: true + - type: smart + - type: crossref +renderer: + type: haystack_pydoc_tools.renderers.ReadmePreviewRenderer + excerpt: Optimum integration for Haystack + category_slug: integrations-api + title: Optimum + slug: integrations-optimum + order: 185 + markdown: + descriptive_class_title: false + descriptive_module_title: true + add_method_class_prefix: true + add_member_class_prefix: false + filename: _readme_optimum.md diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml new file mode 100644 index 000000000..17bca2597 --- /dev/null +++ b/integrations/optimum/pyproject.toml @@ -0,0 +1,202 @@ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +[project] +name = "optimum-haystack" +dynamic = ["version"] +description = "Component to embed strings and Documents using models loaded with the HuggingFace Optimum library. This component is designed to seamlessly inference models using the high speed ONNX runtime." +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Ashwin Mathur", email = "" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "transformers[sentencepiece]", + # The main export function of Optimum into ONNX has hidden dependencies. + # It depends on either "sentence-transformers", "diffusers" or "timm", based + # on which model is loaded from HF Hub. + # Ref: https://github.com/huggingface/optimum/blob/8651c0ca1cccf095458bc80329dec9df4601edb4/optimum/exporters/onnx/__main__.py#L164 + # "sentence-transformers" has been added, since most embedding models use it + "sentence-transformers>=2.2.0", + "optimum[onnxruntime]" +] + +[project.urls] +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/optimum#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/optimum" + +[tool.hatch.build.targets.wheel] +packages = ["src/haystack_integrations"] + +[tool.hatch.version] +source = "vcs" +tag-pattern = 'integrations\/optimum-v(?P.*)' + +[tool.hatch.version.raw-options] +root = "../.." +git_describe_command = 'git describe --tags --match="integrations/optimum-v[0-9]*"' + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "haystack-pydoc-tools" +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] +docs = ["pydoc-markdown pydoc/config.yml"] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] + + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] + +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.ruff.lint.isort] +known-first-party = ["src"] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "parents" + +[tool.ruff.lint.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] +# Examples can print their output +"examples/**" = ["T201"] +"tests/**" = ["T201"] + +[tool.coverage.run] +source_pkgs = ["optimum", "tests"] +branch = true +parallel = true + +[tool.coverage.paths] +optimum = ["src/haystack_integrations", "*/optimum/src/haystack_integrations"] +tests = ["tests", "*/optimum/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "haystack_integrations.*", + "pytest.*", + "numpy.*", + "optimum.*", + "torch.*", + "transformers.*", + "huggingface_hub.*", + "sentence_transformers.*" +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = ["--strict-markers", "-vv"] +markers = [ + "integration: integration tests", + "unit: unit tests", + "embedders: embedders tests", +] +log_cli = true diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/__init__.py b/integrations/optimum/src/haystack_integrations/components/embedders/__init__.py new file mode 100644 index 000000000..4e5ac1535 --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .optimum_document_embedder import OptimumDocumentEmbedder +from .optimum_text_embedder import OptimumTextEmbedder + +__all__ = ["OptimumDocumentEmbedder", "OptimumTextEmbedder"] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_backend.py new file mode 100644 index 000000000..55d0e6f20 --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_backend.py @@ -0,0 +1,99 @@ +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from haystack.utils.auth import Secret +from haystack_integrations.components.embedders.pooling import Pooling, PoolingMode +from optimum.onnxruntime import ORTModelForFeatureExtraction +from tqdm import tqdm +from transformers import AutoTokenizer + + +class OptimumEmbeddingBackend: + """ + Class to manage Optimum embeddings. + """ + + def __init__(self, model: str, model_kwargs: Dict[str, Any], token: Optional[Secret] = None): + """ + Create an instance of OptimumEmbeddingBackend. + + :param model: A string representing the model id on HF Hub. + :param model_kwargs: Keyword arguments to pass to the model. + :param token: The HuggingFace token to use as HTTP bearer authorization. + """ + # export=True converts the model to ONNX on the fly + self.model = ORTModelForFeatureExtraction.from_pretrained(**model_kwargs, export=True) + self.tokenizer = AutoTokenizer.from_pretrained(model, token=token) + + def embed( + self, + texts_to_embed: Union[str, List[str]], + normalize_embeddings: bool, + pooling_mode: PoolingMode = PoolingMode.MEAN, + progress_bar: bool = False, + batch_size: int = 1, + ) -> Union[List[List[float]], List[float]]: + """ + Embed text or list of texts using the Optimum model. + + :param texts_to_embed: The text or list of texts to embed. + :param normalize_embeddings: Whether to normalize the embeddings to unit length. + :param pooling_mode: The pooling mode to use. + :param progress_bar: Whether to show a progress bar or not. + :param batch_size: Batch size to use. + :return: A single embedding if the input is a single string. A list of embeddings if the input is a list of + strings. + """ + if isinstance(texts_to_embed, str): + texts = [texts_to_embed] + else: + texts = texts_to_embed + + # Determine device for tokenizer output + device = self.model.device + + # Sorting by length + length_sorted_idx = np.argsort([-len(sen) for sen in texts]) + sentences_sorted = [texts[idx] for idx in length_sorted_idx] + + all_embeddings = [] + for i in tqdm( + range(0, len(sentences_sorted), batch_size), disable=not progress_bar, desc="Calculating embeddings" + ): + batch = sentences_sorted[i : i + batch_size] + encoded_input = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device) + + # Only pass required inputs otherwise onnxruntime can raise an error + inputs_to_remove = set(encoded_input.keys()).difference(self.model.inputs_names) + for key in inputs_to_remove: + encoded_input.pop(key) + model_output = self.model(**encoded_input) + + # Pool Embeddings + pooling = Pooling( + pooling_mode=pooling_mode, + attention_mask=encoded_input["attention_mask"].to(device), + model_output=model_output, + ) + sentence_embeddings = pooling.pool_embeddings() + all_embeddings.append(sentence_embeddings) + + embeddings = torch.cat(all_embeddings, dim=0) + + # Normalize all embeddings + if normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + embeddings = embeddings.tolist() + + # Reorder embeddings according to original order + reordered_embeddings: List[List[float]] = [[]] * len(texts) + for embedding, idx in zip(embeddings, length_sorted_idx): + reordered_embeddings[idx] = embedding + + if isinstance(texts_to_embed, str): + # Return the embedding if only one text was passed + return reordered_embeddings[0] + + return reordered_embeddings diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py new file mode 100644 index 000000000..f0310f872 --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py @@ -0,0 +1,251 @@ +from typing import Any, Dict, List, Optional, Union + +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs +from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend +from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode + + +@component +class OptimumDocumentEmbedder: + """ + A component for computing Document embeddings using models loaded with the HuggingFace Optimum library. + This component is designed to seamlessly inference models using the high speed ONNX runtime. + + The embedding of each Document is stored in the `embedding` field of the Document. + + Usage example: + ```python + from haystack.dataclasses import Document + from haystack_integrations.components.embedders import OptimumDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = OptimumDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2") + document_embedder.warm_up() + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + + Key Features and Compatibility: + - **Primary Compatibility**: Designed to work seamlessly with any embedding model present on the Hugging Face + Hub. + - **Conversion to ONNX**: The models are converted to ONNX using the HuggingFace Optimum library. This is + performed in real-time, during the warm-up step. + - **Accelerated Inference on GPU**: Supports using different execution providers such as CUDA and TensorRT, to + accelerate ONNX Runtime inference on GPUs. + Simply pass the execution provider as the onnx_execution_provider parameter. Additonal parameters can be passed + to the model using the model_kwargs parameter. + For more details refer to the HuggingFace documentation: + https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu. + """ + + def __init__( + self, + model: str = "sentence-transformers/all-mpnet-base-v2", + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), # noqa: B008 + prefix: str = "", + suffix: str = "", + normalize_embeddings: bool = True, + onnx_execution_provider: str = "CPUExecutionProvider", + pooling_mode: Optional[Union[str, PoolingMode]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + batch_size: int = 32, + progress_bar: bool = True, + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create a OptimumDocumentEmbedder component. + + :param model: A string representing the model id on HF Hub. + :param token: The HuggingFace token to use as HTTP bearer authorization. + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. + :param normalize_embeddings: Whether to normalize the embeddings to unit length. + :param onnx_execution_provider: The execution provider to use for ONNX models. See + https://onnxruntime.ai/docs/execution-providers/ for possible providers. + + Note: Using the TensorRT execution provider + TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model + optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime + provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We + recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT + execution provider. The usage is as follows: + ```python + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + onnx_execution_provider="TensorrtExecutionProvider", + model_kwargs={ + "provider_options": { + "trt_engine_cache_enable": True, + "trt_engine_cache_path": "tmp/trt_cache", + } + }, + ) + ``` + :param pooling_mode: The pooling mode to use. When None, pooling mode will be inferred from the model config. + The supported pooling modes are: + - "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text + representations. + - "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all + the tokens. + - "mean": Perform Mean Pooling on the output of the embedding model. + - "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt + (input_length). + - "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904. + - "last_token": Perform Last Token Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. + :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. + In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization + parameters. + :param batch_size: Number of Documents to encode at once. + :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments + to keep the logs clean. + :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. + :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + """ + check_valid_model(model, HFModelType.EMBEDDING, token) + self.model = model + + self.token = token + token = token.resolve_value() if token else None + + if isinstance(pooling_mode, str): + self.pooling_mode = PoolingMode.from_str(pooling_mode) + # Infer pooling mode from model config if not provided, + if pooling_mode is None: + self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token) + # Raise error if pooling mode is not found in model config and not specified by user + if self.pooling_mode is None: + modes = {e.value: e for e in PoolingMode} + msg = ( + f"Pooling mode not found in model config and not specified by user." + f" Supported modes are: {list(modes.keys())}" + ) + raise ValueError(msg) + + self.prefix = prefix + self.suffix = suffix + self.normalize_embeddings = normalize_embeddings + self.onnx_execution_provider = onnx_execution_provider + self.batch_size = batch_size + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + + model_kwargs = model_kwargs or {} + + # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters + model_kwargs.setdefault("model_id", model) + model_kwargs.setdefault("provider", onnx_execution_provider) + model_kwargs.setdefault("use_auth_token", token) + + self.model_kwargs = model_kwargs + self.embedding_backend = None + + def warm_up(self): + """ + Load the embedding backend. + """ + if self.embedding_backend is None: + self.embedding_backend = OptimumEmbeddingBackend( + model=self.model, token=self.token, model_kwargs=self.model_kwargs + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + serialization_dict = default_to_dict( + self, + model=self.model, + prefix=self.prefix, + suffix=self.suffix, + normalize_embeddings=self.normalize_embeddings, + onnx_execution_provider=self.onnx_execution_provider, + pooling_mode=self.pooling_mode.value, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + model_kwargs=self.model_kwargs, + token=self.token.to_dict() if self.token else None, + ) + + model_kwargs = serialization_dict["init_parameters"]["model_kwargs"] + model_kwargs.pop("token", None) + + serialize_hf_model_kwargs(model_kwargs) + return serialization_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OptimumDocumentEmbedder": + """ + Deserialize this component from a dictionary. + """ + data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"]) + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"]) + return default_from_dict(cls, data) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + + text_to_embed = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ) + + texts_to_embed.append(text_to_embed) + return texts_to_embed + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + + :param documents: A list of Documents to embed. + :return: A dictionary containing the updated Documents with their embeddings. + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = ( + "OptimumDocumentEmbedder expects a list of Documents as input." + " In case you want to embed a string, please use the OptimumTextEmbedder." + ) + raise TypeError(msg) + + if self.embedding_backend is None: + msg = "The embedding model has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) + + # Return empty list if no documents + if not documents: + return {"documents": []} + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings = self.embedding_backend.embed( + texts_to_embed=texts_to_embed, + normalize_embeddings=self.normalize_embeddings, + pooling_mode=self.pooling_mode, + progress_bar=self.progress_bar, + batch_size=self.batch_size, + ) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents} diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py new file mode 100644 index 000000000..8a33a9403 --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py @@ -0,0 +1,201 @@ +from typing import Any, Dict, List, Optional, Union + +from haystack import component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs +from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend +from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode + + +@component +class OptimumTextEmbedder: + """ + A component to embed text using models loaded with the HuggingFace Optimum library. + This component is designed to seamlessly inference models using the high speed ONNX runtime. + + Usage example: + ```python + from haystack_integrations.components.embedders import OptimumTextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = OptimumTextEmbedder(model="sentence-transformers/all-mpnet-base-v2") + text_embedder.warm_up() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [-0.07804739475250244, 0.1498992145061493,, ...]} + ``` + + Key Features and Compatibility: + - **Primary Compatibility**: Designed to work seamlessly with any embedding model present on the Hugging Face + Hub. + - **Conversion to ONNX**: The models are converted to ONNX using the HuggingFace Optimum library. This is + performed in real-time, during the warm-up step. + - **Accelerated Inference on GPU**: Supports using different execution providers such as CUDA and TensorRT, to + accelerate ONNX Runtime inference on GPUs. + Simply pass the execution provider as the onnx_execution_provider parameter. Additonal parameters can be passed + to the model using the model_kwargs parameter. + For more details refer to the HuggingFace documentation: + https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu. + """ + + def __init__( + self, + model: str = "sentence-transformers/all-mpnet-base-v2", + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), # noqa: B008 + prefix: str = "", + suffix: str = "", + normalize_embeddings: bool = True, + onnx_execution_provider: str = "CPUExecutionProvider", + pooling_mode: Optional[Union[str, PoolingMode]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + ): + """ + Create a OptimumTextEmbedder component. + + :param model: A string representing the model id on HF Hub. + :param token: The HuggingFace token to use as HTTP bearer authorization. + :param prefix: A string to add to the beginning of each text. + :param suffix: A string to add to the end of each text. + :param normalize_embeddings: Whether to normalize the embeddings to unit length. + :param onnx_execution_provider: The execution provider to use for ONNX models. See + https://onnxruntime.ai/docs/execution-providers/ for possible providers. + + Note: Using the TensorRT execution provider + TensorRT requires to build its inference engine ahead of inference, which takes some time due to the model + optimization and nodes fusion. To avoid rebuilding the engine every time the model is loaded, ONNX Runtime + provides a pair of options to save the engine: `trt_engine_cache_enable` and `trt_engine_cache_path`. We + recommend setting these two provider options using the model_kwargs parameter, when using the TensorRT + execution provider. The usage is as follows: + ```python + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + onnx_execution_provider="TensorrtExecutionProvider", + model_kwargs={ + "provider_options": { + "trt_engine_cache_enable": True, + "trt_engine_cache_path": "tmp/trt_cache", + } + }, + ) + ``` + :param pooling_mode: The pooling mode to use. When None, pooling mode will be inferred from the model config. + The supported pooling modes are: + - "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text + representations. + - "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all + the tokens. + - "mean": Perform Mean Pooling on the output of the embedding model. + - "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt + (input_length). + - "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904. + - "last_token": Perform Last Token Pooling on the output of the embedding model. See + https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. + :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. + In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization + parameters. + """ + check_valid_model(model, HFModelType.EMBEDDING, token) + self.model = model + + self.token = token + token = token.resolve_value() if token else None + + if isinstance(pooling_mode, str): + self.pooling_mode = PoolingMode.from_str(pooling_mode) + # Infer pooling mode from model config if not provided, + if pooling_mode is None: + self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token) + # Raise error if pooling mode is not found in model config and not specified by user + if self.pooling_mode is None: + modes = {e.value: e for e in PoolingMode} + msg = ( + f"Pooling mode not found in model config and not specified by user." + f" Supported modes are: {list(modes.keys())}" + ) + raise ValueError(msg) + + self.prefix = prefix + self.suffix = suffix + self.normalize_embeddings = normalize_embeddings + self.onnx_execution_provider = onnx_execution_provider + + model_kwargs = model_kwargs or {} + + # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters + model_kwargs.setdefault("model_id", model) + model_kwargs.setdefault("provider", onnx_execution_provider) + model_kwargs.setdefault("use_auth_token", token) + + self.model_kwargs = model_kwargs + self.embedding_backend = None + + def warm_up(self): + """ + Load the embedding backend. + """ + if self.embedding_backend is None: + self.embedding_backend = OptimumEmbeddingBackend( + model=self.model, token=self.token, model_kwargs=self.model_kwargs + ) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + serialization_dict = default_to_dict( + self, + model=self.model, + prefix=self.prefix, + suffix=self.suffix, + normalize_embeddings=self.normalize_embeddings, + onnx_execution_provider=self.onnx_execution_provider, + pooling_mode=self.pooling_mode.value, + model_kwargs=self.model_kwargs, + token=self.token.to_dict() if self.token else None, + ) + + model_kwargs = serialization_dict["init_parameters"]["model_kwargs"] + model_kwargs.pop("token", None) + + serialize_hf_model_kwargs(model_kwargs) + return serialization_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "OptimumTextEmbedder": + """ + Deserialize this component from a dictionary. + """ + data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"]) + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"]) + return default_from_dict(cls, data) + + @component.output_types(embedding=List[float]) + def run(self, text: str): + """ + Embed a string. + + :param text: The text to embed. + :return: The embeddings of the text. + """ + if not isinstance(text, str): + msg = ( + "OptimumTextEmbedder expects a string as an input. " + "In case you want to embed a list of Documents, please use the OptimumDocumentEmbedder." + ) + raise TypeError(msg) + + if self.embedding_backend is None: + msg = "The embedding model has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) + + text_to_embed = self.prefix + text + self.suffix + + embedding = self.embedding_backend.embed( + texts_to_embed=text_to_embed, normalize_embeddings=self.normalize_embeddings, pooling_mode=self.pooling_mode + ) + + return {"embedding": embedding} diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py b/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py new file mode 100644 index 000000000..8f0cddc22 --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py @@ -0,0 +1,137 @@ +import json +from enum import Enum +from typing import Optional + +import torch +from haystack.utils import Secret +from huggingface_hub import hf_hub_download +from sentence_transformers.models import Pooling as PoolingLayer + + +class PoolingMode(Enum): + """ + Pooling Modes support by the Optimum Embedders. + """ + + CLS = "cls" + MEAN = "mean" + MAX = "max" + MEAN_SQRT_LEN = "mean_sqrt_len" + WEIGHTED_MEAN = "weighted_mean" + LAST_TOKEN = "last_token" + + def __str__(self): + return self.value + + @classmethod + def from_str(cls, string: str) -> "PoolingMode": + """ + Create a pooling mode from a string. + + :param string: + The string to convert. + :returns: + The pooling mode. + """ + enum_map = {e.value: e for e in PoolingMode} + pooling_mode = enum_map.get(string) + if pooling_mode is None: + msg = f"Unknown Pooling mode '{string}'. Supported modes are: {list(enum_map.keys())}" + raise ValueError(msg) + return pooling_mode + + +POOLING_MODES_MAP = { + "pooling_mode_cls_token": PoolingMode.CLS, + "pooling_mode_mean_tokens": PoolingMode.MEAN, + "pooling_mode_max_tokens": PoolingMode.MAX, + "pooling_mode_mean_sqrt_len_tokens": PoolingMode.MEAN_SQRT_LEN, + "pooling_mode_weightedmean_tokens": PoolingMode.WEIGHTED_MEAN, + "pooling_mode_lasttoken": PoolingMode.LAST_TOKEN, +} + +INVERSE_POOLING_MODES_MAP = {mode: name for name, mode in POOLING_MODES_MAP.items()} + + +class HFPoolingMode: + """ + Gets the pooling mode of the model from the Hugging Face Hub. + """ + + @staticmethod + def get_pooling_mode(model: str, token: Optional[Secret] = None) -> Optional[PoolingMode]: + """ + Gets the pooling mode of the model from the Hugging Face Hub. + + :param model: + The model to get the pooling mode for. + :param token: + The HuggingFace token to use as HTTP bearer authorization. + :returns: + The pooling mode. + """ + try: + pooling_config_path = hf_hub_download(repo_id=model, token=token, filename="1_Pooling/config.json") + + with open(pooling_config_path) as f: + pooling_config = json.load(f) + + # Filter only those keys that start with "pooling_mode" and are True + true_pooling_modes = [ + key for key, value in pooling_config.items() if key.startswith("pooling_mode") and value + ] + + # If exactly one True pooling mode is found, return it + if len(true_pooling_modes) == 1: + pooling_mode_from_config = true_pooling_modes[0] + pooling_mode = POOLING_MODES_MAP.get(pooling_mode_from_config) + # If no True pooling modes or more than one True pooling mode is found, return None + else: + pooling_mode = None + return pooling_mode + except Exception as e: + msg = f"An error occurred while inferring the pooling mode from the model config: {e}" + raise ValueError(msg) from e + + +class Pooling: + """ + Class to manage pooling of the embeddings. + + :param pooling_mode: The pooling mode to use. + :param attention_mask: The attention mask of the tokenized text. + :param model_output: The output of the embedding model. + """ + + def __init__(self, pooling_mode: PoolingMode, attention_mask: torch.tensor, model_output: torch.tensor): + self.pooling_mode = pooling_mode + self.attention_mask = attention_mask + self.model_output = model_output + + def pool_embeddings(self) -> torch.tensor: + """ + Perform pooling on the output of the embedding model. + + :param pooling_mode: The pooling mode to use. + :param attention_mask: The attention mask of the tokenized text. + :param model_output: The output of the embedding model. + :return: The embeddings of the text after pooling. + """ + pooling_func_map = { + INVERSE_POOLING_MODES_MAP[self.pooling_mode]: True, + } + # By default, sentence-transformers uses mean pooling + # If multiple pooling methods are specified, the output dimension of the embeddings is scaled by the number of + # pooling methods selected + if self.pooling_mode != PoolingMode.MEAN: + pooling_func_map[INVERSE_POOLING_MODES_MAP[PoolingMode.MEAN]] = False + + # First element of model_output contains all token embeddings + token_embeddings = self.model_output[0] + word_embedding_dimension = token_embeddings.size(dim=2) + pooling = PoolingLayer(word_embedding_dimension=word_embedding_dimension, **pooling_func_map) + features = {"token_embeddings": token_embeddings, "attention_mask": self.attention_mask} + pooled_outputs = pooling.forward(features) + embeddings = pooled_outputs["sentence_embedding"] + + return embeddings diff --git a/integrations/optimum/tests/__init__.py b/integrations/optimum/tests/__init__.py new file mode 100644 index 000000000..6b5e14dc1 --- /dev/null +++ b/integrations/optimum/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/optimum/tests/test_optimum_backend.py b/integrations/optimum/tests/test_optimum_backend.py new file mode 100644 index 000000000..8ef61fd37 --- /dev/null +++ b/integrations/optimum/tests/test_optimum_backend.py @@ -0,0 +1,32 @@ +import pytest +from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend +from haystack_integrations.components.embedders.pooling import PoolingMode + + +@pytest.fixture +def backend(): + model = "sentence-transformers/all-mpnet-base-v2" + model_kwargs = {"model_id": model} + backend = OptimumEmbeddingBackend(model=model, model_kwargs=model_kwargs, token=None) + return backend + + +class TestOptimumBackend: + def test_embed_output_order(self, backend): + texts_to_embed = ["short text", "text that is longer than the other", "medium length text"] + embeddings = backend.embed(texts_to_embed, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN) + + # Compute individual embeddings in order + expected_embeddings = [] + for text in texts_to_embed: + expected_embeddings.append(backend.embed(text, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN)) + + # Assert that the embeddings are in the same order + assert embeddings == expected_embeddings + + def test_run_pooling_modes(self, backend): + for pooling_mode in PoolingMode: + embedding = backend.embed("test text", normalize_embeddings=False, pooling_mode=pooling_mode) + + assert len(embedding) == 768 + assert all(isinstance(x, float) for x in embedding) diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py new file mode 100644 index 000000000..f61fea1d3 --- /dev/null +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -0,0 +1,288 @@ +from unittest.mock import MagicMock, patch + +import pytest +from haystack.dataclasses import Document +from haystack.utils.auth import Secret +from haystack_integrations.components.embedders import OptimumDocumentEmbedder +from haystack_integrations.components.embedders.pooling import PoolingMode +from huggingface_hub.utils import RepositoryNotFoundError + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack_integrations.components.embedders.optimum_document_embedder.check_valid_model", + MagicMock(return_value=None), + ) as mock: + yield mock + + +@pytest.fixture +def mock_get_pooling_mode(): + with patch( + "haystack_integrations.components.embedders.optimum_text_embedder.HFPoolingMode.get_pooling_mode", + MagicMock(return_value=PoolingMode.MEAN), + ) as mock: + yield mock + + +class TestOptimumDocumentEmbedder: + def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") + embedder = OptimumDocumentEmbedder() + + assert embedder.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.normalize_embeddings is True + assert embedder.onnx_execution_provider == "CPUExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MEAN + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + assert embedder.model_kwargs == { + "model_id": "sentence-transformers/all-mpnet-base-v2", + "provider": "CPUExecutionProvider", + "use_auth_token": "fake-api-token", + } + + def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + token=Secret.from_token("fake-api-token"), + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + normalize_embeddings=False, + pooling_mode="max", + onnx_execution_provider="CUDAExecutionProvider", + model_kwargs={"trust_remote_code": True}, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.token == Secret.from_token("fake-api-token") + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + assert embedder.normalize_embeddings is False + assert embedder.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MAX + assert embedder.model_kwargs == { + "trust_remote_code": True, + "model_id": "sentence-transformers/all-minilm-l6-v2", + "provider": "CUDAExecutionProvider", + "use_auth_token": "fake-api-token", + } + + def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + component = OptimumDocumentEmbedder() + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.embedders.optimum_document_embedder.OptimumDocumentEmbedder", + "init_parameters": { + "model": "sentence-transformers/all-mpnet-base-v2", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "normalize_embeddings": True, + "onnx_execution_provider": "CPUExecutionProvider", + "pooling_mode": "mean", + "model_kwargs": { + "model_id": "sentence-transformers/all-mpnet-base-v2", + "provider": "CPUExecutionProvider", + "use_auth_token": None, + }, + }, + } + + def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + component = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + token=Secret.from_env_var("ENV_VAR", strict=False), + prefix="prefix", + suffix="suffix", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + normalize_embeddings=False, + onnx_execution_provider="CUDAExecutionProvider", + pooling_mode="max", + model_kwargs={"trust_remote_code": True}, + ) + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.embedders.optimum_document_embedder.OptimumDocumentEmbedder", + "init_parameters": { + "model": "sentence-transformers/all-minilm-l6-v2", + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "batch_size": 64, + "progress_bar": False, + "meta_fields_to_embed": ["test_field"], + "embedding_separator": " | ", + "normalize_embeddings": False, + "onnx_execution_provider": "CUDAExecutionProvider", + "pooling_mode": "max", + "model_kwargs": { + "trust_remote_code": True, + "model_id": "sentence-transformers/all-minilm-l6-v2", + "provider": "CUDAExecutionProvider", + "use_auth_token": None, + }, + }, + } + + def test_initialize_with_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + OptimumDocumentEmbedder(model="invalid_model_id") + + def test_initialize_with_invalid_pooling_mode(self, mock_check_valid_model): # noqa: ARG002 + mock_get_pooling_mode.side_effect = ValueError("Invalid pooling mode") + with pytest.raises(ValueError): + OptimumDocumentEmbedder( + model="sentence-transformers/all-mpnet-base-v2", pooling_mode="Invalid_pooling_mode" + ) + + def test_infer_pooling_mode_from_str(self): + """ + Test that the pooling mode is correctly inferred from a string. + The pooling mode is "mean" as per the model config. + """ + for pooling_mode in PoolingMode: + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=pooling_mode.value, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == pooling_mode + + @pytest.mark.integration + def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 + with pytest.raises(ValueError): + OptimumDocumentEmbedder( + model="embedding_model_finetuned", + pooling_mode=None, + ) + + @pytest.mark.integration + def test_infer_pooling_mode_from_hf(self): + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=None, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == PoolingMode.MEAN + + def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model): # noqa: ARG002 + documents = [ + Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + meta_fields_to_embed=["meta_field"], + embedding_separator=" | ", + pooling_mode="mean", + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "meta_value 0 | document number 0: content", + "meta_value 1 | document number 1: content", + "meta_value 2 | document number 2: content", + "meta_value 3 | document number 3: content", + "meta_value 4 | document number 4: content", + ] + + def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): # noqa: ARG002 + documents = [Document(content=f"document number {i}") for i in range(5)] + + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + prefix="my_prefix ", + suffix=" my_suffix", + pooling_mode="mean", + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "my_prefix document number 0 my_suffix", + "my_prefix document number 1 my_suffix", + "my_prefix document number 2 my_suffix", + "my_prefix document number 3 my_suffix", + "my_prefix document number 4 my_suffix", + ] + + def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumDocumentEmbedder(model="sentence-transformers/all-mpnet-base-v2", pooling_mode="mean") + embedder.warm_up() + # wrong formats + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="OptimumDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=string_input) + + with pytest.raises(TypeError, match="OptimumDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=list_integers_input) + + def test_run_on_empty_list(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + ) + embedder.warm_up() + empty_list_input = [] + result = embedder.run(documents=empty_list_input) + + assert result["documents"] is not None + assert not result["documents"] # empty list + + @pytest.mark.integration + def test_run(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + embedder = OptimumDocumentEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + batch_size=1, + ) + embedder.warm_up() + + result = embedder.run(documents=docs) + + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 768 + assert all(isinstance(x, float) for x in doc.embedding) diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py new file mode 100644 index 000000000..9932d1dbf --- /dev/null +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock, patch + +import pytest +from haystack.utils.auth import Secret +from haystack_integrations.components.embedders import OptimumTextEmbedder +from haystack_integrations.components.embedders.pooling import PoolingMode +from huggingface_hub.utils import RepositoryNotFoundError + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack_integrations.components.embedders.optimum_text_embedder.check_valid_model", + MagicMock(return_value=None), + ) as mock: + yield mock + + +@pytest.fixture +def mock_get_pooling_mode(): + with patch( + "haystack_integrations.components.embedders.optimum_text_embedder.HFPoolingMode.get_pooling_mode", + MagicMock(return_value=PoolingMode.MEAN), + ) as mock: + yield mock + + +class TestOptimumTextEmbedder: + def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") + embedder = OptimumTextEmbedder() + + assert embedder.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.normalize_embeddings is True + assert embedder.onnx_execution_provider == "CPUExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MEAN + assert embedder.model_kwargs == { + "model_id": "sentence-transformers/all-mpnet-base-v2", + "provider": "CPUExecutionProvider", + "use_auth_token": "fake-api-token", + } + + def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + token=Secret.from_token("fake-api-token"), + prefix="prefix", + suffix="suffix", + normalize_embeddings=False, + pooling_mode="max", + onnx_execution_provider="CUDAExecutionProvider", + model_kwargs={"trust_remote_code": True}, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.token == Secret.from_token("fake-api-token") + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.normalize_embeddings is False + assert embedder.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MAX + assert embedder.model_kwargs == { + "trust_remote_code": True, + "model_id": "sentence-transformers/all-minilm-l6-v2", + "provider": "CUDAExecutionProvider", + "use_auth_token": "fake-api-token", + } + + def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + component = OptimumTextEmbedder() + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.embedders.optimum_text_embedder.OptimumTextEmbedder", + "init_parameters": { + "model": "sentence-transformers/all-mpnet-base-v2", + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "", + "suffix": "", + "normalize_embeddings": True, + "onnx_execution_provider": "CPUExecutionProvider", + "pooling_mode": "mean", + "model_kwargs": { + "model_id": "sentence-transformers/all-mpnet-base-v2", + "provider": "CPUExecutionProvider", + "use_auth_token": None, + }, + }, + } + + def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # noqa: ARG002 + component = OptimumTextEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + token=Secret.from_env_var("ENV_VAR", strict=False), + prefix="prefix", + suffix="suffix", + normalize_embeddings=False, + onnx_execution_provider="CUDAExecutionProvider", + pooling_mode="max", + model_kwargs={"trust_remote_code": True}, + ) + data = component.to_dict() + + assert data == { + "type": "haystack_integrations.components.embedders.optimum_text_embedder.OptimumTextEmbedder", + "init_parameters": { + "model": "sentence-transformers/all-minilm-l6-v2", + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "normalize_embeddings": False, + "onnx_execution_provider": "CUDAExecutionProvider", + "pooling_mode": "max", + "model_kwargs": { + "trust_remote_code": True, + "model_id": "sentence-transformers/all-minilm-l6-v2", + "provider": "CUDAExecutionProvider", + "use_auth_token": None, + }, + }, + } + + def test_initialize_with_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + OptimumTextEmbedder(model="invalid_model_id", pooling_mode="max") + + def test_initialize_with_invalid_pooling_mode(self, mock_check_valid_model): # noqa: ARG002 + mock_get_pooling_mode.side_effect = ValueError("Invalid pooling mode") + with pytest.raises(ValueError): + OptimumTextEmbedder(model="sentence-transformers/all-mpnet-base-v2", pooling_mode="Invalid_pooling_mode") + + def test_infer_pooling_mode_from_str(self): + """ + Test that the pooling mode is correctly inferred from a string. + The pooling mode is "mean" as per the model config. + """ + for pooling_mode in PoolingMode: + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=pooling_mode.value, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == pooling_mode + + @pytest.mark.integration + def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 + with pytest.raises(ValueError): + OptimumTextEmbedder( + model="embedding_model_finetuned", + pooling_mode=None, + ) + + @pytest.mark.integration + def test_infer_pooling_mode_from_hf(self): + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-minilm-l6-v2", + pooling_mode=None, + ) + + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.pooling_mode == PoolingMode.MEAN + + def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + token=Secret.from_token("fake-api-token"), + pooling_mode="mean", + ) + embedder.warm_up() + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="OptimumTextEmbedder expects a string as an input"): + embedder.run(text=list_integers_input) + + @pytest.mark.integration + def test_run(self): + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + prefix="prefix ", + suffix=" suffix", + ) + embedder.warm_up() + + result = embedder.run(text="The food was delicious") + + assert len(result["embedding"]) == 768 + assert all(isinstance(x, float) for x in result["embedding"])