Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): LiteLLM Reranker #109

Merged
merged 19 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions examples/document-search/from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
]

config = {
"embedder": {"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings"},
"embedder": {
"type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings",
},
"vector_store": {
"type": "ragbits.core.vector_stores.chroma:ChromaVectorStore",
"config": {
Expand All @@ -42,7 +44,16 @@
},
},
},
"reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"},
"reranker": {
"type": "ragbits.document_search.retrieval.rerankers.litellm:LiteLLMReranker",
"config": {
"model": "cohere/rerank-english-v3.0",
"default_options": {
"top_n": 3,
"max_chunks_per_doc": None,
},
},
},
"providers": {"txt": {"type": "DummyProvider"}},
"rephraser": {
"type": "LLMQueryRephraser",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers import get_reranker
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions
from ragbits.document_search.retrieval.rerankers.noop import NoopReranker


Expand Down Expand Up @@ -83,7 +83,7 @@ def from_config(cls, config: dict) -> "DocumentSearch":

return cls(embedder, vector_store, query_rephraser, reranker, document_processor_router)

async def search(self, query: str, config: SearchConfig | None = None) -> list[Element]:
async def search(self, query: str, config: SearchConfig | None = None) -> Sequence[Element]:
"""
Search for the most relevant chunks for a query.

Expand All @@ -105,7 +105,11 @@ async def search(self, query: str, config: SearchConfig | None = None) -> list[E
)
elements.extend([Element.from_vector_db_entry(entry) for entry in entries])

return self.reranker.rerank(elements)
return await self.reranker.rerank(
elements=elements,
query=query,
options=RerankerOptions(**config.reranker_kwargs),
)

async def _process_document(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
import sys

from ragbits.core.utils.config_handling import get_cls_from_config

from .base import Reranker
from .noop import NoopReranker
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.noop import NoopReranker

__all__ = ["NoopReranker", "Reranker"]

module = sys.modules[__name__]


def get_reranker(reranker_config: dict | None) -> Reranker:
def get_reranker(config: dict | None = None) -> Reranker:
"""
Initializes and returns a Reranker object based on the provided configuration.

Args:
reranker_config: A dictionary containing configuration details for the Reranker.
config: A dictionary containing configuration details for the Reranker.

Returns:
An instance of the specified Reranker class, initialized with the provided config
(if any) or default arguments.

Raises:
KeyError: If the provided configuration does not contain a valid "type" key.
InvalidConfigurationError: If the provided configuration is invalid.
NotImplementedError: If the specified Reranker class cannot be created from the provided configuration.
"""
if reranker_config is None:
if config is None:
return NoopReranker()

reranker_cls = get_cls_from_config(reranker_config["type"], module)
config = reranker_config.get("config", {})

return reranker_cls(**config)
reranker_cls = get_cls_from_config(config["type"], sys.modules[__name__])
return reranker_cls.from_config(config.get("config", {}))
Original file line number Diff line number Diff line change
@@ -1,22 +1,65 @@
import abc
from abc import ABC, abstractmethod
from collections.abc import Sequence

from pydantic import BaseModel

from ragbits.document_search.documents.element import Element


class Reranker(abc.ABC):
class RerankerOptions(BaseModel):
"""
Options for the reranker.
"""

top_n: int | None = None
max_chunks_per_doc: int | None = None


class Reranker(ABC):
"""
Reranks chunks retrieved from vector store.
Reranks elements retrieved from vector store.
"""

@staticmethod
@abc.abstractmethod
def rerank(chunks: list[Element]) -> list[Element]:
def __init__(self, default_options: RerankerOptions | None = None) -> None:
"""
Constructs a new Reranker instance.

Args:
default_options: The default options for reranking.
"""
self._default_options = default_options or RerankerOptions()

@classmethod
def from_config(cls, config: dict) -> "Reranker":
"""
Creates and returns an instance of the Reranker class from the given configuration.

Args:
config: A dictionary containing the configuration for initializing the Reranker instance.

Returns:
An initialized instance of the Reranker class.

Raises:
NotImplementedError: If the class cannot be created from the provided configuration.
"""
raise NotImplementedError(f"Cannot create class {cls.__name__} from config.")

@abstractmethod
async def rerank(
self,
elements: Sequence[Element],
query: str,
options: RerankerOptions | None = None,
) -> Sequence[Element]:
"""
Rerank chunks.
Rerank elements.

Args:
chunks: The chunks to rerank.
elements: The elements to rerank.
query: The query to rerank the elements against.
options: The options for reranking.

Returns:
The reranked chunks.
The reranked elements.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from collections.abc import Sequence

import litellm

from ragbits.document_search.documents.element import Element
from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions


class LiteLLMReranker(Reranker):
"""
A [LiteLLM](https://docs.litellm.ai/docs/rerank) reranker for providers such as Cohere, Together AI, Azure AI.
"""

def __init__(self, model: str, default_options: RerankerOptions | None = None) -> None:
"""
Constructs a new LiteLLMReranker instance.

Args:
model: The reranker model to use.
default_options: The default options for reranking.
"""
super().__init__(default_options)
self.model = model

@classmethod
def from_config(cls, config: dict) -> "LiteLLMReranker":
"""
Creates and returns an instance of the LiteLLMReranker class from the given configuration.

Args:
config: A dictionary containing the configuration for initializing the LiteLLMReranker instance.

Returns:
An initialized instance of the LiteLLMReranker class.
"""
return cls(
model=config["model"],
default_options=RerankerOptions(**config.get("default_options", {})),
)

async def rerank(
self,
elements: Sequence[Element],
query: str,
options: RerankerOptions | None = None,
) -> Sequence[Element]:
"""
Rerank elements with LiteLLM API.

Args:
elements: The elements to rerank.
query: The query to rerank the elements against.
options: The options for reranking.

Returns:
The reranked elements.
"""
options = self._default_options if options is None else options
documents = [element.get_text_representation() for element in elements]

response = await litellm.arerank(
model=self.model,
query=query,
documents=documents, # type: ignore
top_n=options.top_n,
max_chunks_per_doc=options.max_chunks_per_doc,
)

return [elements[result["index"]] for result in response.results] # type: ignore
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
from collections.abc import Sequence

from ragbits.document_search.documents.element import Element
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions


class NoopReranker(Reranker):
"""
A no-op reranker that does not change the order of the chunks.
A no-op reranker that does not change the order of the elements.
"""

@staticmethod
def rerank(chunks: list[Element]) -> list[Element]:
@classmethod
def from_config(cls, config: dict) -> "NoopReranker":
"""
Creates and returns an instance of the NoopReranker class from the given configuration.

Args:
config: A dictionary containing the configuration for initializing the NoopReranker instance.

Returns:
An initialized instance of the NoopReranker class.
"""
return cls(default_options=RerankerOptions(**config.get("default_options", {})))

async def rerank( # noqa: PLR6301
self,
elements: Sequence[Element],
query: str,
options: RerankerOptions | None = None,
) -> Sequence[Element]:
"""
No reranking, returning the same chunks as in input.
No reranking, returning the elements in the same order.

Args:
chunks: The chunks to rerank.
elements: The elements to rerank.
query: The query to rerank the elements against.
options: The options for reranking.

Returns:
The reranked chunks.
The reranked elements.
"""
return chunks
return elements
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest

from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.documents.element import TextElement
from ragbits.document_search.retrieval.rerankers.base import RerankerOptions
from ragbits.document_search.retrieval.rerankers.litellm import LiteLLMReranker

from ..helpers import env_vars_not_set

COHERE_API_KEY_ENV = "COHERE_API_KEY" # noqa: S105


@pytest.mark.skipif(
env_vars_not_set([COHERE_API_KEY_ENV]),
reason="Cohere API KEY environment variables not set",
)
async def test_litellm_cohere_reranker_rerank() -> None:
options = RerankerOptions(top_n=2, max_chunks_per_doc=None)
reranker = LiteLLMReranker(
model="cohere/rerank-english-v3.0",
default_options=options,
)
elements = [
TextElement(
content="Element 1", document_meta=DocumentMeta.create_text_document_from_literal("Mock document 1")
),
TextElement(
content="Element 2", document_meta=DocumentMeta.create_text_document_from_literal("Mock document 1")
),
TextElement(
content="Element 3", document_meta=DocumentMeta.create_text_document_from_literal("Mock document 1")
),
]
query = "Test query"

results = await reranker.rerank(elements, query)

assert len(results) == 2
Loading
Loading