Skip to content

Commit

Permalink
Refactor optimum namespacing + bug fixes (#469)
Browse files Browse the repository at this point in the history
* refactor!: Move components to the `optimum` submodule

* fix: Resolve secret before passing to `hf_hub`

* build: Increment `sentence-transformer` pin to supprot weighted mean pooling

* Linting

* Pydocs config

* Fix secret serialization and pooling mode resolution

* Fix docstring
  • Loading branch information
shadeMe authored Feb 22, 2024
1 parent 9b63027 commit a2a69f6
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 126 deletions.
9 changes: 5 additions & 4 deletions integrations/optimum/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ loaders:
search_path: [../src]
modules:
[
"haystack_integrations.components.embedders.optimum_backend",
"haystack_integrations.components.embedders.optimum_document_embedder",
"haystack_integrations.components.embedders.optimum_text_embedder",
"haystack_integrations.components.embedders.pooling",
"haystack_integrations.components.embedders.optimum.optimum_document_embedder",
"haystack_integrations.components.embedders.optimum.optimum_text_embedder",
"haystack_integrations.components.embedders.optimum.pooling",
]
ignore_when_discovered: ["__init__"]
processors:
Expand All @@ -15,6 +14,8 @@ processors:
documented_only: true
do_not_filter_modules: false
skip_empty_modules: true
- type: filter
expression: "name not in ['Pooling', 'POOLING_MODES_MAP', 'INVERSE_POOLING_MODES_MAP', 'HFPoolingMode']"
- type: smart
- type: crossref
renderer:
Expand Down
123 changes: 52 additions & 71 deletions integrations/optimum/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ dependencies = [
# on which model is loaded from HF Hub.
# Ref: https://github.com/huggingface/optimum/blob/8651c0ca1cccf095458bc80329dec9df4601edb4/optimum/exporters/onnx/__main__.py#L164
# "sentence-transformers" has been added, since most embedding models use it
"sentence-transformers>=2.2.0",
"optimum[onnxruntime]"
"sentence-transformers>=2.3",
"optimum[onnxruntime]",
]

[project.urls]
Expand All @@ -53,22 +53,12 @@ root = "../.."
git_describe_command = 'git describe --tags --match="integrations/optimum-v[0-9]*"'

[tool.hatch.envs.default]
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"haystack-pydoc-tools"
]
dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
cov-report = [
"- coverage combine",
"coverage report",
]
cov = [
"test-cov",
"cov-report",
]
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]

[[tool.hatch.envs.all.matrix]]
Expand All @@ -77,27 +67,13 @@ python = ["3.8", "3.9", "3.10", "3.11", "3.12"]

[tool.hatch.envs.lint]
detached = true
dependencies = [
"black>=23.1.0",
"mypy>=1.0.0",
"ruff>=0.0.243",
]
dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]

[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
]
fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"style",
]
all = [
"style",
"typing",
]
style = ["ruff {args:.}", "black --check --diff {args:.}"]
fmt = ["black {args:.}", "ruff --fix {args:.}", "style"]
all = ["style", "typing"]

[tool.hatch.metadata]
allow-direct-references = true
Expand All @@ -114,43 +90,52 @@ skip-string-normalization = true
target-version = "py37"
line-length = 120
select = [
"A",
"ARG",
"B",
"C",
"DTZ",
"E",
"EM",
"F",
"I",
"ICN",
"ISC",
"N",
"PLC",
"PLE",
"PLR",
"PLW",
"Q",
"RUF",
"S",
"T",
"TID",
"UP",
"W",
"YTT",
"A",
"ARG",
"B",
"C",
"DTZ",
"E",
"EM",
"F",
"I",
"ICN",
"ISC",
"N",
"PLC",
"PLE",
"PLR",
"PLW",
"Q",
"RUF",
"S",
"T",
"TID",
"UP",
"W",
"YTT",
]
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Ignore checks for possible passwords
"S105", "S106", "S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
# Allow non-abstract empty methods in abstract base classes
"B027",
# Ignore checks for possible passwords
"S105",
"S106",
"S107",
# Ignore complexity
"C901",
"PLR0911",
"PLR0912",
"PLR0913",
"PLR0915",
# Asserts
"S101",
]
unfixable = [
# Don't touch unused imports
"F401",
# Don't touch unused imports
"F401",
]
extend-exclude = ["tests", "example"]

[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "parents"
Expand All @@ -172,11 +157,7 @@ optimum = ["src/haystack_integrations", "*/optimum/src/haystack_integrations"]
tests = ["tests", "*/optimum/tests"]

[tool.coverage.report]
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]

[[tool.mypy.overrides]]
module = [
Expand All @@ -188,7 +169,7 @@ module = [
"torch.*",
"transformers.*",
"huggingface_hub.*",
"sentence_transformers.*"
"sentence_transformers.*",
]
ignore_missing_imports = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import numpy as np
import torch
from haystack.utils.auth import Secret
from haystack_integrations.components.embedders.pooling import Pooling, PoolingMode
from optimum.onnxruntime import ORTModelForFeatureExtraction
from tqdm import tqdm
from transformers import AutoTokenizer

from optimum.onnxruntime import ORTModelForFeatureExtraction

from .pooling import Pooling, PoolingMode


class OptimumEmbeddingBackend:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend
from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode

from .optimum_backend import OptimumEmbeddingBackend
from .pooling import HFPoolingMode, PoolingMode


@component
Expand All @@ -18,7 +19,7 @@ class OptimumDocumentEmbedder:
Usage example:
```python
from haystack.dataclasses import Document
from haystack_integrations.components.embedders import OptimumDocumentEmbedder
from haystack_integrations.components.optimum.embedders import OptimumDocumentEmbedder
doc = Document(content="I love pizza!")
Expand Down Expand Up @@ -114,13 +115,16 @@ def __init__(
self.model = model

self.token = token
token = token.resolve_value() if token else None
resolved_token = token.resolve_value() if token else None

if isinstance(pooling_mode, str):
self.pooling_mode: Optional[PoolingMode] = None
if isinstance(pooling_mode, PoolingMode):
self.pooling_mode = pooling_mode
elif isinstance(pooling_mode, str):
self.pooling_mode = PoolingMode.from_str(pooling_mode)
# Infer pooling mode from model config if not provided,
if pooling_mode is None:
self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token)
else:
self.pooling_mode = HFPoolingMode.get_pooling_mode(model, resolved_token)

# Raise error if pooling mode is not found in model config and not specified by user
if self.pooling_mode is None:
modes = {e.value: e for e in PoolingMode}
Expand All @@ -144,7 +148,7 @@ def __init__(
# Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters
model_kwargs.setdefault("model_id", model)
model_kwargs.setdefault("provider", onnx_execution_provider)
model_kwargs.setdefault("use_auth_token", token)
model_kwargs.setdefault("use_auth_token", resolved_token)

self.model_kwargs = model_kwargs
self.embedding_backend = None
Expand All @@ -162,6 +166,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
assert self.pooling_mode is not None
serialization_dict = default_to_dict(
self,
model=self.model,
Expand All @@ -179,7 +184,7 @@ def to_dict(self) -> Dict[str, Any]:
)

model_kwargs = serialization_dict["init_parameters"]["model_kwargs"]
model_kwargs.pop("token", None)
model_kwargs.pop("use_auth_token", None)

serialize_hf_model_kwargs(model_kwargs)
return serialization_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model, deserialize_hf_model_kwargs, serialize_hf_model_kwargs
from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend
from haystack_integrations.components.embedders.pooling import HFPoolingMode, PoolingMode

from .optimum_backend import OptimumEmbeddingBackend
from .pooling import HFPoolingMode, PoolingMode


@component
Expand All @@ -15,7 +16,7 @@ class OptimumTextEmbedder:
Usage example:
```python
from haystack_integrations.components.embedders import OptimumTextEmbedder
from haystack_integrations.components.optimum.embedders import OptimumTextEmbedder
text_to_embed = "I love pizza!"
Expand Down Expand Up @@ -101,13 +102,16 @@ def __init__(
self.model = model

self.token = token
token = token.resolve_value() if token else None
resolved_token = token.resolve_value() if token else None

if isinstance(pooling_mode, str):
self.pooling_mode: Optional[PoolingMode] = None
if isinstance(pooling_mode, PoolingMode):
self.pooling_mode = pooling_mode
elif isinstance(pooling_mode, str):
self.pooling_mode = PoolingMode.from_str(pooling_mode)
# Infer pooling mode from model config if not provided,
if pooling_mode is None:
self.pooling_mode = HFPoolingMode.get_pooling_mode(model, token)
else:
self.pooling_mode = HFPoolingMode.get_pooling_mode(model, resolved_token)

# Raise error if pooling mode is not found in model config and not specified by user
if self.pooling_mode is None:
modes = {e.value: e for e in PoolingMode}
Expand All @@ -127,7 +131,7 @@ def __init__(
# Check if the model_kwargs contain the parameters, otherwise, populate them with values from init parameters
model_kwargs.setdefault("model_id", model)
model_kwargs.setdefault("provider", onnx_execution_provider)
model_kwargs.setdefault("use_auth_token", token)
model_kwargs.setdefault("use_auth_token", resolved_token)

self.model_kwargs = model_kwargs
self.embedding_backend = None
Expand All @@ -145,6 +149,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
assert self.pooling_mode is not None
serialization_dict = default_to_dict(
self,
model=self.model,
Expand All @@ -158,7 +163,7 @@ def to_dict(self) -> Dict[str, Any]:
)

model_kwargs = serialization_dict["init_parameters"]["model_kwargs"]
model_kwargs.pop("token", None)
model_kwargs.pop("use_auth_token", None)

serialize_hf_model_kwargs(model_kwargs)
return serialization_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Optional

import torch
from haystack.utils import Secret
from huggingface_hub import hf_hub_download
from sentence_transformers.models import Pooling as PoolingLayer

Expand Down Expand Up @@ -59,7 +58,7 @@ class HFPoolingMode:
"""

@staticmethod
def get_pooling_mode(model: str, token: Optional[Secret] = None) -> Optional[PoolingMode]:
def get_pooling_mode(model: str, token: Optional[str] = None) -> Optional[PoolingMode]:
"""
Gets the pooling mode of the model from the Hugging Face Hub.
Expand Down Expand Up @@ -103,18 +102,15 @@ class Pooling:
:param model_output: The output of the embedding model.
"""

def __init__(self, pooling_mode: PoolingMode, attention_mask: torch.tensor, model_output: torch.tensor):
def __init__(self, pooling_mode: PoolingMode, attention_mask: torch.Tensor, model_output: torch.Tensor):
self.pooling_mode = pooling_mode
self.attention_mask = attention_mask
self.model_output = model_output

def pool_embeddings(self) -> torch.tensor:
def pool_embeddings(self) -> torch.Tensor:
"""
Perform pooling on the output of the embedding model.
:param pooling_mode: The pooling mode to use.
:param attention_mask: The attention mask of the tokenized text.
:param model_output: The output of the embedding model.
:return: The embeddings of the text after pooling.
"""
pooling_func_map = {
Expand Down
4 changes: 2 additions & 2 deletions integrations/optimum/tests/test_optimum_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from haystack_integrations.components.embedders.optimum_backend import OptimumEmbeddingBackend
from haystack_integrations.components.embedders.pooling import PoolingMode
from haystack_integrations.components.embedders.optimum.optimum_backend import OptimumEmbeddingBackend
from haystack_integrations.components.embedders.optimum.pooling import PoolingMode


@pytest.fixture
Expand Down
Loading

0 comments on commit a2a69f6

Please sign in to comment.