From 1b8507a5e32dee399634af50dcf786cb8400ca16 Mon Sep 17 00:00:00 2001 From: alperkaya Date: Fri, 4 Oct 2024 22:26:30 +0200 Subject: [PATCH] add support for bm25 in FastEmbed integration --- .../embedding_backend/fastembed_backend.py | 14 ++-- .../fastembed_sparse_document_embedder.py | 5 ++ .../fastembed_sparse_text_embedder.py | 5 ++ .../fastembed/tests/test_fastembed_backend.py | 26 +++++++ ...test_fastembed_sparse_document_embedder.py | 68 ++++++++++++++++++- .../test_fastembed_sparse_text_embedder.py | 68 ++++++++++++++++++- 6 files changed, 180 insertions(+), 6 deletions(-) diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py index 66f797549..19de89558 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/embedding_backend/fastembed_backend.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional from haystack.dataclasses.sparse_embedding import SparseEmbedding from tqdm import tqdm @@ -73,14 +73,15 @@ def get_embedding_backend( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + bm25: Optional[Dict[str, Any]] = None, ): - embedding_backend_id = f"{model_name}{cache_dir}{threads}" + embedding_backend_id = f"{model_name}{cache_dir}{threads}{bm25}" if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances: return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] embedding_backend = _FastembedSparseEmbeddingBackend( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only, bm25=bm25 ) _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend return embedding_backend @@ -97,9 +98,14 @@ def __init__( cache_dir: Optional[str] = None, threads: Optional[int] = None, local_files_only: bool = False, + bm25: Optional[Dict[str, Any]] = None, ): self.model = SparseTextEmbedding( - model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only + model_name=model_name, + cache_dir=cache_dir, + threads=threads, + local_files_only=local_files_only, + **(bm25 if bm25 else {}), ) def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index 4b72389fa..743cf218d 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -62,6 +62,7 @@ def __init__( local_files_only: bool = False, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + bm25: Optional[Dict[str, Any]] = None, ): """ Create an FastembedDocumentEmbedder component. @@ -81,6 +82,7 @@ def __init__( :param local_files_only: If `True`, only use the model files in the `cache_dir`. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + :param bm25: Dictionary containing BM25 parameters (`k`, `b`, `avg_len`, `language`, `token_max_length`). """ self.model_name = model @@ -92,6 +94,7 @@ def __init__( self.local_files_only = local_files_only self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.bm25 = bm25 if model == "Qdrant/bm25" else None def to_dict(self) -> Dict[str, Any]: """ @@ -110,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]: local_files_only=self.local_files_only, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + bm25=self.bm25, ) def warm_up(self): @@ -122,6 +126,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + bm25=self.bm25, ) def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index 67348b2bd..659642939 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -35,6 +35,7 @@ def __init__( progress_bar: bool = True, parallel: Optional[int] = None, local_files_only: bool = False, + bm25: Optional[Dict[str, Any]] = None, ): """ Create a FastembedSparseTextEmbedder component. @@ -50,6 +51,7 @@ def __init__( If 0, use all available cores. If None, don't use data-parallel processing, use default onnxruntime threading instead. :param local_files_only: If `True`, only use the model files in the `cache_dir`. + :param bm25: Dictionary containing BM25 parameters (`k`, `b`, `avg_len`, `language`, `token_max_length`). """ self.model_name = model @@ -58,6 +60,7 @@ def __init__( self.progress_bar = progress_bar self.parallel = parallel self.local_files_only = local_files_only + self.bm25 = bm25 if model == "Qdrant/bm25" else None def to_dict(self) -> Dict[str, Any]: """ @@ -74,6 +77,7 @@ def to_dict(self) -> Dict[str, Any]: progress_bar=self.progress_bar, parallel=self.parallel, local_files_only=self.local_files_only, + bm25=self.bm25, ) def warm_up(self): @@ -86,6 +90,7 @@ def warm_up(self): cache_dir=self.cache_dir, threads=self.threads, local_files_only=self.local_files_only, + bm25=self.bm25, ) @component.output_types(sparse_embedding=SparseEmbedding) diff --git a/integrations/fastembed/tests/test_fastembed_backend.py b/integrations/fastembed/tests/test_fastembed_backend.py index 631d9f1e0..ec547c4fb 100644 --- a/integrations/fastembed/tests/test_fastembed_backend.py +++ b/integrations/fastembed/tests/test_fastembed_backend.py @@ -2,6 +2,7 @@ from haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend import ( _FastembedEmbeddingBackendFactory, + _FastembedSparseEmbeddingBackendFactory, ) @@ -44,3 +45,28 @@ def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001 embedding_backend.model.embed.assert_called_once_with(data) # restore the factory stateTrue _FastembedEmbeddingBackendFactory._instances = {} + + +@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.SparseTextEmbedding") +def test_bm25_model_initialization(mock_instructor): + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 40, + } + + # Invoke the backend factory with the BM25 configuration + _FastembedSparseEmbeddingBackendFactory.get_embedding_backend( + model_name="Qdrant/bm25", + bm25=bm25_config, + ) + + # Check if SparseTextEmbedding was called with the correct arguments + mock_instructor.assert_called_once_with( + model_name="Qdrant/bm25", cache_dir=None, threads=None, local_files_only=False, **bm25_config + ) + + # Restore factory state after the test + _FastembedSparseEmbeddingBackendFactory._instances = {} diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index d3f2023b8..1d266a1aa 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -69,6 +69,7 @@ def test_to_dict(self): "local_files_only": False, "embedding_separator": "\n", "meta_fields_to_embed": [], + "bm25": None, }, } @@ -100,6 +101,7 @@ def test_to_dict_with_custom_init_parameters(self): "local_files_only": True, "meta_fields_to_embed": ["test_field"], "embedding_separator": " | ", + "bm25": None, }, } @@ -174,7 +176,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, bm25=None ) @patch( @@ -275,6 +277,70 @@ def test_embed_metadata(self): parallel=None, ) + def test_init_with_bm25_parameters(self): + """ + Test initialization of FastembedSparseDocumentEmbedder with BM25 parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm25", + bm25=bm25_config, + ) + + assert embedder.bm25 == bm25_config + + def test_bm25_not_passed_for_non_bm25_model(self): + """ + Test that BM25 parameters are not used if model is not "Qdrant/bm25". + """ + bm25_config = { + "k": 1.5, + "b": 0.9, + "avg_len": 250.0, + } + + embedder = FastembedSparseDocumentEmbedder( + model="prithvida/Splade_PP_en_v1", + bm25=bm25_config, + ) + assert embedder.bm25 is None + + @pytest.mark.integration + def test_run_with_bm25(self): + """ + Integration test to check the embedding with bm25 parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 256.0, + } + + embedder = FastembedSparseDocumentEmbedder( + model="Qdrant/bm25", + bm25=bm25_config, + ) + embedder.warm_up() + + doc = Document(content="Example content using BM25") + + result = embedder.run(documents=[doc]) + embedding = result["documents"][0].sparse_embedding + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseDocumentEmbedder( diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index 7e9197493..70a2afc10 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -54,6 +54,7 @@ def test_to_dict(self): "progress_bar": True, "parallel": None, "local_files_only": False, + "bm25": None, }, } @@ -79,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self): "progress_bar": False, "parallel": 1, "local_files_only": True, + "bm25": None, }, } @@ -135,7 +137,7 @@ def test_warmup(self, mocked_factory): mocked_factory.get_embedding_backend.assert_not_called() embedder.warm_up() mocked_factory.get_embedding_backend.assert_called_once_with( - model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False + model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, bm25=None ) @patch( @@ -195,6 +197,70 @@ def test_run_wrong_incorrect_format(self): with pytest.raises(TypeError, match="FastembedSparseTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) + def test_init_with_bm25_parameters(self): + """ + Test initialization of FastembedSparseTextEmbedder with BM25 parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 300.0, + "language": "english", + "token_max_length": 50, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + bm25=bm25_config, + ) + + assert embedder.bm25 == bm25_config + + def test_bm25_not_passed_for_non_bm25_model(self): + """ + Test that BM25 parameters are not used if model is not "Qdrant/bm25". + """ + bm25_config = { + "k": 1.5, + "b": 0.9, + "avg_len": 250.0, + } + + embedder = FastembedSparseTextEmbedder( + model="prithvida/Splade_PP_en_v1", + bm25=bm25_config, + ) + assert embedder.bm25 is None + + @pytest.mark.integration + def test_run_with_bm25(self): + """ + Integration test to check the embedding with bm25 parameters. + """ + bm25_config = { + "k": 1.2, + "b": 0.75, + "avg_len": 256.0, + } + + embedder = FastembedSparseTextEmbedder( + model="Qdrant/bm25", + bm25=bm25_config, + ) + embedder.warm_up() + + text = "Example content using BM25" + + result = embedder.run(text=text) + embedding = result["sparse_embedding"] + embedding_dict = embedding.to_dict() + + assert isinstance(embedding, SparseEmbedding) + assert isinstance(embedding_dict["indices"], list) + assert isinstance(embedding_dict["values"], list) + assert isinstance(embedding_dict["indices"][0], int) + assert isinstance(embedding_dict["values"][0], float) + @pytest.mark.integration def test_run(self): embedder = FastembedSparseTextEmbedder(