Skip to content

Commit

Permalink
FastembedRanker: remove the backend as suggested
Browse files Browse the repository at this point in the history
  • Loading branch information
paulmartrencharpro committed Nov 13, 2024
1 parent 1ca3d93 commit 383bac8
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from haystack import Document, component, default_from_dict, default_to_dict, logging

from .ranker_backend.fastembed_backend import _FastembedRankerBackendFactory
from fastembed.rerank.cross_encoder import TextCrossEncoder

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -106,8 +106,8 @@ def warm_up(self):
"""
Initializes the component.
"""
if not hasattr(self, "ranker_backend"):
self.ranker_backend = _FastembedRankerBackendFactory.get_ranker_backend(
if not hasattr(self, "ranker"):
self.ranker = TextCrossEncoder(
model_name=self.model_name,
cache_dir=self.cache_dir,
threads=self.threads,
Expand Down Expand Up @@ -164,14 +164,14 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None
msg = f"top_k must be > 0, but got {top_k}"
raise ValueError(msg)

if not hasattr(self, "ranker_backend"):
if not hasattr(self, "ranker"):
msg = "The ranker model has not been loaded. Please call warm_up() before running."
raise RuntimeError(msg)

fastembed_input_docs = self._prepare_fastembed_input_docs(documents)

scores = list(
self.ranker_backend.rerank(
self.ranker.rerank(
query=query,
documents=fastembed_input_docs,
batch_size=self.batch_size,
Expand Down

This file was deleted.

This file was deleted.

27 changes: 2 additions & 25 deletions integrations/fastembed/tests/test_fastembed_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,29 +159,6 @@ def test_from_dict_with_custom_init_parameters(self):
assert ranker.meta_fields_to_embed == ["test_field"]
assert ranker.meta_data_separator == " | "

@patch("haystack_integrations.components.rankers.fastembed.ranker._FastembedRankerBackendFactory")
def test_warmup(self, mocked_factory):
"""
Test for checking ranker instances after warm-up.
"""
ranker = FastembedRanker(model_name="BAAI/bge-reranker-base")
mocked_factory.get_ranker_backend.assert_not_called()
ranker.warm_up()
mocked_factory.get_ranker_backend.assert_called_once_with(
model_name="BAAI/bge-reranker-base", cache_dir=None, threads=None, local_files_only=False
)

@patch("haystack_integrations.components.rankers.fastembed.ranker._FastembedRankerBackendFactory")
def test_warmup_does_not_reload(self, mocked_factory):
"""
Test for checking backend instances after multiple warm-ups.
"""
ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-12-v2")
mocked_factory.get_ranker_backend.assert_not_called()
ranker.warm_up()
ranker.warm_up()
mocked_factory.get_ranker_backend.assert_called_once()

def test_embed_incorrect_input_format(self):
"""
Test for checking incorrect input format when creating embedding.
Expand Down Expand Up @@ -214,13 +191,13 @@ def test_embed_metadata(self):
model_name="model_name",
meta_fields_to_embed=["meta_field"],
)
ranker.ranker_backend = MagicMock()
ranker.ranker = MagicMock()

documents = [Document(content=f"document-number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
query = "test"
ranker.run(query=query, documents=documents)

ranker.ranker_backend.rerank.assert_called_once_with(
ranker.ranker.rerank.assert_called_once_with(
query=query,
documents=[
"meta_value 0\ndocument-number 0",
Expand Down

0 comments on commit 383bac8

Please sign in to comment.