Skip to content

Commit

Permalink
feat: SentenceTransformersDocumentEmbedder supports config_kwargs (#…
Browse files Browse the repository at this point in the history
…8433)

* initial import

* adding release notes
  • Loading branch information
davidsbatista authored and LastRemote committed Oct 24, 2024
1 parent 94c368a commit bf0a2e5
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_embedding_backend(
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
):
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}"

Expand All @@ -42,6 +43,7 @@ def get_embedding_backend(
truncate_dim=truncate_dim,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
)
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -61,6 +63,7 @@ def __init__(
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
):
sentence_transformers_import.check()
self.model = SentenceTransformer(
Expand All @@ -71,6 +74,7 @@ def __init__(
truncate_dim=truncate_dim,
model_kwargs=model_kwargs,
tokenizer_kwargs=tokenizer_kwargs,
config_kwargs=config_kwargs,
)

def embed(self, data: List[str], **kwargs) -> List[List[float]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__( # noqa: PLR0913
truncate_dim: Optional[int] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
config_kwargs: Optional[Dict[str, Any]] = None,
precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
):
"""
Expand Down Expand Up @@ -96,10 +97,12 @@ def __init__( # noqa: PLR0913
:param tokenizer_kwargs:
Additional keyword arguments for `AutoTokenizer.from_pretrained` when loading the tokenizer.
Refer to specific model documentation for available kwargs.
:param config_kwargs:
Additional keyword arguments for `AutoConfig.from_pretrained` when loading the model configuration.
:param precision:
The precision to use for the embeddings.
All non-float32 precisions are quantized embeddings.
Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy.
Quantized embeddings are smaller and faster to compute, but may have a lower accuracy.
They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks.
"""

Expand All @@ -117,6 +120,7 @@ def __init__( # noqa: PLR0913
self.truncate_dim = truncate_dim
self.model_kwargs = model_kwargs
self.tokenizer_kwargs = tokenizer_kwargs
self.config_kwargs = config_kwargs
self.embedding_backend = None
self.precision = precision

Expand Down Expand Up @@ -149,6 +153,7 @@ def to_dict(self) -> Dict[str, Any]:
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
precision=self.precision,
)
if serialization_dict["init_parameters"].get("model_kwargs") is not None:
Expand Down Expand Up @@ -186,6 +191,7 @@ def warm_up(self):
truncate_dim=self.truncate_dim,
model_kwargs=self.model_kwargs,
tokenizer_kwargs=self.tokenizer_kwargs,
config_kwargs=self.config_kwargs,
)
if self.tokenizer_kwargs and self.tokenizer_kwargs.get("model_max_length"):
self.embedding_backend.model.max_seq_length = self.tokenizer_kwargs["model_max_length"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
SentenceTransformersDocumentEmbedder now supports config_kwargs for additional parameters when loading the model configuration
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_to_dict(self):
"truncate_dim": None,
"model_kwargs": None,
"tokenizer_kwargs": None,
"config_kwargs": None,
"precision": "float32",
},
}
Expand All @@ -99,6 +100,7 @@ def test_to_dict_with_custom_init_parameters(self):
truncate_dim=256,
model_kwargs={"torch_dtype": torch.float32},
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
precision="int8",
)
data = component.to_dict()
Expand All @@ -120,6 +122,7 @@ def test_to_dict_with_custom_init_parameters(self):
"truncate_dim": 256,
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": True},
"precision": "int8",
},
}
Expand All @@ -140,6 +143,7 @@ def test_from_dict(self):
"truncate_dim": 256,
"model_kwargs": {"torch_dtype": "torch.float32"},
"tokenizer_kwargs": {"model_max_length": 512},
"config_kwargs": {"use_memory_efficient_attention": True},
"precision": "int8",
}
component = SentenceTransformersDocumentEmbedder.from_dict(
Expand All @@ -162,6 +166,7 @@ def test_from_dict(self):
assert component.truncate_dim == 256
assert component.model_kwargs == {"torch_dtype": torch.float32}
assert component.tokenizer_kwargs == {"model_max_length": 512}
assert component.config_kwargs == {"use_memory_efficient_attention": True}
assert component.precision == "int8"

def test_from_dict_no_default_parameters(self):
Expand Down Expand Up @@ -230,6 +235,7 @@ def test_warmup(self, mocked_factory):
token=None,
device=ComponentDevice.from_str("cpu"),
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
)
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
Expand All @@ -242,6 +248,7 @@ def test_warmup(self, mocked_factory):
truncate_dim=None,
model_kwargs=None,
tokenizer_kwargs={"model_max_length": 512},
config_kwargs={"use_memory_efficient_attention": True},
)

@patch(
Expand Down Expand Up @@ -291,11 +298,8 @@ def test_embed_metadata(self):
model="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n"
)
embedder.embedding_backend = MagicMock()

documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]

embedder.run(documents=documents)

embedder.embedding_backend.embed.assert_called_once_with(
[
"meta_value 0\ndocument number 0",
Expand All @@ -319,11 +323,8 @@ def test_prefix_suffix(self):
embedding_separator="\n",
)
embedder.embedding_backend = MagicMock()

documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]

embedder.run(documents=documents)

embedder.embedding_backend.embed.assert_called_once_with(
[
"my_prefix meta_value 0\ndocument number 0 my_suffix",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_model_initialization(mock_sentence_transformer):
truncate_dim=256,
model_kwargs=None,
tokenizer_kwargs=None,
config_kwargs=None,
)


Expand Down

0 comments on commit bf0a2e5

Please sign in to comment.