Skip to content

Commit

Permalink
Config now belongs to an instance. No kwargs needed.
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrykWyzgowski committed Oct 17, 2024
1 parent 8f71200 commit 985fa0d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ async def main():
"""Run the example."""

document_search = DocumentSearch(
embedder=LiteLLMEmbeddings(), vector_store=InMemoryVectorStore(), reranker=LiteLLMReranker()
embedder=LiteLLMEmbeddings(),
vector_store=InMemoryVectorStore(),
reranker=LiteLLMReranker(model="cohere/rerank-english-v3.0"),
)

for document in documents:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig())
entries = await self.vector_store.retrieve(search_vector[0], **search_config.vector_store_kwargs)
elements.extend([Element.from_vector_db_entry(entry) for entry in entries])

if self.reranker and query:
return self.reranker.rerank(elements, query=query)
return self.reranker.rerank(elements)
return self.reranker.rerank(elements, query=query)

async def ingest_document(
self, document: Union[DocumentMeta, Document], document_processor: Optional[BaseProvider] = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import abc
from typing import Any

from pydantic import BaseModel

from ragbits.document_search.documents.element import Element


class Reranker(abc.ABC):
class Reranker(BaseModel, abc.ABC):
"""
Reranks chunks retrieved from vector store.
"""

@staticmethod
@abc.abstractmethod
def rerank(chunks: list[Element], **kwargs: Any) -> list[Element]:
def rerank(self, chunks: list[Element], query: str) -> list[Element]:
"""
Rerank chunks.
Args:
chunks: The chunks to rerank.
kwargs: Additional arguments.
query: The query to rerank the chunks against.
Returns:
The reranked chunks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,19 @@ class LiteLLMReranker(Reranker):
A LiteLLM reranker for providers such as Cohere, Together AI, Azure AI.
"""

@staticmethod
def rerank(chunks: List[Element], **kwargs: Any) -> List[Element]:
model: str
top_n: int | None = None
return_documents: bool = True
rank_fields: list[str] | None = None
max_chunks_per_doc: int | None = None

def rerank(self, chunks: List[Element], query: str) -> List[Element]:
"""
Reranking with LiteLLM API.
Args:
chunks: The chunks to rerank.
kwargs: Additional arguments for the LiteLLM API.
query: The query to rerank the chunks against.
Returns:
The reranked chunks.
Expand All @@ -32,11 +37,13 @@ def rerank(chunks: List[Element], **kwargs: Any) -> List[Element]:
documents = [chunk.content if isinstance(chunk, TextElement) else None for chunk in chunks]

response = litellm.rerank(
model="cohere/rerank-english-v3.0",
query=kwargs.get("query"),
model=self.model,
query=query,
documents=documents,
top_n=kwargs.get("top_n"),
return_documents=False,
top_n=self.top_n,
return_documents=self.return_documents,
rank_fields=self.rank_fields,
max_chunks_per_doc=self.max_chunks_per_doc,
)
target_order = [result["index"] for result in response.results]
reranked_chunks = [chunks[i] for i in target_order]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@ class NoopReranker(Reranker):
A no-op reranker that does not change the order of the chunks.
"""

@staticmethod
def rerank(chunks: List[Element], **kwargs: Any) -> List[Element]: # pylint: disable=unused-argument
def rerank(self, chunks: List[Element], query:str) -> List[Element]: # pylint: disable=unused-argument
"""
No reranking, returning the same chunks as in input.
Args:
chunks: The chunks to rerank.
kwargs: Additional arguments.
query: The query to rerank the chunks against.
Returns:
The reranked chunks.
Expand Down

0 comments on commit 985fa0d

Please sign in to comment.