diff --git a/integrations/optimum/pydoc/config.yml b/integrations/optimum/pydoc/config.yml index 5fb353b5d..979bb389f 100644 --- a/integrations/optimum/pydoc/config.yml +++ b/integrations/optimum/pydoc/config.yml @@ -3,10 +3,9 @@ loaders: search_path: [../src] modules: [ - "haystack_integrations.components.embedders.optimum_backend", - "haystack_integrations.components.embedders.optimum_document_embedder", - "haystack_integrations.components.embedders.optimum_text_embedder", - "haystack_integrations.components.embedders.pooling", + "haystack_integrations.components.embedders.optimum.optimum_document_embedder", + "haystack_integrations.components.embedders.optimum.optimum_text_embedder", + "haystack_integrations.components.embedders.optimum.pooling", ] ignore_when_discovered: ["__init__"] processors: @@ -15,6 +14,8 @@ processors: documented_only: true do_not_filter_modules: false skip_empty_modules: true + - type: filter + expression: "name not in ['Pooling', 'POOLING_MODES_MAP', 'INVERSE_POOLING_MODES_MAP', 'HFPoolingMode']" - type: smart - type: crossref renderer: diff --git a/integrations/optimum/pyproject.toml b/integrations/optimum/pyproject.toml index 17bca2597..ae89210f6 100644 --- a/integrations/optimum/pyproject.toml +++ b/integrations/optimum/pyproject.toml @@ -32,8 +32,8 @@ dependencies = [ # on which model is loaded from HF Hub. # Ref: https://github.com/huggingface/optimum/blob/8651c0ca1cccf095458bc80329dec9df4601edb4/optimum/exporters/onnx/__main__.py#L164 # "sentence-transformers" has been added, since most embedding models use it - "sentence-transformers>=2.2.0", - "optimum[onnxruntime]" + "sentence-transformers>=2.3", + "optimum[onnxruntime]", ] [project.urls] @@ -53,22 +53,12 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/optimum-v[0-9]*"' [tool.hatch.envs.default] -dependencies = [ - "coverage[toml]>=6.5", - "pytest", - "haystack-pydoc-tools" -] +dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] @@ -77,27 +67,13 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.hatch.metadata] allow-direct-references = true @@ -114,43 +90,52 @@ skip-string-normalization = true target-version = "py37" line-length = 120 select = [ - "A", - "ARG", - "B", - "C", - "DTZ", - "E", - "EM", - "F", - "I", - "ICN", - "ISC", - "N", - "PLC", - "PLE", - "PLR", - "PLW", - "Q", - "RUF", - "S", - "T", - "TID", - "UP", - "W", - "YTT", + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", ] ignore = [ - # Allow non-abstract empty methods in abstract base classes - "B027", - # Ignore checks for possible passwords - "S105", "S106", "S107", - # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + # Allow non-abstract empty methods in abstract base classes + "B027", + # Ignore checks for possible passwords + "S105", + "S106", + "S107", + # Ignore complexity + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Asserts + "S101", ] unfixable = [ - # Don't touch unused imports - "F401", + # Don't touch unused imports + "F401", ] +extend-exclude = ["tests", "example"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "parents" @@ -172,11 +157,7 @@ optimum = ["src/haystack_integrations", "*/optimum/src/haystack_integrations"] tests = ["tests", "*/optimum/tests"] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ @@ -188,7 +169,7 @@ module = [ "torch.*", "transformers.*", "huggingface_hub.*", - "sentence_transformers.*" + "sentence_transformers.*", ] ignore_missing_imports = true diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/__init__.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py similarity index 100% rename from integrations/optimum/src/haystack_integrations/components/embedders/__init__.py rename to integrations/optimum/src/haystack_integrations/components/embedders/optimum/__init__.py diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_backend.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_backend.py similarity index 97% rename from integrations/optimum/src/haystack_integrations/components/embedders/optimum_backend.py rename to integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_backend.py index 55d0e6f20..f7d7ce7be 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_backend.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_backend.py @@ -3,11 +3,13 @@ import numpy as np import torch from haystack.utils.auth import Secret -from haystack_integrations.components.embedders.pooling import Pooling, PoolingMode -from optimum.onnxruntime import ORTModelForFeatureExtraction from tqdm import tqdm from transformers import AutoTokenizer +from optimum.onnxruntime import ORTModelForFeatureExtraction + +from .pooling import Pooling, PoolingMode + class OptimumEmbeddingBackend: """ diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py similarity index 94% rename from integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py rename to integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py index f0310f872..7a9caad71 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_document_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_document_embedder.py @@ -3,8 +3,9 @@ from haystack import Document, component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs -from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend -from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode + +from .optimum_backend import OptimumEmbeddingBackend +from .pooling import HFPoolingMode, PoolingMode @component @@ -18,7 +19,7 @@ class OptimumDocumentEmbedder: Usage example: ```python from haystack.dataclasses import Document - from haystack_integrations.components.embedders import OptimumDocumentEmbedder + from haystack_integrations.components.optimum.embedders import OptimumDocumentEmbedder doc = Document(content="I love pizza!") @@ -114,13 +115,16 @@ def __init__( self.model = model self.token = token - token = token.resolve_value() if token else None + resolved_token = token.resolve_value() if token else None - if isinstance(pooling_mode, str): + self.pooling_mode: Optional[PoolingMode] = None + if isinstance(pooling_mode, PoolingMode): + self.pooling_mode = pooling_mode + elif isinstance(pooling_mode, str): self.pooling_mode = PoolingMode.from_str(pooling_mode) - # Infer pooling mode from model config if not provided, - if pooling_mode is None: - self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token) + else: + self.pooling_mode = HFPoolingMode.get_pooling_mode(model, resolved_token) + # Raise error if pooling mode is not found in model config and not specified by user if self.pooling_mode is None: modes = {e.value: e for e in PoolingMode} @@ -144,7 +148,7 @@ def __init__( # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters model_kwargs.setdefault("model_id", model) model_kwargs.setdefault("provider", onnx_execution_provider) - model_kwargs.setdefault("use_auth_token", token) + model_kwargs.setdefault("use_auth_token", resolved_token) self.model_kwargs = model_kwargs self.embedding_backend = None @@ -162,6 +166,7 @@ def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ + assert self.pooling_mode is not None serialization_dict = default_to_dict( self, model=self.model, @@ -179,7 +184,7 @@ def to_dict(self) -> Dict[str, Any]: ) model_kwargs = serialization_dict["init_parameters"]["model_kwargs"] - model_kwargs.pop("token", None) + model_kwargs.pop("use_auth_token", None) serialize_hf_model_kwargs(model_kwargs) return serialization_dict diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py similarity index 93% rename from integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py rename to integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py index 8a33a9403..a20d7b5f8 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/optimum_text_embedder.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/optimum_text_embedder.py @@ -3,8 +3,9 @@ from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs -from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend -from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode + +from .optimum_backend import OptimumEmbeddingBackend +from .pooling import HFPoolingMode, PoolingMode @component @@ -15,7 +16,7 @@ class OptimumTextEmbedder: Usage example: ```python - from haystack_integrations.components.embedders import OptimumTextEmbedder + from haystack_integrations.components.optimum.embedders import OptimumTextEmbedder text_to_embed = "I love pizza!" @@ -101,13 +102,16 @@ def __init__( self.model = model self.token = token - token = token.resolve_value() if token else None + resolved_token = token.resolve_value() if token else None - if isinstance(pooling_mode, str): + self.pooling_mode: Optional[PoolingMode] = None + if isinstance(pooling_mode, PoolingMode): + self.pooling_mode = pooling_mode + elif isinstance(pooling_mode, str): self.pooling_mode = PoolingMode.from_str(pooling_mode) - # Infer pooling mode from model config if not provided, - if pooling_mode is None: - self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token) + else: + self.pooling_mode = HFPoolingMode.get_pooling_mode(model, resolved_token) + # Raise error if pooling mode is not found in model config and not specified by user if self.pooling_mode is None: modes = {e.value: e for e in PoolingMode} @@ -127,7 +131,7 @@ def __init__( # Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters model_kwargs.setdefault("model_id", model) model_kwargs.setdefault("provider", onnx_execution_provider) - model_kwargs.setdefault("use_auth_token", token) + model_kwargs.setdefault("use_auth_token", resolved_token) self.model_kwargs = model_kwargs self.embedding_backend = None @@ -145,6 +149,7 @@ def to_dict(self) -> Dict[str, Any]: """ Serialize this component to a dictionary. """ + assert self.pooling_mode is not None serialization_dict = default_to_dict( self, model=self.model, @@ -158,7 +163,7 @@ def to_dict(self) -> Dict[str, Any]: ) model_kwargs = serialization_dict["init_parameters"]["model_kwargs"] - model_kwargs.pop("token", None) + model_kwargs.pop("use_auth_token", None) serialize_hf_model_kwargs(model_kwargs) return serialization_dict diff --git a/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py similarity index 91% rename from integrations/optimum/src/haystack_integrations/components/embedders/pooling.py rename to integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py index 8f0cddc22..df8f8fb4e 100644 --- a/integrations/optimum/src/haystack_integrations/components/embedders/pooling.py +++ b/integrations/optimum/src/haystack_integrations/components/embedders/optimum/pooling.py @@ -3,7 +3,6 @@ from typing import Optional import torch -from haystack.utils import Secret from huggingface_hub import hf_hub_download from sentence_transformers.models import Pooling as PoolingLayer @@ -59,7 +58,7 @@ class HFPoolingMode: """ @staticmethod - def get_pooling_mode(model: str, token: Optional[Secret] = None) -> Optional[PoolingMode]: + def get_pooling_mode(model: str, token: Optional[str] = None) -> Optional[PoolingMode]: """ Gets the pooling mode of the model from the Hugging Face Hub. @@ -103,18 +102,15 @@ class Pooling: :param model_output: The output of the embedding model. """ - def __init__(self, pooling_mode: PoolingMode, attention_mask: torch.tensor, model_output: torch.tensor): + def __init__(self, pooling_mode: PoolingMode, attention_mask: torch.Tensor, model_output: torch.Tensor): self.pooling_mode = pooling_mode self.attention_mask = attention_mask self.model_output = model_output - def pool_embeddings(self) -> torch.tensor: + def pool_embeddings(self) -> torch.Tensor: """ Perform pooling on the output of the embedding model. - :param pooling_mode: The pooling mode to use. - :param attention_mask: The attention mask of the tokenized text. - :param model_output: The output of the embedding model. :return: The embeddings of the text after pooling. """ pooling_func_map = { diff --git a/integrations/optimum/tests/test_optimum_backend.py b/integrations/optimum/tests/test_optimum_backend.py index 8ef61fd37..cfa2ac08d 100644 --- a/integrations/optimum/tests/test_optimum_backend.py +++ b/integrations/optimum/tests/test_optimum_backend.py @@ -1,6 +1,6 @@ import pytest -from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend -from haystack_integrations.components.embedders.pooling import PoolingMode +from haystack_integrations.components.embedders.optimum.optimum_backend import OptimumEmbeddingBackend +from haystack_integrations.components.embedders.optimum.pooling import PoolingMode @pytest.fixture diff --git a/integrations/optimum/tests/test_optimum_document_embedder.py b/integrations/optimum/tests/test_optimum_document_embedder.py index f61fea1d3..3c8cc90b3 100644 --- a/integrations/optimum/tests/test_optimum_document_embedder.py +++ b/integrations/optimum/tests/test_optimum_document_embedder.py @@ -3,15 +3,15 @@ import pytest from haystack.dataclasses import Document from haystack.utils.auth import Secret -from haystack_integrations.components.embedders import OptimumDocumentEmbedder -from haystack_integrations.components.embedders.pooling import PoolingMode +from haystack_integrations.components.embedders.optimum import OptimumDocumentEmbedder +from haystack_integrations.components.embedders.optimum.pooling import PoolingMode from huggingface_hub.utils import RepositoryNotFoundError @pytest.fixture def mock_check_valid_model(): with patch( - "haystack_integrations.components.embedders.optimum_document_embedder.check_valid_model", + "haystack_integrations.components.embedders.optimum.optimum_document_embedder.check_valid_model", MagicMock(return_value=None), ) as mock: yield mock @@ -20,7 +20,7 @@ def mock_check_valid_model(): @pytest.fixture def mock_get_pooling_mode(): with patch( - "haystack_integrations.components.embedders.optimum_text_embedder.HFPoolingMode.get_pooling_mode", + "haystack_integrations.components.embedders.optimum.optimum_text_embedder.HFPoolingMode.get_pooling_mode", MagicMock(return_value=PoolingMode.MEAN), ) as mock: yield mock @@ -82,12 +82,12 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 "use_auth_token": "fake-api-token", } - def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumDocumentEmbedder() data = component.to_dict() assert data == { - "type": "haystack_integrations.components.embedders.optimum_document_embedder.OptimumDocumentEmbedder", + "type": "haystack_integrations.components.embedders.optimum.optimum_document_embedder.OptimumDocumentEmbedder", "init_parameters": { "model": "sentence-transformers/all-mpnet-base-v2", "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, @@ -103,12 +103,31 @@ def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: "model_kwargs": { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", - "use_auth_token": None, }, }, } - def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + embedder = OptimumDocumentEmbedder.from_dict(data) + assert embedder.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.normalize_embeddings is True + assert embedder.onnx_execution_provider == "CPUExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MEAN + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + assert embedder.model_kwargs == { + "model_id": "sentence-transformers/all-mpnet-base-v2", + "provider": "CPUExecutionProvider", + "use_auth_token": None, + } + + def test_to_and_from_dict_with_custom_init_parameters( + self, mock_check_valid_model, mock_get_pooling_mode + ): # noqa: ARG002 component = OptimumDocumentEmbedder( model="sentence-transformers/all-minilm-l6-v2", token=Secret.from_env_var("ENV_VAR", strict=False), @@ -126,7 +145,7 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model, mock_ data = component.to_dict() assert data == { - "type": "haystack_integrations.components.embedders.optimum_document_embedder.OptimumDocumentEmbedder", + "type": "haystack_integrations.components.embedders.optimum.optimum_document_embedder.OptimumDocumentEmbedder", "init_parameters": { "model": "sentence-transformers/all-minilm-l6-v2", "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, @@ -143,11 +162,29 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model, mock_ "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", - "use_auth_token": None, }, }, } + embedder = OptimumDocumentEmbedder.from_dict(data) + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False) + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " + assert embedder.normalize_embeddings is False + assert embedder.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MAX + assert embedder.model_kwargs == { + "trust_remote_code": True, + "model_id": "sentence-transformers/all-minilm-l6-v2", + "provider": "CUDAExecutionProvider", + "use_auth_token": None, + } + def test_initialize_with_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") with pytest.raises(RepositoryNotFoundError): diff --git a/integrations/optimum/tests/test_optimum_text_embedder.py b/integrations/optimum/tests/test_optimum_text_embedder.py index 9932d1dbf..58a5597c3 100644 --- a/integrations/optimum/tests/test_optimum_text_embedder.py +++ b/integrations/optimum/tests/test_optimum_text_embedder.py @@ -2,15 +2,15 @@ import pytest from haystack.utils.auth import Secret -from haystack_integrations.components.embedders import OptimumTextEmbedder -from haystack_integrations.components.embedders.pooling import PoolingMode +from haystack_integrations.components.embedders.optimum import OptimumTextEmbedder +from haystack_integrations.components.embedders.optimum.pooling import PoolingMode from huggingface_hub.utils import RepositoryNotFoundError @pytest.fixture def mock_check_valid_model(): with patch( - "haystack_integrations.components.embedders.optimum_text_embedder.check_valid_model", + "haystack_integrations.components.embedders.optimum.optimum_text_embedder.check_valid_model", MagicMock(return_value=None), ) as mock: yield mock @@ -19,7 +19,7 @@ def mock_check_valid_model(): @pytest.fixture def mock_get_pooling_mode(): with patch( - "haystack_integrations.components.embedders.optimum_text_embedder.HFPoolingMode.get_pooling_mode", + "haystack_integrations.components.embedders.optimum.optimum_text_embedder.HFPoolingMode.get_pooling_mode", MagicMock(return_value=PoolingMode.MEAN), ) as mock: yield mock @@ -69,12 +69,12 @@ def test_init_with_parameters(self, mock_check_valid_model): # noqa: ARG002 "use_auth_token": "fake-api-token", } - def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 + def test_to_and_from_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: ARG002 component = OptimumTextEmbedder() data = component.to_dict() assert data == { - "type": "haystack_integrations.components.embedders.optimum_text_embedder.OptimumTextEmbedder", + "type": "haystack_integrations.components.embedders.optimum.optimum_text_embedder.OptimumTextEmbedder", "init_parameters": { "model": "sentence-transformers/all-mpnet-base-v2", "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, @@ -86,12 +86,25 @@ def test_to_dict(self, mock_check_valid_model, mock_get_pooling_mode): # noqa: "model_kwargs": { "model_id": "sentence-transformers/all-mpnet-base-v2", "provider": "CPUExecutionProvider", - "use_auth_token": None, }, }, } - def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # noqa: ARG002 + embedder = OptimumTextEmbedder.from_dict(data) + assert embedder.model == "sentence-transformers/all-mpnet-base-v2" + assert embedder.token == Secret.from_env_var("HF_API_TOKEN", strict=False) + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.normalize_embeddings is True + assert embedder.onnx_execution_provider == "CPUExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MEAN + assert embedder.model_kwargs == { + "model_id": "sentence-transformers/all-mpnet-base-v2", + "provider": "CPUExecutionProvider", + "use_auth_token": None, + } + + def test_to_and_from_dict_with_custom_init_parameters(self, mock_check_valid_model): # noqa: ARG002 component = OptimumTextEmbedder( model="sentence-transformers/all-minilm-l6-v2", token=Secret.from_env_var("ENV_VAR", strict=False), @@ -105,7 +118,7 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n data = component.to_dict() assert data == { - "type": "haystack_integrations.components.embedders.optimum_text_embedder.OptimumTextEmbedder", + "type": "haystack_integrations.components.embedders.optimum.optimum_text_embedder.OptimumTextEmbedder", "init_parameters": { "model": "sentence-transformers/all-minilm-l6-v2", "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, @@ -118,11 +131,25 @@ def test_to_dict_with_custom_init_parameters(self, mock_check_valid_model): # n "trust_remote_code": True, "model_id": "sentence-transformers/all-minilm-l6-v2", "provider": "CUDAExecutionProvider", - "use_auth_token": None, }, }, } + embedder = OptimumTextEmbedder.from_dict(data) + assert embedder.model == "sentence-transformers/all-minilm-l6-v2" + assert embedder.token == Secret.from_env_var("ENV_VAR", strict=False) + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.normalize_embeddings is False + assert embedder.onnx_execution_provider == "CUDAExecutionProvider" + assert embedder.pooling_mode == PoolingMode.MAX + assert embedder.model_kwargs == { + "trust_remote_code": True, + "model_id": "sentence-transformers/all-minilm-l6-v2", + "provider": "CUDAExecutionProvider", + "use_auth_token": None, + } + def test_initialize_with_invalid_model(self, mock_check_valid_model): mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") with pytest.raises(RepositoryNotFoundError):