Skip to content

Commit

Permalink
add support for bm25 in FastEmbed integration
Browse files Browse the repository at this point in the history
  • Loading branch information
alperkaya committed Oct 4, 2024
1 parent ddb0c63 commit 1b8507a
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional

from haystack.dataclasses.sparse_embedding import SparseEmbedding
from tqdm import tqdm
Expand Down Expand Up @@ -73,14 +73,15 @@ def get_embedding_backend(
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
bm25: Optional[Dict[str, Any]] = None,
):
embedding_backend_id = f"{model_name}{cache_dir}{threads}"
embedding_backend_id = f"{model_name}{cache_dir}{threads}{bm25}"

if embedding_backend_id in _FastembedSparseEmbeddingBackendFactory._instances:
return _FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id]

embedding_backend = _FastembedSparseEmbeddingBackend(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only, bm25=bm25
)
_FastembedSparseEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
return embedding_backend
Expand All @@ -97,9 +98,14 @@ def __init__(
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
local_files_only: bool = False,
bm25: Optional[Dict[str, Any]] = None,
):
self.model = SparseTextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, local_files_only=local_files_only
model_name=model_name,
cache_dir=cache_dir,
threads=threads,
local_files_only=local_files_only,
**(bm25 if bm25 else {}),
)

def embed(self, data: List[List[str]], progress_bar=True, **kwargs) -> List[SparseEmbedding]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
local_files_only: bool = False,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
bm25: Optional[Dict[str, Any]] = None,
):
"""
Create an FastembedDocumentEmbedder component.
Expand All @@ -81,6 +82,7 @@ def __init__(
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
:param bm25: Dictionary containing BM25 parameters (`k`, `b`, `avg_len`, `language`, `token_max_length`).
"""

self.model_name = model
Expand All @@ -92,6 +94,7 @@ def __init__(
self.local_files_only = local_files_only
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.bm25 = bm25 if model == "Qdrant/bm25" else None

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -110,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]:
local_files_only=self.local_files_only,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
bm25=self.bm25,
)

def warm_up(self):
Expand All @@ -122,6 +126,7 @@ def warm_up(self):
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
bm25=self.bm25,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
progress_bar: bool = True,
parallel: Optional[int] = None,
local_files_only: bool = False,
bm25: Optional[Dict[str, Any]] = None,
):
"""
Create a FastembedSparseTextEmbedder component.
Expand All @@ -50,6 +51,7 @@ def __init__(
If 0, use all available cores.
If None, don't use data-parallel processing, use default onnxruntime threading instead.
:param local_files_only: If `True`, only use the model files in the `cache_dir`.
:param bm25: Dictionary containing BM25 parameters (`k`, `b`, `avg_len`, `language`, `token_max_length`).
"""

self.model_name = model
Expand All @@ -58,6 +60,7 @@ def __init__(
self.progress_bar = progress_bar
self.parallel = parallel
self.local_files_only = local_files_only
self.bm25 = bm25 if model == "Qdrant/bm25" else None

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -74,6 +77,7 @@ def to_dict(self) -> Dict[str, Any]:
progress_bar=self.progress_bar,
parallel=self.parallel,
local_files_only=self.local_files_only,
bm25=self.bm25,
)

def warm_up(self):
Expand All @@ -86,6 +90,7 @@ def warm_up(self):
cache_dir=self.cache_dir,
threads=self.threads,
local_files_only=self.local_files_only,
bm25=self.bm25,
)

@component.output_types(sparse_embedding=SparseEmbedding)
Expand Down
26 changes: 26 additions & 0 deletions integrations/fastembed/tests/test_fastembed_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend import (
_FastembedEmbeddingBackendFactory,
_FastembedSparseEmbeddingBackendFactory,
)


Expand Down Expand Up @@ -44,3 +45,28 @@ def test_embedding_function_with_kwargs(mock_instructor): # noqa: ARG001
embedding_backend.model.embed.assert_called_once_with(data)
# restore the factory stateTrue
_FastembedEmbeddingBackendFactory._instances = {}


@patch("haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend.SparseTextEmbedding")
def test_bm25_model_initialization(mock_instructor):
bm25_config = {
"k": 1.2,
"b": 0.75,
"avg_len": 300.0,
"language": "english",
"token_max_length": 40,
}

# Invoke the backend factory with the BM25 configuration
_FastembedSparseEmbeddingBackendFactory.get_embedding_backend(
model_name="Qdrant/bm25",
bm25=bm25_config,
)

# Check if SparseTextEmbedding was called with the correct arguments
mock_instructor.assert_called_once_with(
model_name="Qdrant/bm25", cache_dir=None, threads=None, local_files_only=False, **bm25_config
)

# Restore factory state after the test
_FastembedSparseEmbeddingBackendFactory._instances = {}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_to_dict(self):
"local_files_only": False,
"embedding_separator": "\n",
"meta_fields_to_embed": [],
"bm25": None,
},
}

Expand Down Expand Up @@ -100,6 +101,7 @@ def test_to_dict_with_custom_init_parameters(self):
"local_files_only": True,
"meta_fields_to_embed": ["test_field"],
"embedding_separator": " | ",
"bm25": None,
},
}

Expand Down Expand Up @@ -174,7 +176,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False
model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, bm25=None
)

@patch(
Expand Down Expand Up @@ -275,6 +277,70 @@ def test_embed_metadata(self):
parallel=None,
)

def test_init_with_bm25_parameters(self):
"""
Test initialization of FastembedSparseDocumentEmbedder with BM25 parameters.
"""
bm25_config = {
"k": 1.2,
"b": 0.75,
"avg_len": 300.0,
"language": "english",
"token_max_length": 50,
}

embedder = FastembedSparseDocumentEmbedder(
model="Qdrant/bm25",
bm25=bm25_config,
)

assert embedder.bm25 == bm25_config

def test_bm25_not_passed_for_non_bm25_model(self):
"""
Test that BM25 parameters are not used if model is not "Qdrant/bm25".
"""
bm25_config = {
"k": 1.5,
"b": 0.9,
"avg_len": 250.0,
}

embedder = FastembedSparseDocumentEmbedder(
model="prithvida/Splade_PP_en_v1",
bm25=bm25_config,
)
assert embedder.bm25 is None

@pytest.mark.integration
def test_run_with_bm25(self):
"""
Integration test to check the embedding with bm25 parameters.
"""
bm25_config = {
"k": 1.2,
"b": 0.75,
"avg_len": 256.0,
}

embedder = FastembedSparseDocumentEmbedder(
model="Qdrant/bm25",
bm25=bm25_config,
)
embedder.warm_up()

doc = Document(content="Example content using BM25")

result = embedder.run(documents=[doc])
embedding = result["documents"][0].sparse_embedding
embedding_dict = embedding.to_dict()

assert isinstance(embedding, SparseEmbedding)
assert isinstance(embedding_dict["indices"], list)
assert isinstance(embedding_dict["values"], list)
assert isinstance(embedding_dict["indices"][0], int)
assert isinstance(embedding_dict["values"][0], float)

@pytest.mark.integration
def test_run(self):
embedder = FastembedSparseDocumentEmbedder(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_to_dict(self):
"progress_bar": True,
"parallel": None,
"local_files_only": False,
"bm25": None,
},
}

Expand All @@ -79,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self):
"progress_bar": False,
"parallel": 1,
"local_files_only": True,
"bm25": None,
},
}

Expand Down Expand Up @@ -135,7 +137,7 @@ def test_warmup(self, mocked_factory):
mocked_factory.get_embedding_backend.assert_not_called()
embedder.warm_up()
mocked_factory.get_embedding_backend.assert_called_once_with(
model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False
model_name="prithvida/Splade_PP_en_v1", cache_dir=None, threads=None, local_files_only=False, bm25=None
)

@patch(
Expand Down Expand Up @@ -195,6 +197,70 @@ def test_run_wrong_incorrect_format(self):
with pytest.raises(TypeError, match="FastembedSparseTextEmbedder expects a string as input"):
embedder.run(text=list_integers_input)

def test_init_with_bm25_parameters(self):
"""
Test initialization of FastembedSparseTextEmbedder with BM25 parameters.
"""
bm25_config = {
"k": 1.2,
"b": 0.75,
"avg_len": 300.0,
"language": "english",
"token_max_length": 50,
}

embedder = FastembedSparseTextEmbedder(
model="Qdrant/bm25",
bm25=bm25_config,
)

assert embedder.bm25 == bm25_config

def test_bm25_not_passed_for_non_bm25_model(self):
"""
Test that BM25 parameters are not used if model is not "Qdrant/bm25".
"""
bm25_config = {
"k": 1.5,
"b": 0.9,
"avg_len": 250.0,
}

embedder = FastembedSparseTextEmbedder(
model="prithvida/Splade_PP_en_v1",
bm25=bm25_config,
)
assert embedder.bm25 is None

@pytest.mark.integration
def test_run_with_bm25(self):
"""
Integration test to check the embedding with bm25 parameters.
"""
bm25_config = {
"k": 1.2,
"b": 0.75,
"avg_len": 256.0,
}

embedder = FastembedSparseTextEmbedder(
model="Qdrant/bm25",
bm25=bm25_config,
)
embedder.warm_up()

text = "Example content using BM25"

result = embedder.run(text=text)
embedding = result["sparse_embedding"]
embedding_dict = embedding.to_dict()

assert isinstance(embedding, SparseEmbedding)
assert isinstance(embedding_dict["indices"], list)
assert isinstance(embedding_dict["values"], list)
assert isinstance(embedding_dict["indices"][0], int)
assert isinstance(embedding_dict["values"][0], float)

@pytest.mark.integration
def test_run(self):
embedder = FastembedSparseTextEmbedder(
Expand Down

0 comments on commit 1b8507a

Please sign in to comment.