From 42cc2ae8557a68ae61f22117dbacea46fc39ff67 Mon Sep 17 00:00:00 2001 From: Alper Date: Mon, 7 Oct 2024 10:55:49 +0200 Subject: [PATCH] feat: introduce `model_kwargs` in Sparse Embedders (can be used for BM25 parameters) (#1126) * add support for bm25 in FastEmbed integration * agnostic support for model config params * Apply suggestions from code review Co-authored-by: Stefano Fiorucci * add future readability --------- Co-authored-by: Stefano Fiorucci --- .../embedding_backend/fastembed_backend.py | 20 +++++-- .../fastembed_sparse_document_embedder.py | 5 ++ .../fastembed_sparse_text_embedder.py | 5 ++ .../fastembed/tests/test_fastembed_backend.py | 26 +++++++++ ...test_fastembed_sparse_document_embedder.py | 54 +++++++++++++++++- .../test_fastembed_sparse_text_embedder.py | 56 ++++++++++++++++++- 6 files changed, 160 insertions(+), 6 deletions(-) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index 66f797549..3a68abcfb 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional from haystack.dataclasses.sparse_embedding import SparseEmbedding from tqdm import tqdm @@ -73,14 +73,19 @@ def get_embedding_backend( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): - embedding_backend_id = f"{model_name}{cache_dir}{threads}" + embedding_backend_id = f"{model_name}{cache_dir}{threads}{local_files_only}{model_kwargs}" if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances: return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _FastembedSparseEmbeddingBackend( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + local_files_only=local_files_only, + model_kwargs=model_kwargs, ) _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -97,9 +102,16 @@ def __init__( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): + model_kwargs = model_kwargs or {} + self.model = SparseTextEmbedding( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + local_files_only=local_files_only, + **model_kwargs, ) def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index 4b72389fa..f79f08c90 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -62,6 +62,7 @@ def __init__( local_files_only: bool = False, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Create an FastembedDocumentEmbedder component. @@ -81,6 +82,7 @@ def __init__( :param local_files_only: If `True`, only use the model files in the `cache_dir`. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + :param model_kwargs: Dictionary containing model parameters such as `k`, `b`, `avg_len`, `language`. """ self.model_name = model @@ -92,6 +94,7 @@ def __init__( self.local_files_only = local_files_only self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.model_kwargs = model_kwargs def to_dict(self) -> Dict[str, Any]: """ @@ -110,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]: local_files_only=self.local_files_only, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + model_kwargs=self.model_kwargs, ) def warm_up(self): @@ -122,6 +126,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index 67348b2bd..2ebab35b4 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -35,6 +35,7 @@ def __init__( progress_bar: bool = True, parallel: Optional[int] = None, local_files_only: bool = False, + model_kwargs: Optional[Dict[str, Any]] = None, ): """ Create a FastembedSparseTextEmbedder component. @@ -50,6 +51,7 @@ def __init__( If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. :param local_files_only: If `True`, only use the model files in the `cache_dir`. + :param model_kwargs: Dictionary containing model parameters such as `k`, `b`, `avg_len`, `language`. """ self.model_name = model @@ -58,6 +60,7 @@ def __init__( self.progress_bar = progress_bar self.parallel = parallel self.local_files_only = local_files_only + self.model_kwargs = model_kwargs def to_dict(self) -> Dict[str, Any]: """ @@ -74,6 +77,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, parallel=self.parallel, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) def warm_up(self): @@ -86,6 +90,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + model_kwargs=self.model_kwargs, ) @component.output_types(sparse_embedding=SparseEmbedding) diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index 631d9f1e0..994a6f883 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -2,6 +2,7 @@ from haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend import ( _FastembedEmbeddingBackendFactory, + _FastembedSparseEmbeddingBackendFactory, ) @@ -44,3 +45,28 @@ def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 embedding_backend.model.embed.assert_called_once_with(data) # restore the factory stateTrue _FastembedEmbeddingBackendFactory._instances = {} + + +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.SparseTextEmbedding") +def test_model_kwargs_initialization(mock_instructor): + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 40, + } + + # Invoke the backend factory with the BM25 configuration + _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( + model_name="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + # Check if SparseTextEmbedding was called with the correct arguments + mock_instructor.assert_called_once_with( + model_name="Qdrant/bm25", cache_dir=None, threads=None, local_files_only=False, **bm25_config + ) + + # Restore factory state after the test + _FastembedSparseEmbeddingBackendFactory._instances = {} diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index d3f2023b8..90e94908d 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -69,6 +69,7 @@ def test_to_dict(self): "local_files_only": False, "embedding_separator": "\n", "meta_fields_to_embed": [], + "model_kwargs": None, }, } @@ -100,6 +101,7 @@ def test_to_dict_with_custom_init_parameters(self): "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", + "model_kwargs": None, }, } @@ -174,7 +176,11 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithvida/Splade_PP_en_v1", + cache_dir=None, + threads=None, + local_files_only=False, + model_kwargs=None, ) @patch( @@ -275,6 +281,52 @@ def test_embed_metadata(self): parallel=None, ) + def test_init_with_model_kwargs_parameters(self): + """ + Test initialization of FastembedSparseDocumentEmbedder with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + assert embedder.model_kwargs == bm25_config + + @pytest.mark.integration + def test_run_with_model_kwargs(self): + """ + Integration test to check the embedding with model_kwargs parameters. + """ + bm42_config = { + "alpha": 0.2, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm42-all-minilm-l6-v2-attentions", + model_kwargs=bm42_config, + ) + embedder.warm_up() + + doc = Document(content="Example content using BM42") + + result = embedder.run(documents=[doc]) + embedding = result["documents"][0].sparse_embedding + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseDocumentEmbedder( diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index 7e9197493..4f438fd15 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -54,6 +54,7 @@ def test_to_dict(self): "progress_bar": True, "parallel": None, "local_files_only": False, + "model_kwargs": None, }, } @@ -79,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self): "progress_bar": False, "parallel": 1, "local_files_only": True, + "model_kwargs": None, }, } @@ -135,7 +137,11 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithvida/Splade_PP_en_v1", + cache_dir=None, + threads=None, + local_files_only=False, + model_kwargs=None, ) @patch( @@ -195,6 +201,54 @@ def test_run_wrong_incorrect_format(self): with pytest.raises(TypeError, match="FastembedSparseTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) + def test_init_with_model_kwargs_parameters(self): + """ + Test initialization of FastembedSparseTextEmbedder with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + + assert embedder.model_kwargs == bm25_config + + @pytest.mark.integration + def test_run_with_model_kwargs(self): + """ + Integration test to check the embedding with model_kwargs parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 256.0, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + model_kwargs=bm25_config, + ) + embedder.warm_up() + + text = "Example content using BM25" + + result = embedder.run(text=text) + embedding = result["sparse_embedding"] + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseTextEmbedder(