From 904625266b2b08aa501d8a9f894d6f20de58798f Mon Sep 17 00:00:00 2001 From: PatrykWyzgowski Date: Fri, 18 Oct 2024 15:11:35 +0200 Subject: [PATCH] Providing unit tests for LiteLLMReranker class. --- .../examples/reranker_example.py | 10 ++- .../retrieval/rerankers/litellm.py | 2 +- .../tests/unit/test_rerankers.py | 61 +++++++++++++++++++ 3 files changed, 66 insertions(+), 7 deletions(-) create mode 100644 packages/ragbits-document-search/tests/unit/test_rerankers.py diff --git a/packages/ragbits-document-search/examples/reranker_example.py b/packages/ragbits-document-search/examples/reranker_example.py index 856a099e..57f18995 100644 --- a/packages/ragbits-document-search/examples/reranker_example.py +++ b/packages/ragbits-document-search/examples/reranker_example.py @@ -7,11 +7,8 @@ # /// import asyncio -from ragbits.core.embeddings import LiteLLMEmbeddings -from ragbits.core.vector_store import InMemoryVectorStore from ragbits.document_search import DocumentSearch from ragbits.document_search.documents.document import DocumentMeta -from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker documents = [ DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."), @@ -25,10 +22,11 @@ config = { "embedder": {"type": "LiteLLMEmbeddings"}, - "vector_store": { - "type": "InMemoryVectorStore" + "vector_store": {"type": "InMemoryVectorStore"}, + "reranker": { + "type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", + "config": {"model": "cohere/rerank-english-v3.0"}, }, - "reranker": {"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker", "config": {"model": "cohere/rerank-english-v3.0"}}, "providers": {"txt": {"type": "DummyProvider"}}, } diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py index 8e517d6d..5312bff7 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py @@ -13,7 +13,7 @@ class LiteLLMReranker(Reranker): model: str top_n: int | None = None - return_documents: bool = True + return_documents: bool = False rank_fields: list[str] | None = None max_chunks_per_doc: int | None = None diff --git a/packages/ragbits-document-search/tests/unit/test_rerankers.py b/packages/ragbits-document-search/tests/unit/test_rerankers.py new file mode 100644 index 00000000..7cadfcf0 --- /dev/null +++ b/packages/ragbits-document-search/tests/unit/test_rerankers.py @@ -0,0 +1,61 @@ +from pathlib import Path +import pytest +from ragbits.document_search.documents.document import DocumentMeta, DocumentType +from ragbits.document_search.documents.element import Element, TextElement +from ragbits.document_search.documents.sources import LocalFileSource +from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker + + +@pytest.fixture +def mock_litellm_response(monkeypatch): + class MockResponse: + results = [{"index": 1}, {"index": 0}] + + def mock_rerank(*args, **kwargs): + return MockResponse() + + monkeypatch.setattr("litellm.rerank", mock_rerank) + + +@pytest.fixture +def reranker(): + return LiteLLMReranker( + model="test_model", + top_n=2, + return_documents=True, + rank_fields=["content"], + max_chunks_per_doc=1, + ) + + +@pytest.fixture +def mock_document_meta(): + return DocumentMeta(document_type=DocumentType.TXT, source=LocalFileSource(path=Path("test.txt"))) + + +@pytest.fixture +def mock_custom_element(mock_document_meta): + class CustomElement(Element): + + def get_key(self): + return "test_key" + + return CustomElement(element_type="test_type", document_meta=mock_document_meta) + + +def test_rerank_success(reranker, mock_litellm_response, mock_document_meta): + chunks = [TextElement(content="chunk1", document_meta=mock_document_meta), TextElement(content="chunk2", document_meta=mock_document_meta)] + query = "test query" + + reranked_chunks = reranker.rerank(chunks, query) + + assert reranked_chunks[0].content == "chunk2" + assert reranked_chunks[1].content == "chunk1" + + +def test_rerank_invalid_chunks(reranker, mock_custom_element): + chunks = [mock_custom_element] + query = "test query" + + with pytest.raises(ValueError, match="All chunks must be TextElement instances"): + reranker.rerank(chunks, query)