diff --git a/packages/ragbits-document-search/examples/reranker_example.py b/packages/ragbits-document-search/examples/reranker_example.py index 7fc8ca4a..b915447d 100644 --- a/packages/ragbits-document-search/examples/reranker_example.py +++ b/packages/ragbits-document-search/examples/reranker_example.py @@ -28,7 +28,9 @@ async def main(): """Run the example.""" document_search = DocumentSearch( - embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore(), reranker=LiteLLMReranker() + embedder=LiteLLMEmbeddings(), + vector_store=InMemoryVectorStore(), + reranker=LiteLLMReranker(model="cohere/rerank-english-v3.0"), ) for document in documents: diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 910f76c5..4914ea8a 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -101,9 +101,7 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig()) entries = await self.vector_store.retrieve(search_vector[0], **search_config.vector_store_kwargs) elements.extend([Element.from_vector_db_entry(entry) for entry in entries]) - if self.reranker and query: - return self.reranker.rerank(elements, query=query) - return self.reranker.rerank(elements) + return self.reranker.rerank(elements, query=query) async def ingest_document( self, document: Union[DocumentMeta, Document], document_processor: Optional[BaseProvider] = None diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py index 43d0f4b2..b8e91978 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py @@ -1,23 +1,24 @@ import abc from typing import Any +from pydantic import BaseModel + from ragbits.document_search.documents.element import Element -class Reranker(abc.ABC): +class Reranker(BaseModel, abc.ABC): """ Reranks chunks retrieved from vector store. """ - @staticmethod @abc.abstractmethod - def rerank(chunks: list[Element], **kwargs: Any) -> list[Element]: + def rerank(self, chunks: list[Element], query: str) -> list[Element]: """ Rerank chunks. Args: chunks: The chunks to rerank. - kwargs: Additional arguments. + query: The query to rerank the chunks against. Returns: The reranked chunks. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py index 3f13c930..383f7463 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py @@ -11,14 +11,19 @@ class LiteLLMReranker(Reranker): A LiteLLM reranker for providers such as Cohere, Together AI, Azure AI. """ - @staticmethod - def rerank(chunks: List[Element], **kwargs: Any) -> List[Element]: + model: str + top_n: int | None = None + return_documents: bool = True + rank_fields: list[str] | None = None + max_chunks_per_doc: int | None = None + + def rerank(self, chunks: List[Element], query: str) -> List[Element]: """ Reranking with LiteLLM API. Args: chunks: The chunks to rerank. - kwargs: Additional arguments for the LiteLLM API. + query: The query to rerank the chunks against. Returns: The reranked chunks. @@ -32,11 +37,13 @@ def rerank(chunks: List[Element], **kwargs: Any) -> List[Element]: documents = [chunk.content if isinstance(chunk, TextElement) else None for chunk in chunks] response = litellm.rerank( - model="cohere/rerank-english-v3.0", - query=kwargs.get("query"), + model=self.model, + query=query, documents=documents, - top_n=kwargs.get("top_n"), - return_documents=False, + top_n=self.top_n, + return_documents=self.return_documents, + rank_fields=self.rank_fields, + max_chunks_per_doc=self.max_chunks_per_doc, ) target_order = [result["index"] for result in response.results] reranked_chunks = [chunks[i] for i in target_order] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py index fad6ea9a..e29539a1 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py @@ -9,14 +9,13 @@ class NoopReranker(Reranker): A no-op reranker that does not change the order of the chunks. """ - @staticmethod - def rerank(chunks: List[Element], **kwargs: Any) -> List[Element]: # pylint: disable=unused-argument + def rerank(self, chunks: List[Element], query:str) -> List[Element]: # pylint: disable=unused-argument """ No reranking, returning the same chunks as in input. Args: chunks: The chunks to rerank. - kwargs: Additional arguments. + query: The query to rerank the chunks against. Returns: The reranked chunks.