Skip to content

Commit

Permalink
Providing unit tests for LiteLLMReranker class.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski committed Oct 18, 2024
1 parent 2ecde29 commit 9046252
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 7 deletions.
10 changes: 4 additions & 6 deletions packages/ragbits-document-search/examples/reranker_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
# ///
import asyncio

from ragbits.core.embeddings import LiteLLMEmbeddings
from ragbits.core.vector_store import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
Expand All @@ -25,10 +22,11 @@

config = {
"embedder": {"type": "LiteLLMEmbeddings"},
"vector_store": {
"type": "InMemoryVectorStore"
"vector_store": {"type": "InMemoryVectorStore"},
"reranker": {
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker",
"config": {"model": "cohere/rerank-english-v3.0"},
},
"reranker": {"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", "config": {"model": "cohere/rerank-english-v3.0"}},
"providers": {"txt": {"type": "DummyProvider"}},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class LiteLLMReranker(Reranker):

model: str
top_n: int | None = None
return_documents: bool = True
return_documents: bool = False
rank_fields: list[str] | None = None
max_chunks_per_doc: int | None = None

Expand Down
61 changes: 61 additions & 0 deletions packages/ragbits-document-search/tests/unit/test_rerankers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path
import pytest
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker


@pytest.fixture
def mock_litellm_response(monkeypatch):
class MockResponse:
results = [{"index": 1}, {"index": 0}]

def mock_rerank(*args, **kwargs):
return MockResponse()

monkeypatch.setattr("litellm.rerank", mock_rerank)


@pytest.fixture
def reranker():
return LiteLLMReranker(
model="test_model",
top_n=2,
return_documents=True,
rank_fields=["content"],
max_chunks_per_doc=1,
)


@pytest.fixture
def mock_document_meta():
return DocumentMeta(document_type=DocumentType.TXT, source=LocalFileSource(path=Path("test.txt")))


@pytest.fixture
def mock_custom_element(mock_document_meta):
class CustomElement(Element):

def get_key(self):
return "test_key"

return CustomElement(element_type="test_type", document_meta=mock_document_meta)


def test_rerank_success(reranker, mock_litellm_response, mock_document_meta):
chunks = [TextElement(content="chunk1", document_meta=mock_document_meta), TextElement(content="chunk2", document_meta=mock_document_meta)]
query = "test query"

reranked_chunks = reranker.rerank(chunks, query)

assert reranked_chunks[0].content == "chunk2"
assert reranked_chunks[1].content == "chunk1"


def test_rerank_invalid_chunks(reranker, mock_custom_element):
chunks = [mock_custom_element]
query = "test query"

with pytest.raises(ValueError, match="All chunks must be TextElement instances"):
reranker.rerank(chunks, query)

0 comments on commit 9046252

Please sign in to comment.