Skip to content

Commit

Permalink
added truncate_dim to sentence transformers embedder (deepset-ai#8077)
Browse files Browse the repository at this point in the history
* added truncate_dim to sentence transformers embedder

* Update haystack/components/embedders/sentence_transformers_document_embedder.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* Update releasenotes/notes/release-note-2b603a123cd36214.yaml

Co-authored-by: Stefano Fiorucci <[email protected]>

* fixed parameter description

* added test for truncation to text embedder

* fix format

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
nickprock and anakin87 authored Jul 26, 2024
1 parent b2aef21 commit 47f4db8
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,22 @@ class _SentenceTransformersEmbeddingBackendFactory:

@staticmethod
def get_embedding_backend(
model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None, trust_remote_code: bool = False
model: str,
device: Optional[str] = None,
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
):
embedding_backend_id = f"{model}{device}{auth_token}"
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}"

if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
embedding_backend = _SentenceTransformersEmbeddingBackend(
model=model, device=device, auth_token=auth_token, trust_remote_code=trust_remote_code
model=model,
device=device,
auth_token=auth_token,
trust_remote_code=trust_remote_code,
truncate_dim=truncate_dim,
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -46,13 +54,15 @@ def __init__(
device: Optional[str] = None,
auth_token: Optional[Secret] = None,
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
):
sentence_transformers_import.check()
self.model = SentenceTransformer(
model_name_or_path=model,
device=device,
use_auth_token=auth_token.resolve_value() if auth_token else None,
trust_remote_code=trust_remote_code,
truncate_dim=truncate_dim,
)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
):
"""
Create a SentenceTransformersDocumentEmbedder component.
Expand Down Expand Up @@ -73,6 +74,10 @@ def __init__(
:param trust_remote_code:
If `False`, only Hugging Face verified model architectures are allowed.
If `True`, custom models and scripts are allowed.
:param truncate_dim:
The dimension to truncate sentence embeddings to. `None` does no truncation.
If the model has not been trained with Matryoshka Representation Learning,
truncation of embeddings can significantly affect performance.
"""

self.model = model
Expand All @@ -86,6 +91,7 @@ def __init__(
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.trust_remote_code = trust_remote_code
self.truncate_dim = truncate_dim

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -113,6 +119,7 @@ def to_dict(self) -> Dict[str, Any]:
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
)

@classmethod
Expand Down Expand Up @@ -141,6 +148,7 @@ def warm_up(self):
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
)

@component.output_types(documents=List[Document])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
progress_bar: bool = True,
normalize_embeddings: bool = False,
trust_remote_code: bool = False,
truncate_dim: Optional[int] = None,
):
"""
Create a SentenceTransformersTextEmbedder component.
Expand Down Expand Up @@ -71,6 +72,10 @@ def __init__(
:param trust_remote_code:
If `False`, permits only Hugging Face verified model architectures.
If `True`, permits custom models and scripts.
:param truncate_dim:
The dimension to truncate sentence embeddings to. `None` does no truncation.
If the model has not been trained with Matryoshka Representation Learning,
truncation of embeddings can significantly affect performance.
"""

self.model = model
Expand All @@ -82,6 +87,7 @@ def __init__(
self.progress_bar = progress_bar
self.normalize_embeddings = normalize_embeddings
self.trust_remote_code = trust_remote_code
self.truncate_dim = truncate_dim

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -107,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]:
progress_bar=self.progress_bar,
normalize_embeddings=self.normalize_embeddings,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
)

@classmethod
Expand Down Expand Up @@ -135,6 +142,7 @@ def warm_up(self):
device=self.device.to_torch_str(),
auth_token=self.token,
trust_remote_code=self.trust_remote_code,
truncate_dim=self.truncate_dim,
)

@component.output_types(embedding=List[float])
Expand Down
5 changes: 5 additions & 0 deletions releasenotes/notes/release-note-2b603a123cd36214.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
Add `truncate_dim` parameter to Sentence Transformers Embedders, which allows truncating
embeddings. Especially useful for models trained with Matryoshka Representation Learning.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def test_init_default(self):
assert embedder.meta_fields_to_embed == []
assert embedder.embedding_separator == "\n"
assert embedder.trust_remote_code is False
assert embedder.truncate_dim is None

def test_init_with_parameters(self):
embedder = SentenceTransformersDocumentEmbedder(
Expand All @@ -39,6 +40,7 @@ def test_init_with_parameters(self):
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
trust_remote_code=True,
truncate_dim=256,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
Expand All @@ -51,6 +53,7 @@ def test_init_with_parameters(self):
assert embedder.meta_fields_to_embed == ["test_field"]
assert embedder.embedding_separator == " | "
assert embedder.trust_remote_code
assert embedder.truncate_dim == 256

def test_to_dict(self):
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
Expand All @@ -69,6 +72,7 @@ def test_to_dict(self):
"embedding_separator": "\n",
"meta_fields_to_embed": [],
"trust_remote_code": False,
"truncate_dim": None,
},
}

Expand All @@ -85,6 +89,7 @@ def test_to_dict_with_custom_init_parameters(self):
meta_fields_to_embed=["meta_field"],
embedding_separator=" - ",
trust_remote_code=True,
truncate_dim=256,
)
data = component.to_dict()

Expand All @@ -102,6 +107,7 @@ def test_to_dict_with_custom_init_parameters(self):
"embedding_separator": " - ",
"trust_remote_code": True,
"meta_fields_to_embed": ["meta_field"],
"truncate_dim": 256,
},
}

Expand All @@ -118,6 +124,7 @@ def test_from_dict(self):
"embedding_separator": " - ",
"meta_fields_to_embed": ["meta_field"],
"trust_remote_code": True,
"truncate_dim": 256,
}
component = SentenceTransformersDocumentEmbedder.from_dict(
{
Expand All @@ -136,6 +143,7 @@ def test_from_dict(self):
assert component.embedding_separator == " - "
assert component.trust_remote_code
assert component.meta_fields_to_embed == ["meta_field"]
assert component.truncate_dim == 256

def test_from_dict_no_default_parameters(self):
component = SentenceTransformersDocumentEmbedder.from_dict(
Expand All @@ -155,6 +163,7 @@ def test_from_dict_no_default_parameters(self):
assert component.embedding_separator == "\n"
assert component.trust_remote_code is False
assert component.meta_fields_to_embed == []
assert component.truncate_dim is None

def test_from_dict_none_device(self):
init_parameters = {
Expand All @@ -169,6 +178,7 @@ def test_from_dict_none_device(self):
"embedding_separator": " - ",
"meta_fields_to_embed": ["meta_field"],
"trust_remote_code": True,
"truncate_dim": None,
}
component = SentenceTransformersDocumentEmbedder.from_dict(
{
Expand All @@ -187,6 +197,7 @@ def test_from_dict_none_device(self):
assert component.embedding_separator == " - "
assert component.trust_remote_code
assert component.meta_fields_to_embed == ["meta_field"]
assert component.truncate_dim is None

@patch(
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand All @@ -198,7 +209,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
model="model", device="cpu", auth_token=None, trust_remote_code=False, truncate_dim=None
)

@patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@ def test_factory_behavior(mock_sentence_transformer):
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
def test_model_initialization(mock_sentence_transformer):
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token"), trust_remote_code=True
model="model",
device="cpu",
auth_token=Secret.from_token("fake-api-token"),
trust_remote_code=True,
truncate_dim=256,
)
mock_sentence_transformer.assert_called_once_with(
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token", trust_remote_code=True
model_name_or_path="model",
device="cpu",
use_auth_token="fake-api-token",
trust_remote_code=True,
truncate_dim=256,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_init_default(self):
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
assert embedder.trust_remote_code is False
assert embedder.truncate_dim is None

def test_init_with_parameters(self):
embedder = SentenceTransformersTextEmbedder(
Expand All @@ -34,6 +35,7 @@ def test_init_with_parameters(self):
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
truncate_dim=256,
)
assert embedder.model == "model"
assert embedder.device == ComponentDevice.from_str("cuda:0")
Expand All @@ -43,7 +45,8 @@ def test_init_with_parameters(self):
assert embedder.batch_size == 64
assert embedder.progress_bar is False
assert embedder.normalize_embeddings is True
assert embedder.trust_remote_code
assert embedder.trust_remote_code is True
assert embedder.truncate_dim == 256

def test_to_dict(self):
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
Expand All @@ -60,6 +63,7 @@ def test_to_dict(self):
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"truncate_dim": None,
},
}

Expand All @@ -74,6 +78,7 @@ def test_to_dict_with_custom_init_parameters(self):
progress_bar=False,
normalize_embeddings=True,
trust_remote_code=True,
truncate_dim=256,
)
data = component.to_dict()
assert data == {
Expand All @@ -88,6 +93,7 @@ def test_to_dict_with_custom_init_parameters(self):
"progress_bar": False,
"normalize_embeddings": True,
"trust_remote_code": True,
"truncate_dim": 256,
},
}

Expand All @@ -109,6 +115,7 @@ def test_from_dict(self):
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"truncate_dim": None,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
Expand All @@ -121,6 +128,7 @@ def test_from_dict(self):
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False
assert component.truncate_dim is None

def test_from_dict_no_default_parameters(self):
data = {
Expand All @@ -137,6 +145,7 @@ def test_from_dict_no_default_parameters(self):
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False
assert component.truncate_dim is None

def test_from_dict_none_device(self):
data = {
Expand All @@ -151,6 +160,7 @@ def test_from_dict_none_device(self):
"progress_bar": True,
"normalize_embeddings": False,
"trust_remote_code": False,
"truncate_dim": 256,
},
}
component = SentenceTransformersTextEmbedder.from_dict(data)
Expand All @@ -163,6 +173,7 @@ def test_from_dict_none_device(self):
assert component.progress_bar is True
assert component.normalize_embeddings is False
assert component.trust_remote_code is False
assert component.truncate_dim == 256

@patch(
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
Expand All @@ -172,7 +183,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model="model", device="cpu", auth_token=None, trust_remote_code=False
model="model", device="cpu", auth_token=None, trust_remote_code=False, truncate_dim=None
)

@patch(
Expand Down Expand Up @@ -206,3 +217,24 @@ def test_run_wrong_input_format(self):

with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):
embedder.run(text=list_integers_input)

@pytest.mark.integration
def test_run_trunc(self):
"""
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space
"""
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
text = "a nice text to embed"

embedder_def = SentenceTransformersTextEmbedder(model=checkpoint)
embedder_def.warm_up()
result_def = embedder_def.run(text=text)
embedding_def = result_def["embedding"]

embedder_trunc = SentenceTransformersTextEmbedder(model=checkpoint, truncate_dim=128)
embedder_trunc.warm_up()
result_trunc = embedder_trunc.run(text=text)
embedding_trunc = result_trunc["embedding"]

assert len(embedding_def) == 768
assert len(embedding_trunc) == 128

0 comments on commit 47f4db8

Please sign in to comment.