diff --git a/integrations/optimum/pydoc/config.yml b/integrations/optimum/pydoc/config.yml index 979bb389f..617eb4aed 100644 --- a/integrations/optimum/pydoc/config.yml +++ b/integrations/optimum/pydoc/config.yml @@ -14,8 +14,6 @@ processors: documented_only: true do_not_filter_modules: false skip_empty_modules: true - - type: filter - expression: "name not in ['Pooling', 'POOLING_MODES_MAP', 'INVERSE_POOLING_MODES_MAP', 'HFPoolingMode']" - type: smart - type: crossref renderer: diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py index 4e5ac1535..e2ab2d6b7 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py @@ -4,5 +4,6 @@ from .optimum_document_embedder import OptimumDocumentEmbedder from .optimum_text_embedder import OptimumTextEmbedder +from .pooling import OptimumEmbedderPooling -__all__ = ["OptimumDocumentEmbedder", "OptimumTextEmbedder"] +__all__ = ["OptimumDocumentEmbedder", "OptimumEmbedderPooling", "OptimumTextEmbedder"] diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py new file mode 100644 index 000000000..fc4f0b1ae --- /dev/null +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/_backend.py @@ -0,0 +1,198 @@ +import copy +import json +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs +from huggingface_hub import hf_hub_download +from sentence_transformers.models import Pooling as SentenceTransformerPoolingLayer +from tqdm import tqdm +from transformers import AutoTokenizer + +from optimum.onnxruntime import ORTModelForFeatureExtraction + +from .pooling import OptimumEmbedderPooling + + +@dataclass +class _EmbedderParams: + model: str + token: Optional[Secret] + prefix: str + suffix: str + normalize_embeddings: bool + onnx_execution_provider: str + batch_size: int + progress_bar: bool + pooling_mode: Optional[Union[str, OptimumEmbedderPooling]] + model_kwargs: Optional[Dict[str, Any]] + + def serialize(self) -> Dict[str, Any]: + out = {} + for field in self.__dataclass_fields__.keys(): + out[field] = copy.deepcopy(getattr(self, field)) + + # Fixups. + assert isinstance(self.pooling_mode, OptimumEmbedderPooling) + out["pooling_mode"] = self.pooling_mode.value + out["token"] = self.token.to_dict() if self.token else None + out["model_kwargs"].pop("use_auth_token", None) + serialize_hf_model_kwargs(out["model_kwargs"]) + return out + + @classmethod + def deserialize_inplace(cls, data: Dict[str, Any]) -> Dict[str, Any]: + data["pooling_mode"] = OptimumEmbedderPooling.from_str(data["pooling_mode"]) + deserialize_secrets_inplace(data, keys=["token"]) + deserialize_hf_model_kwargs(data["model_kwargs"]) + return data + + +class _EmbedderBackend: + def __init__(self, params: _EmbedderParams): + check_valid_model(params.model, HFModelType.EMBEDDING, params.token) + resolved_token = params.token.resolve_value() if params.token else None + + if isinstance(params.pooling_mode, str): + params.pooling_mode = OptimumEmbedderPooling.from_str(params.pooling_mode) + elif params.pooling_mode is None: + params.pooling_mode = _pooling_from_model_config(params.model, resolved_token) + + if params.pooling_mode is None: + modes = {e.value: e for e in OptimumEmbedderPooling} + msg = ( + f"Pooling mode not found in model config and not specified by user." + f" Supported modes are: {list(modes.keys())}" + ) + raise ValueError(msg) + + params.model_kwargs = params.model_kwargs or {} + + # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters + params.model_kwargs.setdefault("model_id", params.model) + params.model_kwargs.setdefault("provider", params.onnx_execution_provider) + params.model_kwargs.setdefault("use_auth_token", resolved_token) + + self.params = params + self.model = None + self.tokenizer = None + self.pooling_layer = None + + def warm_up(self): + self.model = ORTModelForFeatureExtraction.from_pretrained(**self.params.model_kwargs, export=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.params.model, token=self.params.token.resolve_value() if self.params.token else None + ) + + # We need the width of the embeddings to initialize the pooling layer + # so we do a dummy forward pass with the model. + dummy_input = self.tokenizer(["dummy input"], padding=True, truncation=True, return_tensors="pt").to( + self.model.device + ) + dummy_output = self.model(input_ids=dummy_input["input_ids"], attention_mask=dummy_input["attention_mask"]) + width = dummy_output[0].size(dim=2) # BaseModelOutput.last_hidden_state + + self.pooling_layer = SentenceTransformerPoolingLayer( + width, + pooling_mode_cls_token=self.params.pooling_mode == OptimumEmbedderPooling.CLS, + pooling_mode_max_tokens=self.params.pooling_mode == OptimumEmbedderPooling.MAX, + pooling_mode_mean_tokens=self.params.pooling_mode == OptimumEmbedderPooling.MEAN, + pooling_mode_mean_sqrt_len_tokens=self.params.pooling_mode == OptimumEmbedderPooling.MEAN_SQRT_LEN, + pooling_mode_weightedmean_tokens=self.params.pooling_mode == OptimumEmbedderPooling.WEIGHTED_MEAN, + pooling_mode_lasttoken=self.params.pooling_mode == OptimumEmbedderPooling.LAST_TOKEN, + ) + + @property + def parameters(self) -> _EmbedderParams: + return self.params + + def pool_embeddings(self, model_output: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + assert self.pooling_layer is not None + features = {"token_embeddings": model_output, "attention_mask": attention_mask} + pooled_outputs = self.pooling_layer.forward(features) + return pooled_outputs["sentence_embedding"] + + def embed_texts( + self, + texts_to_embed: Union[str, List[str]], + ) -> Union[List[List[float]], List[float]]: + assert self.model is not None + assert self.tokenizer is not None + + if isinstance(texts_to_embed, str): + texts = [texts_to_embed] + else: + texts = texts_to_embed + + device = self.model.device + + # Sorting by length + length_sorted_idx = np.argsort([-len(sen) for sen in texts]) + sentences_sorted = [texts[idx] for idx in length_sorted_idx] + + all_embeddings = [] + for i in tqdm( + range(0, len(sentences_sorted), self.params.batch_size), + disable=not self.params.progress_bar, + desc="Calculating embeddings", + ): + batch = sentences_sorted[i : i + self.params.batch_size] + encoded_input = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device) + model_output = self.model( + input_ids=encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"] + ) + sentence_embeddings = self.pool_embeddings(model_output[0], encoded_input["attention_mask"].to(device)) + all_embeddings.append(sentence_embeddings) + + embeddings = torch.cat(all_embeddings, dim=0) + + if self.params.normalize_embeddings: + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + + embeddings = embeddings.tolist() + + # Reorder embeddings according to original order + reordered_embeddings: List[List[float]] = [None] * len(texts) # type: ignore + for embedding, idx in zip(embeddings, length_sorted_idx): + reordered_embeddings[idx] = embedding + + if isinstance(texts_to_embed, str): + return reordered_embeddings[0] + else: + return reordered_embeddings + + +def _pooling_from_model_config(model: str, token: Optional[str] = None) -> Optional[OptimumEmbedderPooling]: + try: + pooling_config_path = hf_hub_download(repo_id=model, token=token, filename="1_Pooling/config.json") + except Exception as e: + msg = f"An error occurred while downloading the model config: {e}" + raise ValueError(msg) from e + + with open(pooling_config_path) as f: + pooling_config = json.load(f) + + # Filter only those keys that start with "pooling_mode" and are True + true_pooling_modes = [key for key, value in pooling_config.items() if key.startswith("pooling_mode") and value] + + # If exactly one True pooling mode is found, return it + # If no True pooling modes or more than one True pooling mode is found, return None + if len(true_pooling_modes) == 1: + pooling_mode_from_config = true_pooling_modes[0] + pooling_mode = _POOLING_MODES_MAP.get(pooling_mode_from_config) + else: + pooling_mode = None + return pooling_mode + + +_POOLING_MODES_MAP = { + "pooling_mode_cls_token": OptimumEmbedderPooling.CLS, + "pooling_mode_mean_tokens": OptimumEmbedderPooling.MEAN, + "pooling_mode_max_tokens": OptimumEmbedderPooling.MAX, + "pooling_mode_mean_sqrt_len_tokens": OptimumEmbedderPooling.MEAN_SQRT_LEN, + "pooling_mode_weightedmean_tokens": OptimumEmbedderPooling.WEIGHTED_MEAN, + "pooling_mode_lasttoken": OptimumEmbedderPooling.LAST_TOKEN, +} diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_backend.py deleted file mode 100644 index f7d7ce7be..000000000 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_backend.py +++ /dev/null @@ -1,101 +0,0 @@ -from typing import Any, Dict, List, Optional, Union - -import numpy as np -import torch -from haystack.utils.auth import Secret -from tqdm import tqdm -from transformers import AutoTokenizer - -from optimum.onnxruntime import ORTModelForFeatureExtraction - -from .pooling import Pooling, PoolingMode - - -class OptimumEmbeddingBackend: - """ - Class to manage Optimum embeddings. - """ - - def __init__(self, model: str, model_kwargs: Dict[str, Any], token: Optional[Secret] = None): - """ - Create an instance of OptimumEmbeddingBackend. - - :param model: A string representing the model id on HF Hub. - :param model_kwargs: Keyword arguments to pass to the model. - :param token: The HuggingFace token to use as HTTP bearer authorization. - """ - # export=True converts the model to ONNX on the fly - self.model = ORTModelForFeatureExtraction.from_pretrained(**model_kwargs, export=True) - self.tokenizer = AutoTokenizer.from_pretrained(model, token=token) - - def embed( - self, - texts_to_embed: Union[str, List[str]], - normalize_embeddings: bool, - pooling_mode: PoolingMode = PoolingMode.MEAN, - progress_bar: bool = False, - batch_size: int = 1, - ) -> Union[List[List[float]], List[float]]: - """ - Embed text or list of texts using the Optimum model. - - :param texts_to_embed: The text or list of texts to embed. - :param normalize_embeddings: Whether to normalize the embeddings to unit length. - :param pooling_mode: The pooling mode to use. - :param progress_bar: Whether to show a progress bar or not. - :param batch_size: Batch size to use. - :return: A single embedding if the input is a single string. A list of embeddings if the input is a list of - strings. - """ - if isinstance(texts_to_embed, str): - texts = [texts_to_embed] - else: - texts = texts_to_embed - - # Determine device for tokenizer output - device = self.model.device - - # Sorting by length - length_sorted_idx = np.argsort([-len(sen) for sen in texts]) - sentences_sorted = [texts[idx] for idx in length_sorted_idx] - - all_embeddings = [] - for i in tqdm( - range(0, len(sentences_sorted), batch_size), disable=not progress_bar, desc="Calculating embeddings" - ): - batch = sentences_sorted[i : i + batch_size] - encoded_input = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(device) - - # Only pass required inputs otherwise onnxruntime can raise an error - inputs_to_remove = set(encoded_input.keys()).difference(self.model.inputs_names) - for key in inputs_to_remove: - encoded_input.pop(key) - model_output = self.model(**encoded_input) - - # Pool Embeddings - pooling = Pooling( - pooling_mode=pooling_mode, - attention_mask=encoded_input["attention_mask"].to(device), - model_output=model_output, - ) - sentence_embeddings = pooling.pool_embeddings() - all_embeddings.append(sentence_embeddings) - - embeddings = torch.cat(all_embeddings, dim=0) - - # Normalize all embeddings - if normalize_embeddings: - embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) - - embeddings = embeddings.tolist() - - # Reorder embeddings according to original order - reordered_embeddings: List[List[float]] = [[]] * len(texts) - for embedding, idx in zip(embeddings, length_sorted_idx): - reordered_embeddings[idx] = embedding - - if isinstance(texts_to_embed, str): - # Return the embedding if only one text was passed - return reordered_embeddings[0] - - return reordered_embeddings diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py index 7a9caad71..2f49bd0b3 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py @@ -1,11 +1,10 @@ from typing import Any, Dict, List, Optional, Union from haystack import Document, component, default_from_dict, default_to_dict -from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs +from haystack.utils import Secret -from .optimum_backend import OptimumEmbeddingBackend -from .pooling import HFPoolingMode, PoolingMode +from ._backend import _EmbedderBackend, _EmbedderParams +from .pooling import OptimumEmbedderPooling @component @@ -53,7 +52,7 @@ def __init__( suffix: str = "", normalize_embeddings: bool = True, onnx_execution_provider: str = "CPUExecutionProvider", - pooling_mode: Optional[Union[str, PoolingMode]] = None, + pooling_mode: Optional[Union[str, OptimumEmbedderPooling]] = None, model_kwargs: Optional[Dict[str, Any]] = None, batch_size: int = 32, progress_bar: bool = True, @@ -90,18 +89,7 @@ def __init__( ) ``` :param pooling_mode: The pooling mode to use. When None, pooling mode will be inferred from the model config. - The supported pooling modes are: - - "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text - representations. - - "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all - the tokens. - - "mean": Perform Mean Pooling on the output of the embedding model. - - "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt - (input_length). - - "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See - https://arxiv.org/abs/2202.08904. - - "last_token": Perform Last Token Pooling on the output of the embedding model. See - https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. + Refer to the OptimumEmbedderPooling enum for supported pooling modes. :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization parameters. @@ -111,92 +99,49 @@ def __init__( :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text. :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - check_valid_model(model, HFModelType.EMBEDDING, token) - self.model = model - - self.token = token - resolved_token = token.resolve_value() if token else None - - self.pooling_mode: Optional[PoolingMode] = None - if isinstance(pooling_mode, PoolingMode): - self.pooling_mode = pooling_mode - elif isinstance(pooling_mode, str): - self.pooling_mode = PoolingMode.from_str(pooling_mode) - else: - self.pooling_mode = HFPoolingMode.get_pooling_mode(model, resolved_token) - - # Raise error if pooling mode is not found in model config and not specified by user - if self.pooling_mode is None: - modes = {e.value: e for e in PoolingMode} - msg = ( - f"Pooling mode not found in model config and not specified by user." - f" Supported modes are: {list(modes.keys())}" - ) - raise ValueError(msg) - - self.prefix = prefix - self.suffix = suffix - self.normalize_embeddings = normalize_embeddings - self.onnx_execution_provider = onnx_execution_provider - self.batch_size = batch_size - self.progress_bar = progress_bar + params = _EmbedderParams( + model=model, + token=token, + prefix=prefix, + suffix=suffix, + normalize_embeddings=normalize_embeddings, + onnx_execution_provider=onnx_execution_provider, + batch_size=batch_size, + progress_bar=progress_bar, + pooling_mode=pooling_mode, + model_kwargs=model_kwargs, + ) self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator - model_kwargs = model_kwargs or {} - - # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters - model_kwargs.setdefault("model_id", model) - model_kwargs.setdefault("provider", onnx_execution_provider) - model_kwargs.setdefault("use_auth_token", resolved_token) - - self.model_kwargs = model_kwargs - self.embedding_backend = None + self._backend = _EmbedderBackend(params) + self._initialized = False def warm_up(self): """ Load the embedding backend. """ - if self.embedding_backend is None: - self.embedding_backend = OptimumEmbeddingBackend( - model=self.model, token=self.token, model_kwargs=self.model_kwargs - ) + if self._initialized: + return + + self._backend.warm_up() + self._initialized = True def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ - assert self.pooling_mode is not None - serialization_dict = default_to_dict( - self, - model=self.model, - prefix=self.prefix, - suffix=self.suffix, - normalize_embeddings=self.normalize_embeddings, - onnx_execution_provider=self.onnx_execution_provider, - pooling_mode=self.pooling_mode.value, - batch_size=self.batch_size, - progress_bar=self.progress_bar, - meta_fields_to_embed=self.meta_fields_to_embed, - embedding_separator=self.embedding_separator, - model_kwargs=self.model_kwargs, - token=self.token.to_dict() if self.token else None, - ) - - model_kwargs = serialization_dict["init_parameters"]["model_kwargs"] - model_kwargs.pop("use_auth_token", None) - - serialize_hf_model_kwargs(model_kwargs) - return serialization_dict + init_params = self._backend.parameters.serialize() + init_params["meta_fields_to_embed"] = self.meta_fields_to_embed + init_params["embedding_separator"] = self.embedding_separator + return default_to_dict(self, **init_params) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "OptimumDocumentEmbedder": """ Deserialize this component from a dictionary. """ - data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"]) - deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) - deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"]) + _EmbedderParams.deserialize_inplace(data["init_parameters"]) return default_from_dict(cls, data) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: @@ -210,7 +155,9 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: ] text_to_embed = ( - self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + self._backend.parameters.prefix + + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + + self._backend.parameters.suffix ) texts_to_embed.append(text_to_embed) @@ -225,6 +172,9 @@ def run(self, documents: List[Document]): :param documents: A list of Documents to embed. :return: A dictionary containing the updated Documents with their embeddings. """ + if not self._initialized: + msg = "The embedding model has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): msg = ( "OptimumDocumentEmbedder expects a list of Documents as input." @@ -232,24 +182,12 @@ def run(self, documents: List[Document]): ) raise TypeError(msg) - if self.embedding_backend is None: - msg = "The embedding model has not been loaded. Please call warm_up() before running." - raise RuntimeError(msg) - # Return empty list if no documents if not documents: return {"documents": []} texts_to_embed = self._prepare_texts_to_embed(documents=documents) - - embeddings = self.embedding_backend.embed( - texts_to_embed=texts_to_embed, - normalize_embeddings=self.normalize_embeddings, - pooling_mode=self.pooling_mode, - progress_bar=self.progress_bar, - batch_size=self.batch_size, - ) - + embeddings = self._backend.embed_texts(texts_to_embed) for doc, emb in zip(documents, embeddings): doc.embedding = emb diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py index a20d7b5f8..64454bf9f 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py @@ -1,11 +1,10 @@ from typing import Any, Dict, List, Optional, Union from haystack import component, default_from_dict, default_to_dict -from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs +from haystack.utils import Secret -from .optimum_backend import OptimumEmbeddingBackend -from .pooling import HFPoolingMode, PoolingMode +from ._backend import _EmbedderBackend, _EmbedderParams +from .pooling import OptimumEmbedderPooling @component @@ -49,7 +48,7 @@ def __init__( suffix: str = "", normalize_embeddings: bool = True, onnx_execution_provider: str = "CPUExecutionProvider", - pooling_mode: Optional[Union[str, PoolingMode]] = None, + pooling_mode: Optional[Union[str, OptimumEmbedderPooling]] = None, model_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -82,100 +81,52 @@ def __init__( ) ``` :param pooling_mode: The pooling mode to use. When None, pooling mode will be inferred from the model config. - The supported pooling modes are: - - "cls": Perform CLS Pooling on the output of the embedding model. Uses the first token (CLS token) as text - representations. - - "max": Perform Max Pooling on the output of the embedding model. Uses max in each dimension over all - the tokens. - - "mean": Perform Mean Pooling on the output of the embedding model. - - "mean_sqrt_len": Perform mean-pooling on the output of the embedding model, but divide by sqrt - (input_length). - - "weighted_mean": Perform Weighted (position) Mean Pooling on the output of the embedding model. See - https://arxiv.org/abs/2202.08904. - - "last_token": Perform Last Token Pooling on the output of the embedding model. See - https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. + Refer to the OptimumEmbedderPooling enum for supported pooling modes. :param model_kwargs: Dictionary containing additional keyword arguments to pass to the model. In case of duplication, these kwargs override `model`, `onnx_execution_provider`, and `token` initialization parameters. """ - check_valid_model(model, HFModelType.EMBEDDING, token) - self.model = model - - self.token = token - resolved_token = token.resolve_value() if token else None - - self.pooling_mode: Optional[PoolingMode] = None - if isinstance(pooling_mode, PoolingMode): - self.pooling_mode = pooling_mode - elif isinstance(pooling_mode, str): - self.pooling_mode = PoolingMode.from_str(pooling_mode) - else: - self.pooling_mode = HFPoolingMode.get_pooling_mode(model, resolved_token) - - # Raise error if pooling mode is not found in model config and not specified by user - if self.pooling_mode is None: - modes = {e.value: e for e in PoolingMode} - msg = ( - f"Pooling mode not found in model config and not specified by user." - f" Supported modes are: {list(modes.keys())}" - ) - raise ValueError(msg) - - self.prefix = prefix - self.suffix = suffix - self.normalize_embeddings = normalize_embeddings - self.onnx_execution_provider = onnx_execution_provider - - model_kwargs = model_kwargs or {} - - # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters - model_kwargs.setdefault("model_id", model) - model_kwargs.setdefault("provider", onnx_execution_provider) - model_kwargs.setdefault("use_auth_token", resolved_token) - - self.model_kwargs = model_kwargs - self.embedding_backend = None + params = _EmbedderParams( + model=model, + token=token, + prefix=prefix, + suffix=suffix, + normalize_embeddings=normalize_embeddings, + onnx_execution_provider=onnx_execution_provider, + batch_size=1, + progress_bar=False, + pooling_mode=pooling_mode, + model_kwargs=model_kwargs, + ) + self._backend = _EmbedderBackend(params) + self._initialized = False def warm_up(self): """ Load the embedding backend. """ - if self.embedding_backend is None: - self.embedding_backend = OptimumEmbeddingBackend( - model=self.model, token=self.token, model_kwargs=self.model_kwargs - ) + if self._initialized: + return + + self._backend.warm_up() + self._initialized = True def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ - assert self.pooling_mode is not None - serialization_dict = default_to_dict( - self, - model=self.model, - prefix=self.prefix, - suffix=self.suffix, - normalize_embeddings=self.normalize_embeddings, - onnx_execution_provider=self.onnx_execution_provider, - pooling_mode=self.pooling_mode.value, - model_kwargs=self.model_kwargs, - token=self.token.to_dict() if self.token else None, - ) - - model_kwargs = serialization_dict["init_parameters"]["model_kwargs"] - model_kwargs.pop("use_auth_token", None) - - serialize_hf_model_kwargs(model_kwargs) - return serialization_dict + init_params = self._backend.parameters.serialize() + # Remove init params that are not provided to the text embedder. + init_params.pop("batch_size") + init_params.pop("progress_bar") + return default_to_dict(self, **init_params) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "OptimumTextEmbedder": """ Deserialize this component from a dictionary. """ - data["init_parameters"]["pooling_mode"] = PoolingMode.from_str(data["init_parameters"]["pooling_mode"]) - deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) - deserialize_hf_model_kwargs(data["init_parameters"]["model_kwargs"]) + _EmbedderParams.deserialize_inplace(data["init_parameters"]) return default_from_dict(cls, data) @component.output_types(embedding=List[float]) @@ -186,6 +137,10 @@ def run(self, text: str): :param text: The text to embed. :return: The embeddings of the text. """ + if not self._initialized: + msg = "The embedding model has not been loaded. Please call warm_up() before running." + raise RuntimeError(msg) + if not isinstance(text, str): msg = ( "OptimumTextEmbedder expects a string as an input. " @@ -193,14 +148,6 @@ def run(self, text: str): ) raise TypeError(msg) - if self.embedding_backend is None: - msg = "The embedding model has not been loaded. Please call warm_up() before running." - raise RuntimeError(msg) - - text_to_embed = self.prefix + text + self.suffix - - embedding = self.embedding_backend.embed( - texts_to_embed=text_to_embed, normalize_embeddings=self.normalize_embeddings, pooling_mode=self.pooling_mode - ) - + text_to_embed = self._backend.parameters.prefix + text + self._backend.parameters.suffix + embedding = self._backend.embed_texts(text_to_embed) return {"embedding": embedding} diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py index df8f8fb4e..c4d195b8e 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py @@ -1,29 +1,39 @@ -import json from enum import Enum -from typing import Optional -import torch -from huggingface_hub import hf_hub_download -from sentence_transformers.models import Pooling as PoolingLayer - -class PoolingMode(Enum): +class OptimumEmbedderPooling(Enum): """ Pooling Modes support by the Optimum Embedders. """ + #: Perform CLS Pooling on the output of the embedding model + #: using the first token (CLS token). CLS = "cls" + + #: Perform Mean Pooling on the output of the embedding model. MEAN = "mean" + + #: Perform Max Pooling on the output of the embedding model + #: using the maximum value in each dimension over all the tokens. MAX = "max" + + #: Perform mean-pooling on the output of the embedding model but + #: divide by the square root of the sequence length. MEAN_SQRT_LEN = "mean_sqrt_len" + + #: Perform weighted (position) mean pooling on the output of the + #: embedding model. See https://arxiv.org/abs/2202.08904. WEIGHTED_MEAN = "weighted_mean" + + #: Perform Last Token Pooling on the output of the embedding model. + #: See https://arxiv.org/abs/2202.08904 & https://arxiv.org/abs/2201.10005. LAST_TOKEN = "last_token" def __str__(self): return self.value @classmethod - def from_str(cls, string: str) -> "PoolingMode": + def from_str(cls, string: str) -> "OptimumEmbedderPooling": """ Create a pooling mode from a string. @@ -32,102 +42,9 @@ def from_str(cls, string: str) -> "PoolingMode": :returns: The pooling mode. """ - enum_map = {e.value: e for e in PoolingMode} + enum_map = {e.value: e for e in OptimumEmbedderPooling} pooling_mode = enum_map.get(string) if pooling_mode is None: msg = f"Unknown Pooling mode '{string}'. Supported modes are: {list(enum_map.keys())}" raise ValueError(msg) return pooling_mode - - -POOLING_MODES_MAP = { - "pooling_mode_cls_token": PoolingMode.CLS, - "pooling_mode_mean_tokens": PoolingMode.MEAN, - "pooling_mode_max_tokens": PoolingMode.MAX, - "pooling_mode_mean_sqrt_len_tokens": PoolingMode.MEAN_SQRT_LEN, - "pooling_mode_weightedmean_tokens": PoolingMode.WEIGHTED_MEAN, - "pooling_mode_lasttoken": PoolingMode.LAST_TOKEN, -} - -INVERSE_POOLING_MODES_MAP = {mode: name for name, mode in POOLING_MODES_MAP.items()} - - -class HFPoolingMode: - """ - Gets the pooling mode of the model from the Hugging Face Hub. - """ - - @staticmethod - def get_pooling_mode(model: str, token: Optional[str] = None) -> Optional[PoolingMode]: - """ - Gets the pooling mode of the model from the Hugging Face Hub. - - :param model: - The model to get the pooling mode for. - :param token: - The HuggingFace token to use as HTTP bearer authorization. - :returns: - The pooling mode. - """ - try: - pooling_config_path = hf_hub_download(repo_id=model, token=token, filename="1_Pooling/config.json") - - with open(pooling_config_path) as f: - pooling_config = json.load(f) - - # Filter only those keys that start with "pooling_mode" and are True - true_pooling_modes = [ - key for key, value in pooling_config.items() if key.startswith("pooling_mode") and value - ] - - # If exactly one True pooling mode is found, return it - if len(true_pooling_modes) == 1: - pooling_mode_from_config = true_pooling_modes[0] - pooling_mode = POOLING_MODES_MAP.get(pooling_mode_from_config) - # If no True pooling modes or more than one True pooling mode is found, return None - else: - pooling_mode = None - return pooling_mode - except Exception as e: - msg = f"An error occurred while inferring the pooling mode from the model config: {e}" - raise ValueError(msg) from e - - -class Pooling: - """ - Class to manage pooling of the embeddings. - - :param pooling_mode: The pooling mode to use. - :param attention_mask: The attention mask of the tokenized text. - :param model_output: The output of the embedding model. - """ - - def __init__(self, pooling_mode: PoolingMode, attention_mask: torch.Tensor, model_output: torch.Tensor): - self.pooling_mode = pooling_mode - self.attention_mask = attention_mask - self.model_output = model_output - - def pool_embeddings(self) -> torch.Tensor: - """ - Perform pooling on the output of the embedding model. - - :return: The embeddings of the text after pooling. - """ - pooling_func_map = { - INVERSE_POOLING_MODES_MAP[self.pooling_mode]: True, - } - # By default, sentence-transformers uses mean pooling - # If multiple pooling methods are specified, the output dimension of the embeddings is scaled by the number of - # pooling methods selected - if self.pooling_mode != PoolingMode.MEAN: - pooling_func_map[INVERSE_POOLING_MODES_MAP[PoolingMode.MEAN]] = False - - # First element of model_output contains all token embeddings - token_embeddings = self.model_output[0] - word_embedding_dimension = token_embeddings.size(dim=2) - pooling = PoolingLayer(word_embedding_dimension=word_embedding_dimension, **pooling_func_map) - features = {"token_embeddings": token_embeddings, "attention_mask": self.attention_mask} - pooled_outputs = pooling.forward(features) - embeddings = pooled_outputs["sentence_embedding"] - - return embeddings diff --git a/integrations/optimum/tests/test_optimum_backend.py b/integrations/optimum/tests/test_optimum_backend.py deleted file mode 100644 index cfa2ac08d..000000000 --- a/integrations/optimum/tests/test_optimum_backend.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest -from haystack_integrations.components.embedders.optimum.optimum_backend import OptimumEmbeddingBackend -from haystack_integrations.components.embedders.optimum.pooling import PoolingMode - - -@pytest.fixture -def backend(): - model = "sentence-transformers/all-mpnet-base-v2" - model_kwargs = {"model_id": model} - backend = OptimumEmbeddingBackend(model=model, model_kwargs=model_kwargs, token=None) - return backend - - -class TestOptimumBackend: - def test_embed_output_order(self, backend): - texts_to_embed = ["short text", "text that is longer than the other", "medium length text"] - embeddings = backend.embed(texts_to_embed, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN) - - # Compute individual embeddings in order - expected_embeddings = [] - for text in texts_to_embed: - expected_embeddings.append(backend.embed(text, normalize_embeddings=False, pooling_mode=PoolingMode.MEAN)) - - # Assert that the embeddings are in the same order - assert embeddings == expected_embeddings - - def test_run_pooling_modes(self, backend): - for pooling_mode in PoolingMode: - embedding = backend.embed("test text", normalize_embeddings=False, pooling_mode=pooling_mode) - - assert len(embedding) == 768 - assert all(isinstance(x, float) for x in embedding) diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py index 3c8cc90b3..bcbccd533 100644 --- a/integrations/optimum/tests/test_optimum_document_embedder.py +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -4,14 +4,15 @@ from haystack.dataclasses import Document from haystack.utils.auth import Secret from haystack_integrations.components.embedders.optimum import OptimumDocumentEmbedder -from haystack_integrations.components.embedders.optimum.pooling import PoolingMode +from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from huggingface_hub.utils import RepositoryNotFoundError +import copy @pytest.fixture def mock_check_valid_model(): with patch( - "haystack_integrations.components.embedders.optimum.optimum_document_embedder.check_valid_model", + "haystack_integrations.components.embedders.optimum._backend.check_valid_model", MagicMock(return_value=None), ) as mock: yield mock @@ -20,8 +21,8 @@ def mock_check_valid_model(): @pytest.fixture def mock_get_pooling_mode(): with patch( - "haystack_integrations.components.embedders.optimum.optimum_text_embedder.HFPoolingMode.get_pooling_mode", - MagicMock(return_value=PoolingMode.MEAN), + "haystack_integrations.components.embedders.optimum._backend._pooling_from_model_config", + MagicMock(return_value=OptimumEmbedderPooling.MEAN), ) as mock: yield mock @@ -31,18 +32,18 @@ def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_poolin monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") embedder = OptimumDocumentEmbedder() - assert embedder.model == "sentence-transformers/all-mpnet-base-v2" - assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.normalize_embeddings is True - assert embedder.onnx_execution_provider == "CPUExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MEAN - assert embedder.batch_size == 32 - assert embedder.progress_bar is True + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder._backend.parameters.prefix == "" + assert embedder._backend.parameters.suffix == "" + assert embedder._backend.parameters.normalize_embeddings is True + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN + assert embedder._backend.parameters.batch_size == 32 + assert embedder._backend.parameters.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" - assert embedder.model_kwargs == { + assert embedder._backend.parameters.model_kwargs == { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", "use_auth_token": "fake-api-token", @@ -64,18 +65,18 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 model_kwargs={"trust_remote_code": True}, ) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.token == Secret.from_token("fake-api-token") - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 - assert embedder.progress_bar is False + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.token == Secret.from_token("fake-api-token") + assert embedder._backend.parameters.prefix == "prefix" + assert embedder._backend.parameters.suffix == "suffix" + assert embedder._backend.parameters.batch_size == 64 + assert embedder._backend.parameters.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " - assert embedder.normalize_embeddings is False - assert embedder.onnx_execution_provider == "CUDAExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MAX - assert embedder.model_kwargs == { + assert embedder._backend.parameters.normalize_embeddings is False + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX + assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", @@ -108,18 +109,18 @@ def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): } embedder = OptimumDocumentEmbedder.from_dict(data) - assert embedder.model == "sentence-transformers/all-mpnet-base-v2" - assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.normalize_embeddings is True - assert embedder.onnx_execution_provider == "CPUExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MEAN - assert embedder.batch_size == 32 - assert embedder.progress_bar is True + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder._backend.parameters.prefix == "" + assert embedder._backend.parameters.suffix == "" + assert embedder._backend.parameters.normalize_embeddings is True + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN + assert embedder._backend.parameters.batch_size == 32 + assert embedder._backend.parameters.progress_bar is True assert embedder.meta_fields_to_embed == [] assert embedder.embedding_separator == "\n" - assert embedder.model_kwargs == { + assert embedder._backend.parameters.model_kwargs == { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", "use_auth_token": None, @@ -167,18 +168,18 @@ def test_to_and_from_dict_with_custom_init_parameters( } embedder = OptimumDocumentEmbedder.from_dict(data) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False) - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 64 - assert embedder.progress_bar is False + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("ENV_VAR", strict=False) + assert embedder._backend.parameters.prefix == "prefix" + assert embedder._backend.parameters.suffix == "suffix" + assert embedder._backend.parameters.batch_size == 64 + assert embedder._backend.parameters.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] assert embedder.embedding_separator == " | " - assert embedder.normalize_embeddings is False - assert embedder.onnx_execution_provider == "CUDAExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MAX - assert embedder.model_kwargs == { + assert embedder._backend.parameters.normalize_embeddings is False + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX + assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", @@ -202,14 +203,14 @@ def test_infer_pooling_mode_from_str(self): Test that the pooling mode is correctly inferred from a string. The pooling mode is "mean" as per the model config. """ - for pooling_mode in PoolingMode: + for pooling_mode in OptimumEmbedderPooling: embedder = OptimumDocumentEmbedder( model="sentence-transformers/all-minilm-l6-v2", pooling_mode=pooling_mode.value, ) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.pooling_mode == pooling_mode + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.pooling_mode == pooling_mode @pytest.mark.integration def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 @@ -226,8 +227,8 @@ def test_infer_pooling_mode_from_hf(self): pooling_mode=None, ) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.pooling_mode == PoolingMode.MEAN + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN def test_prepare_texts_to_embed_w_metadata(self, mock_check_valid_model): # noqa: ARG002 documents = [ @@ -300,7 +301,9 @@ def test_run(self): docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + Document(content="Every planet we reach is dead", meta={"topic": "Monkeys"}), ] + docs_copy = copy.deepcopy(docs) embedder = OptimumDocumentEmbedder( model="sentence-transformers/all-mpnet-base-v2", @@ -313,6 +316,7 @@ def test_run(self): embedder.warm_up() result = embedder.run(documents=docs) + expected = [embedder.run([d]) for d in docs_copy] documents_with_embeddings = result["documents"] @@ -323,3 +327,6 @@ def test_run(self): assert isinstance(doc.embedding, list) assert len(doc.embedding) == 768 assert all(isinstance(x, float) for x in doc.embedding) + + # Check order + assert [d.embedding for d in docs_copy] == [d.embedding for d in docs] diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py index 58a5597c3..ce5bc2ffb 100644 --- a/integrations/optimum/tests/test_optimum_text_embedder.py +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -3,14 +3,14 @@ import pytest from haystack.utils.auth import Secret from haystack_integrations.components.embedders.optimum import OptimumTextEmbedder -from haystack_integrations.components.embedders.optimum.pooling import PoolingMode +from haystack_integrations.components.embedders.optimum.pooling import OptimumEmbedderPooling from huggingface_hub.utils import RepositoryNotFoundError @pytest.fixture def mock_check_valid_model(): with patch( - "haystack_integrations.components.embedders.optimum.optimum_text_embedder.check_valid_model", + "haystack_integrations.components.embedders.optimum._backend.check_valid_model", MagicMock(return_value=None), ) as mock: yield mock @@ -19,8 +19,8 @@ def mock_check_valid_model(): @pytest.fixture def mock_get_pooling_mode(): with patch( - "haystack_integrations.components.embedders.optimum.optimum_text_embedder.HFPoolingMode.get_pooling_mode", - MagicMock(return_value=PoolingMode.MEAN), + "haystack_integrations.components.embedders.optimum._backend._pooling_from_model_config", + MagicMock(return_value=OptimumEmbedderPooling.MEAN), ) as mock: yield mock @@ -30,14 +30,14 @@ def test_init_default(self, monkeypatch, mock_check_valid_model, mock_get_poolin monkeypatch.setenv("HF_API_TOKEN", "fake-api-token") embedder = OptimumTextEmbedder() - assert embedder.model == "sentence-transformers/all-mpnet-base-v2" - assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.normalize_embeddings is True - assert embedder.onnx_execution_provider == "CPUExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MEAN - assert embedder.model_kwargs == { + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder._backend.parameters.prefix == "" + assert embedder._backend.parameters.suffix == "" + assert embedder._backend.parameters.normalize_embeddings is True + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN + assert embedder._backend.parameters.model_kwargs == { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", "use_auth_token": "fake-api-token", @@ -55,14 +55,14 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 model_kwargs={"trust_remote_code": True}, ) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.token == Secret.from_token("fake-api-token") - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.normalize_embeddings is False - assert embedder.onnx_execution_provider == "CUDAExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MAX - assert embedder.model_kwargs == { + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.token == Secret.from_token("fake-api-token") + assert embedder._backend.parameters.prefix == "prefix" + assert embedder._backend.parameters.suffix == "suffix" + assert embedder._backend.parameters.normalize_embeddings is False + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX + assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", @@ -91,14 +91,14 @@ def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): } embedder = OptimumTextEmbedder.from_dict(data) - assert embedder.model == "sentence-transformers/all-mpnet-base-v2" - assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) - assert embedder.prefix == "" - assert embedder.suffix == "" - assert embedder.normalize_embeddings is True - assert embedder.onnx_execution_provider == "CPUExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MEAN - assert embedder.model_kwargs == { + assert embedder._backend.parameters.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder._backend.parameters.prefix == "" + assert embedder._backend.parameters.suffix == "" + assert embedder._backend.parameters.normalize_embeddings is True + assert embedder._backend.parameters.onnx_execution_provider == "CPUExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN + assert embedder._backend.parameters.model_kwargs == { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", "use_auth_token": None, @@ -136,14 +136,14 @@ def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_mod } embedder = OptimumTextEmbedder.from_dict(data) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False) - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.normalize_embeddings is False - assert embedder.onnx_execution_provider == "CUDAExecutionProvider" - assert embedder.pooling_mode == PoolingMode.MAX - assert embedder.model_kwargs == { + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.token == Secret.from_env_var("ENV_VAR", strict=False) + assert embedder._backend.parameters.prefix == "prefix" + assert embedder._backend.parameters.suffix == "suffix" + assert embedder._backend.parameters.normalize_embeddings is False + assert embedder._backend.parameters.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MAX + assert embedder._backend.parameters.model_kwargs == { "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", @@ -165,14 +165,14 @@ def test_infer_pooling_mode_from_str(self): Test that the pooling mode is correctly inferred from a string. The pooling mode is "mean" as per the model config. """ - for pooling_mode in PoolingMode: + for pooling_mode in OptimumEmbedderPooling: embedder = OptimumTextEmbedder( model="sentence-transformers/all-minilm-l6-v2", pooling_mode=pooling_mode.value, ) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.pooling_mode == pooling_mode + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.pooling_mode == pooling_mode @pytest.mark.integration def test_default_pooling_mode_when_config_not_found(self, mock_check_valid_model): # noqa: ARG002 @@ -189,8 +189,8 @@ def test_infer_pooling_mode_from_hf(self): pooling_mode=None, ) - assert embedder.model == "sentence-transformers/all-minilm-l6-v2" - assert embedder.pooling_mode == PoolingMode.MEAN + assert embedder._backend.parameters.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder._backend.parameters.pooling_mode == OptimumEmbedderPooling.MEAN def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 embedder = OptimumTextEmbedder( @@ -207,14 +207,16 @@ def test_run_wrong_input_format(self, mock_check_valid_model): # noqa: ARG002 @pytest.mark.integration def test_run(self): - embedder = OptimumTextEmbedder( - model="sentence-transformers/all-mpnet-base-v2", - prefix="prefix ", - suffix=" suffix", - ) - embedder.warm_up() + for pooling_mode in OptimumEmbedderPooling: + embedder = OptimumTextEmbedder( + model="sentence-transformers/all-mpnet-base-v2", + prefix="prefix ", + suffix=" suffix", + pooling_mode=pooling_mode, + ) + embedder.warm_up() - result = embedder.run(text="The food was delicious") + result = embedder.run(text="The food was delicious") - assert len(result["embedding"]) == 768 - assert all(isinstance(x, float) for x in result["embedding"]) + assert len(result["embedding"]) == 768 + assert all(isinstance(x, float) for x in result["embedding"])