From 80de16a16cbdd94b6e1b66dfcbbc1e3a95b2ce0c Mon Sep 17 00:00:00 2001 From: Alan Konarski <129968242+akonarski-ds@users.noreply.github.com> Date: Thu, 19 Dec 2024 11:48:20 +0100 Subject: [PATCH] feat(core): Implement generic Options class (#248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Michał Pstrąg --- examples/document-search/chroma.py | 11 +++- packages/ragbits-core/CHANGELOG.md | 4 ++ .../src/ragbits/core/embeddings/__init__.py | 4 +- .../src/ragbits/core/embeddings/base.py | 20 ++++--- .../src/ragbits/core/embeddings/litellm.py | 34 ++++++++--- .../src/ragbits/core/embeddings/local.py | 27 +++++++-- .../src/ragbits/core/embeddings/noop.py | 8 ++- .../core/embeddings/vertex_multimodal.py | 41 ++++++++----- .../src/ragbits/core/llms/base.py | 58 ++++++------------- .../src/ragbits/core/llms/clients/__init__.py | 3 +- .../src/ragbits/core/llms/clients/base.py | 53 ++--------------- .../src/ragbits/core/llms/clients/litellm.py | 17 +++--- .../src/ragbits/core/llms/clients/local.py | 11 ++-- .../src/ragbits/core/llms/litellm.py | 2 +- .../src/ragbits/core/llms/local.py | 2 +- .../ragbits-core/src/ragbits/core/options.py | 50 ++++++++++++++++ .../src/ragbits/core/{llms => }/types.py | 0 .../src/ragbits/core/utils/config_handling.py | 35 ++++++++++- .../src/ragbits/core/vector_stores/_cli.py | 2 +- .../src/ragbits/core/vector_stores/base.py | 22 ++++--- .../src/ragbits/core/vector_stores/chroma.py | 10 ++-- .../ragbits/core/vector_stores/in_memory.py | 10 ++-- .../src/ragbits/core/vector_stores/qdrant.py | 10 ++-- .../tests/unit/embeddings/test_from_config.py | 18 ++++-- .../ragbits-core/tests/unit/test_options.py | 40 +++++++++++++ .../unit/vector_stores/test_from_config.py | 22 +++---- .../src/ragbits/document_search/_main.py | 6 +- .../retrieval/rerankers/base.py | 42 ++++---------- .../retrieval/rerankers/litellm.py | 12 ++-- .../retrieval/rerankers/noop.py | 4 +- .../tests/unit/test_rerankers.py | 24 ++++---- 31 files changed, 368 insertions(+), 234 deletions(-) create mode 100644 packages/ragbits-core/src/ragbits/core/options.py rename packages/ragbits-core/src/ragbits/core/{llms => }/types.py (100%) create mode 100644 packages/ragbits-core/tests/unit/test_options.py diff --git a/examples/document-search/chroma.py b/examples/document-search/chroma.py index 7fa18480..ce1fefb1 100644 --- a/examples/document-search/chroma.py +++ b/examples/document-search/chroma.py @@ -35,7 +35,8 @@ from chromadb import EphemeralClient -from ragbits.core.embeddings.litellm import LiteLLMEmbeddings +from ragbits.core.embeddings.litellm import LiteLLMEmbeddings, LiteLLMEmbeddingsOptions +from ragbits.core.vector_stores import VectorStoreOptions from ragbits.core.vector_stores.chroma import ChromaVectorStore from ragbits.document_search import DocumentSearch, SearchConfig from ragbits.document_search.documents.document import DocumentMeta @@ -70,10 +71,18 @@ async def main() -> None: """ embedder = LiteLLMEmbeddings( model="text-embedding-3-small", + default_options=LiteLLMEmbeddingsOptions( + dimensions=1024, + timeout=1000, + ), ) vector_store = ChromaVectorStore( client=EphemeralClient(), index_name="jokes", + default_options=VectorStoreOptions( + k=10, + max_distance=0.22, + ), ) document_search = DocumentSearch( embedder=embedder, diff --git a/packages/ragbits-core/CHANGELOG.md b/packages/ragbits-core/CHANGELOG.md index 1fe9c4e0..4a82d99d 100644 --- a/packages/ragbits-core/CHANGELOG.md +++ b/packages/ragbits-core/CHANGELOG.md @@ -2,6 +2,10 @@ ## Unreleased +### Changed + +- Feat: Implement generic Options class (#248). + ## 0.5.1 (2024-12-09) ### Changed diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py b/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py index 825b0062..270df2fe 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/__init__.py @@ -1,5 +1,5 @@ -from .base import Embeddings, EmbeddingType +from .base import Embeddings, EmbeddingsOptionsT, EmbeddingType from .litellm import LiteLLMEmbeddings from .noop import NoopEmbeddings -__all__ = ["EmbeddingType", "Embeddings", "LiteLLMEmbeddings", "NoopEmbeddings"] +__all__ = ["EmbeddingType", "Embeddings", "EmbeddingsOptionsT", "LiteLLMEmbeddings", "NoopEmbeddings"] diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/base.py b/packages/ragbits-core/src/ragbits/core/embeddings/base.py index 130113fa..ac99df0e 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/base.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/base.py @@ -1,9 +1,12 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import ClassVar +from typing import ClassVar, TypeVar from ragbits.core import embeddings -from ragbits.core.utils.config_handling import WithConstructionConfig +from ragbits.core.options import Options +from ragbits.core.utils.config_handling import ConfigurableComponent + +EmbeddingsOptionsT = TypeVar("EmbeddingsOptionsT", bound=Options) class EmbeddingType(Enum): @@ -17,25 +20,27 @@ class EmbeddingType(Enum): allowing for the creation of different embeddings for the same element. """ - TEXT: str = "text" - IMAGE: str = "image" + TEXT = "text" + IMAGE = "image" -class Embeddings(WithConstructionConfig, ABC): +class Embeddings(ConfigurableComponent[EmbeddingsOptionsT], ABC): """ Abstract client for communication with embedding models. """ + options_cls: type[EmbeddingsOptionsT] default_module: ClassVar = embeddings configuration_key: ClassVar = "embedder" @abstractmethod - async def embed_text(self, data: list[str]) -> list[list[float]]: + async def embed_text(self, data: list[str], options: EmbeddingsOptionsT | None = None) -> list[list[float]]: """ Creates embeddings for the given strings. Args: data: List of strings to get embeddings for. + options: Additional settings used by the Embeddings model. Returns: List of embeddings for the given strings. @@ -50,12 +55,13 @@ def image_support(self) -> bool: # noqa: PLR6301 """ return False - async def embed_image(self, images: list[bytes]) -> list[list[float]]: + async def embed_image(self, images: list[bytes], options: EmbeddingsOptionsT | None = None) -> list[list[float]]: """ Creates embeddings for the given images. Args: images: List of images to get embeddings for. + options: Additional settings used by the Embeddings model. Returns: List of embeddings for the given images. diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py index 4dbea167..292a91de 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py @@ -8,17 +8,33 @@ EmbeddingResponseError, EmbeddingStatusError, ) +from ragbits.core.options import Options +from ragbits.core.types import NOT_GIVEN, NotGiven -class LiteLLMEmbeddings(Embeddings): +class LiteLLMEmbeddingsOptions(Options): + """ + Dataclass that represents available call options for the LiteLLMEmbeddingClient client. + Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding#optional-litellm-fields). + """ + + dimensions: int | None | NotGiven = NOT_GIVEN + timeout: int | None | NotGiven = NOT_GIVEN + user: str | None | NotGiven = NOT_GIVEN + encoding_format: str | None | NotGiven = NOT_GIVEN + + +class LiteLLMEmbeddings(Embeddings[LiteLLMEmbeddingsOptions]): """ Client for creating text embeddings using LiteLLM API. """ + options_cls = LiteLLMEmbeddingsOptions + def __init__( self, model: str = "text-embedding-3-small", - options: dict | None = None, + default_options: LiteLLMEmbeddingsOptions | None = None, api_base: str | None = None, api_key: str | None = None, api_version: str | None = None, @@ -29,26 +45,26 @@ def __init__( Args: model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\ to be used. Default is "text-embedding-3-small". - options: Additional options to pass to the LiteLLM API. + default_options: Defualt options to pass to the LiteLLM API. api_base: The API endpoint you want to call the model with. api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, for more information, follow the instructions for your specific vendor in the\ [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). api_version: The API version for the call. """ - super().__init__() + super().__init__(default_options=default_options) self.model = model - self.options = options or {} self.api_base = api_base self.api_key = api_key self.api_version = api_version - async def embed_text(self, data: list[str]) -> list[list[float]]: + async def embed_text(self, data: list[str], options: LiteLLMEmbeddingsOptions | None = None) -> list[list[float]]: """ Creates embeddings for the given strings. Args: data: List of strings to get embeddings for. + options: Additional options to pass to the Lite LLM API. Returns: List of embeddings for the given strings. @@ -59,12 +75,14 @@ async def embed_text(self, data: list[str]) -> list[list[float]]: EmbeddingStatusError: If the embedding API returns an error status code. EmbeddingResponseError: If the embedding API response is invalid. """ + merged_options = (self.default_options | options) if options else self.default_options + with trace( data=data, model=self.model, api_base=self.api_base, api_version=self.api_version, - options=self.options, + options=merged_options.dict(), ) as outputs: try: response = await litellm.aembedding( @@ -73,7 +91,7 @@ async def embed_text(self, data: list[str]) -> list[list[float]]: api_base=self.api_base, api_key=self.api_key, api_version=self.api_version, - **self.options, + **merged_options.dict(), ) except litellm.openai.APIConnectionError as exc: raise EmbeddingConnectionError() from exc diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/local.py b/packages/ragbits-core/src/ragbits/core/embeddings/local.py index a13f7f1d..9fae458e 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/local.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/local.py @@ -1,5 +1,8 @@ from collections.abc import Iterator +from ragbits.core.embeddings import Embeddings +from ragbits.core.options import Options + try: import torch import torch.nn.functional as F @@ -9,24 +12,34 @@ except ImportError: HAS_LOCAL_EMBEDDINGS = False -from ragbits.core.embeddings import Embeddings +class LocalEmbeddingsOptions(Options): + """ + Dataclass that represents available call options for the LocalEmbeddings client. + """ -class LocalEmbeddings(Embeddings): + batch_size: int = 1 + + +class LocalEmbeddings(Embeddings[LocalEmbeddingsOptions]): """ Class for interaction with any encoder available in HuggingFace. """ + options_cls = LocalEmbeddingsOptions + def __init__( self, model_name: str, api_key: str | None = None, + default_options: LocalEmbeddingsOptions | None = None, ) -> None: """Constructs a new local LLM instance. Args: model_name: Name of the model to use. api_key: The API key for Hugging Face authentication. + default_options: Default options for the embedding model. Raises: ImportError: If the 'local' extra requirements are not installed. @@ -34,7 +47,7 @@ def __init__( if not HAS_LOCAL_EMBEDDINGS: raise ImportError("You need to install the 'local' extra requirements to use local embeddings models") - super().__init__() + super().__init__(default_options=default_options) self.hf_api_key = api_key self.model_name = model_name @@ -43,18 +56,20 @@ def __init__( self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key) - async def embed_text(self, data: list[str], batch_size: int = 1) -> list[list[float]]: + async def embed_text(self, data: list[str], options: LocalEmbeddingsOptions | None = None) -> list[list[float]]: """Calls the appropriate encoder endpoint with the given data and options. Args: data: List of strings to get embeddings for. - batch_size: Batch size. + options: Additional options to pass to the embedding model. Returns: List of embeddings for the given strings. """ + merged_options = (self.default_options | options) if options else self.default_options + embeddings = [] - for batch in self._batch(data, batch_size): + for batch in self._batch(data, merged_options.batch_size): batch_dict = self.tokenizer( batch, max_length=self.tokenizer.model_max_length, diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/noop.py b/packages/ragbits-core/src/ragbits/core/embeddings/noop.py index 07d7b3d0..5934e4ce 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/noop.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/noop.py @@ -1,8 +1,9 @@ from ragbits.core.audit import traceable from ragbits.core.embeddings.base import Embeddings +from ragbits.core.options import Options -class NoopEmbeddings(Embeddings): +class NoopEmbeddings(Embeddings[Options]): """ A no-op implementation of the Embeddings class. @@ -11,13 +12,16 @@ class NoopEmbeddings(Embeddings): or as a placeholder when an actual embedding model is not required. """ + options_cls = Options + @traceable - async def embed_text(self, data: list[str]) -> list[list[float]]: # noqa: PLR6301 + async def embed_text(self, data: list[str], options: Options | None = None) -> list[list[float]]: # noqa: PLR6301 """ Embeds a list of strings into a list of vectors. Args: data: A list of input text strings to embed. + options: Additional settings used by the Embeddings model. Returns: A list of embedding vectors, where each vector diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py b/packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py index 8f4bf422..51c40c5c 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py @@ -1,6 +1,8 @@ import asyncio import base64 +from ragbits.core.embeddings.litellm import LiteLLMEmbeddingsOptions + try: import litellm from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import VertexAIError @@ -18,11 +20,12 @@ ) -class VertexAIMultimodelEmbeddings(Embeddings): +class VertexAIMultimodelEmbeddings(Embeddings[LiteLLMEmbeddingsOptions]): """ Client for creating text embeddings using LiteLLM API. """ + options_cls = LiteLLMEmbeddingsOptions VERTEX_AI_PREFIX = "vertex_ai/" def __init__( @@ -31,7 +34,7 @@ def __init__( api_base: str | None = None, api_key: str | None = None, concurency: int = 10, - options: dict | None = None, + default_options: LiteLLMEmbeddingsOptions | None = None, ) -> None: """ Constructs the embedding client for multimodal VertexAI models. @@ -41,7 +44,7 @@ def __init__( api_base: The API endpoint you want to call the model with. api_key: API key to be used. If not specified, an environment variable will be used. concurency: The number of concurrent requests to make to the API. - options: Additional options to pass to the API. + default_options: Additional options to pass to the API. Raises: ImportError: If the 'litellm' extra requirements are not installed. @@ -50,7 +53,7 @@ def __init__( if not HAS_LITELLM: raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models") - super().__init__() + super().__init__(default_options=default_options) if model.startswith(self.VERTEX_AI_PREFIX): model = model[len(self.VERTEX_AI_PREFIX) :] @@ -58,19 +61,19 @@ def __init__( self.api_base = api_base self.api_key = api_key self.concurency = concurency - self.options = options or {} supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS if model not in supported_models: raise ValueError(f"Model {model} is not supported by VertexAI multimodal embeddings") - async def _embed(self, data: list[dict]) -> list[dict]: + async def _embed(self, data: list[dict], options: LiteLLMEmbeddingsOptions | None = None) -> list[dict]: """ Creates embeddings for the given data. The format is defined in the VertexAI API: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings Args: data: List of instances in the format expected by the VertexAI API. + options: Additional options to pass to the VertexAI multimodal embeddings API. Returns: List of embeddings for the given VertexAI instances, each instance is a dictionary @@ -80,16 +83,17 @@ async def _embed(self, data: list[dict]) -> list[dict]: EmbeddingStatusError: If the embedding API returns an error status code. EmbeddingResponseError: If the embedding API response is invalid. """ + merged_options = (self.default_options | options) if options else self.default_options with trace( data=data, model=self.model, api_base=self.api_base, - options=self.options, + options=merged_options.dict(), ) as outputs: semaphore = asyncio.Semaphore(self.concurency) try: response = await asyncio.gather( - *[self._call_litellm(instance, semaphore) for instance in data], + *[self._call_litellm(instance, semaphore, merged_options) for instance in data], ) except VertexAIError as exc: raise EmbeddingStatusError(exc.message, exc.status_code) from exc @@ -102,13 +106,16 @@ async def _embed(self, data: list[dict]) -> list[dict]: return outputs.embeddings - async def _call_litellm(self, instance: dict, semaphore: asyncio.Semaphore) -> litellm.EmbeddingResponse: + async def _call_litellm( + self, instance: dict, semaphore: asyncio.Semaphore, options: LiteLLMEmbeddingsOptions + ) -> litellm.EmbeddingResponse: """ Calls the LiteLLM API to get embeddings for the given data. Args: instance: Single VertexAI instance to get embeddings for. semaphore: Semaphore to limit the number of concurrent requests. + options: Additional options to pass to the VertexAI multimodal embeddings API. Returns: List of embeddings for the given LiteLLM instances. @@ -119,17 +126,18 @@ async def _call_litellm(self, instance: dict, semaphore: asyncio.Semaphore) -> l model=f"{self.VERTEX_AI_PREFIX}{self.model}", api_base=self.api_base, api_key=self.api_key, - **self.options, + **options.dict(), ) return response - async def embed_text(self, data: list[str]) -> list[list[float]]: + async def embed_text(self, data: list[str], options: LiteLLMEmbeddingsOptions | None = None) -> list[list[float]]: """ Creates embeddings for the given strings. Args: data: List of strings to get embeddings for. + options: Additional options to pass to the VertexAI multimodal embeddings API. Returns: List of embeddings for the given strings. @@ -138,7 +146,7 @@ async def embed_text(self, data: list[str]) -> list[list[float]]: EmbeddingStatusError: If the embedding API returns an error status code. EmbeddingResponseError: If the embedding API response is invalid. """ - response = await self._embed([{"text": text} for text in data]) + response = await self._embed([{"text": text} for text in data], options=options) return [embedding["textEmbedding"] for embedding in response] def image_support(self) -> bool: # noqa: PLR6301 @@ -150,12 +158,15 @@ def image_support(self) -> bool: # noqa: PLR6301 """ return True - async def embed_image(self, images: list[bytes]) -> list[list[float]]: + async def embed_image( + self, images: list[bytes], options: LiteLLMEmbeddingsOptions | None = None + ) -> list[list[float]]: """ Creates embeddings for the given images. Args: images: List of images to get embeddings for. + options: Additional options to pass to the VertexAI multimodal embeddings API. Returns: List of embeddings for the given images. @@ -165,6 +176,8 @@ async def embed_image(self, images: list[bytes]) -> list[list[float]]: EmbeddingResponseError: If the embedding API response is invalid. """ images_b64 = (base64.b64encode(image).decode() for image in images) - response = await self._embed([{"image": {"bytesBase64Encoded": image}} for image in images_b64]) + response = await self._embed( + [{"image": {"bytesBase64Encoded": image}} for image in images_b64], options=options + ) return [embedding["imageEmbedding"] for embedding in response] diff --git a/packages/ragbits-core/src/ragbits/core/llms/base.py b/packages/ragbits-core/src/ragbits/core/llms/base.py index 0f46811e..7aa45e98 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/base.py @@ -3,15 +3,12 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from functools import cached_property -from typing import ClassVar, Generic, cast, overload - -from typing_extensions import Self +from typing import ClassVar, cast, overload from ragbits.core import llms +from ragbits.core.llms.clients.base import LLMClient, LLMClientOptionsT from ragbits.core.prompt.base import BasePrompt, BasePromptWithParser, ChatFormat, OutputT -from ragbits.core.utils.config_handling import WithConstructionConfig - -from .clients.base import LLMClient, LLMClientOptions, LLMOptions +from ragbits.core.utils.config_handling import ConfigurableComponent class LLMType(enum.Enum): @@ -24,16 +21,16 @@ class LLMType(enum.Enum): STRUCTURED_OUTPUT = "structured_output" -class LLM(WithConstructionConfig, Generic[LLMClientOptions], ABC): +class LLM(ConfigurableComponent[LLMClientOptionsT], ABC): """ Abstract class for interaction with Large Language Model. """ - _options_cls: type[LLMClientOptions] + options_cls: type[LLMClientOptionsT] default_module: ClassVar = llms configuration_key: ClassVar = "llm" - def __init__(self, model_name: str, default_options: LLMOptions | None = None) -> None: + def __init__(self, model_name: str, default_options: LLMClientOptionsT | None = None) -> None: """ Constructs a new LLM instance. @@ -42,14 +39,14 @@ def __init__(self, model_name: str, default_options: LLMOptions | None = None) - default_options: Default options to be used. Raises: - TypeError: If the subclass is missing the '_options_cls' attribute. + TypeError: If the subclass is missing the 'options_cls' attribute. """ + super().__init__(default_options=default_options) self.model_name = model_name - self.default_options = default_options or self._options_cls() def __init_subclass__(cls) -> None: - if not hasattr(cls, "_options_cls"): - raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute") + if not hasattr(cls, "options_cls"): + raise TypeError(f"Class {cls.__name__} is missing the 'options_cls' attribute") @cached_property @abstractmethod @@ -74,7 +71,7 @@ async def generate_raw( self, prompt: BasePrompt, *, - options: LLMOptions | None = None, + options: LLMClientOptionsT | None = None, ) -> str: """ Prepares and sends a prompt to the LLM and returns the raw response (without parsing). @@ -86,10 +83,10 @@ async def generate_raw( Returns: Raw text response from LLM. """ - options = (self.default_options | options) if options else self.default_options + merged_options = (self.default_options | options) if options else self.default_options response = await self.client.call( conversation=self._format_chat_for_llm(prompt), - options=options, + options=merged_options, json_mode=prompt.json_mode, output_schema=prompt.output_schema(), ) @@ -101,7 +98,7 @@ async def generate( self, prompt: BasePromptWithParser[OutputT], *, - options: LLMOptions | None = None, + options: LLMClientOptionsT | None = None, ) -> OutputT: ... @overload @@ -109,14 +106,14 @@ async def generate( self, prompt: BasePrompt, *, - options: LLMOptions | None = None, + options: LLMClientOptionsT | None = None, ) -> OutputT: ... async def generate( self, prompt: BasePrompt, *, - options: LLMOptions | None = None, + options: LLMClientOptionsT | None = None, ) -> OutputT: """ Prepares and sends a prompt to the LLM and returns response parsed to the @@ -140,7 +137,7 @@ async def generate_streaming( self, prompt: BasePrompt, *, - options: LLMOptions | None = None, + options: LLMClientOptionsT | None = None, ) -> AsyncGenerator[str, None]: """ Prepares and sends a prompt to the LLM and streams the results. @@ -152,10 +149,10 @@ async def generate_streaming( Returns: Response stream from LLM. """ - options = (self.default_options | options) if options else self.default_options + merged_options = (self.default_options | options) if options else self.default_options response = await self.client.call_streaming( conversation=self._format_chat_for_llm(prompt), - options=options, + options=merged_options, json_mode=prompt.json_mode, output_schema=prompt.output_schema(), ) @@ -166,20 +163,3 @@ def _format_chat_for_llm(self, prompt: BasePrompt) -> ChatFormat: if prompt.list_images(): wrngs.warn(message=f"Image input not implemented for {self.__class__.__name__}") return prompt.chat - - @classmethod - def from_config(cls, config: dict) -> Self: - """ - Initializes the class with the provided configuration. - - Args: - config: A dictionary containing configuration details for the class. - - Returns: - An instance of the class initialized with the provided configuration. - """ - default_options = config.pop("default_options", None) - - options = cls._options_cls(**default_options) if default_options else None - - return cls(**config, default_options=options) diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/__init__.py b/packages/ragbits-core/src/ragbits/core/llms/clients/__init__.py index e365c0ce..8e23ece9 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/__init__.py @@ -1,10 +1,9 @@ -from .base import LLMClient, LLMOptions +from .base import LLMClient from .litellm import LiteLLMClient, LiteLLMOptions from .local import LocalLLMClient, LocalLLMOptions __all__ = [ "LLMClient", - "LLMOptions", "LiteLLMClient", "LiteLLMOptions", "LocalLLMClient", diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/base.py b/packages/ragbits-core/src/ragbits/core/llms/clients/base.py index b8df9dd9..8d062484 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/base.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/base.py @@ -1,57 +1,16 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator -from dataclasses import asdict, dataclass -from typing import Any, ClassVar, Generic, TypeVar +from typing import Generic, TypeVar from pydantic import BaseModel +from ragbits.core.options import Options from ragbits.core.prompt import ChatFormat -from ..types import NotGiven +LLMClientOptionsT = TypeVar("LLMClientOptionsT", bound=Options) -LLMClientOptions = TypeVar("LLMClientOptions", bound="LLMOptions") - -@dataclass -class LLMOptions(ABC): - """ - A dataclass that represents all available LLM call options. - """ - - _not_given: ClassVar[Any] = None - - def __or__(self, other: "LLMOptions") -> "LLMOptions": - """ - Merges two LLMOptions, prioritizing non-NOT_GIVEN values from the 'other' object. - """ - self_dict = asdict(self) - other_dict = asdict(other) - - updated_dict = { - key: other_dict.get(key, self_dict[key]) - if not isinstance(other_dict.get(key), NotGiven) - else self_dict[key] - for key in self_dict - } - - return self.__class__(**updated_dict) - - def dict(self) -> dict[str, Any]: - """ - Creates a dictionary representation of the LLMOptions instance. - If a value is None, it will be replaced with a provider-specific not-given sentinel. - - Returns: - A dictionary representation of the LLMOptions instance. - """ - options = asdict(self) - return { - key: self._not_given if value is None or isinstance(value, NotGiven) else value - for key, value in options.items() - } - - -class LLMClient(Generic[LLMClientOptions], ABC): +class LLMClient(Generic[LLMClientOptionsT], ABC): """ Abstract client for a direct communication with LLM. """ @@ -69,7 +28,7 @@ def __init__(self, model_name: str) -> None: async def call( self, conversation: ChatFormat, - options: LLMClientOptions, + options: LLMClientOptionsT, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, ) -> str: @@ -90,7 +49,7 @@ async def call( async def call_streaming( self, conversation: ChatFormat, - options: LLMClientOptions, + options: LLMClientOptionsT, json_mode: bool = False, output_schema: type[BaseModel] | dict | None = None, ) -> AsyncGenerator[str, None]: diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py index a5915d29..f584b377 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py @@ -1,20 +1,23 @@ from collections.abc import AsyncGenerator -from dataclasses import dataclass import litellm from litellm.utils import CustomStreamWrapper, ModelResponse from pydantic import BaseModel from ragbits.core.audit import trace +from ragbits.core.llms.clients.base import LLMClient +from ragbits.core.llms.clients.exceptions import ( + LLMConnectionError, + LLMEmptyResponseError, + LLMResponseError, + LLMStatusError, +) +from ragbits.core.options import Options from ragbits.core.prompt import ChatFormat +from ragbits.core.types import NOT_GIVEN, NotGiven -from ..types import NOT_GIVEN, NotGiven -from .base import LLMClient, LLMOptions -from .exceptions import LLMConnectionError, LLMEmptyResponseError, LLMResponseError, LLMStatusError - -@dataclass -class LiteLLMOptions(LLMOptions): +class LiteLLMOptions(Options): """ Dataclass that represents all available LLM call options for the LiteLLM client. Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/input). diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/local.py b/packages/ragbits-core/src/ragbits/core/llms/clients/local.py index 0d8ddee3..691245ff 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/local.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/local.py @@ -1,7 +1,6 @@ import asyncio import threading from collections.abc import AsyncGenerator -from dataclasses import dataclass from pydantic import BaseModel @@ -14,14 +13,14 @@ except ImportError: HAS_LOCAL_LLM = False -from ragbits.core.prompt import ChatFormat -from ..types import NOT_GIVEN, NotGiven -from .base import LLMClient, LLMOptions +from ragbits.core.llms.clients.base import LLMClient +from ragbits.core.options import Options +from ragbits.core.prompt import ChatFormat +from ragbits.core.types import NOT_GIVEN, NotGiven -@dataclass -class LocalLLMOptions(LLMOptions): +class LocalLLMOptions(Options): """ Dataclass that represents all available LLM call options for the local LLM client. Each of them is described in the [HuggingFace documentation] diff --git a/packages/ragbits-core/src/ragbits/core/llms/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/litellm.py index 13c1ebe7..0714bb92 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/litellm.py @@ -15,7 +15,7 @@ class LiteLLM(LLM[LiteLLMOptions]): Class for interaction with any LLM supported by LiteLLM API. """ - _options_cls = LiteLLMOptions + options_cls = LiteLLMOptions def __init__( self, diff --git a/packages/ragbits-core/src/ragbits/core/llms/local.py b/packages/ragbits-core/src/ragbits/core/llms/local.py index 0d469906..42d23979 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/local.py +++ b/packages/ragbits-core/src/ragbits/core/llms/local.py @@ -18,7 +18,7 @@ class LocalLLM(LLM[LocalLLMOptions]): Class for interaction with any LLM available in HuggingFace. """ - _options_cls = LocalLLMOptions + options_cls = LocalLLMOptions def __init__( self, diff --git a/packages/ragbits-core/src/ragbits/core/options.py b/packages/ragbits-core/src/ragbits/core/options.py new file mode 100644 index 00000000..c9ea884f --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/options.py @@ -0,0 +1,50 @@ +from abc import ABC +from typing import Any, ClassVar, TypeVar + +from pydantic import BaseModel, ConfigDict +from typing_extensions import Self + +from ragbits.core.types import NotGiven + +OptionsT = TypeVar("OptionsT", bound="Options") + + +class Options(BaseModel, ABC): + """ + A dataclass that represents all available options. Thanks to the extra='allow' configuration, it allows for + additional fields that are not defined in the class. + """ + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + _not_given: ClassVar[Any] = None + + def __or__(self, other: "Options") -> Self: + """ + Merges two Options, prioritizing non-NOT_GIVEN values from the 'other' object. + """ + self_dict = self.model_dump() + other_dict = other.model_dump() + + updated_dict = { + key: other_dict[key] + if not isinstance(other_dict.get(key), NotGiven) and key in other_dict + else self_dict[key] + for key in self_dict.keys() | other_dict.keys() + } + + return self.__class__(**updated_dict) + + def dict(self) -> dict[str, Any]: # type: ignore # mypy complains about overriding BaseModel.dict + """ + Creates a dictionary representation of the Options instance. + If a value is None, it will be replaced with a provider-specific not-given sentinel. + + Returns: + A dictionary representation of the Options instance. + """ + options = self.model_dump() + + return { + key: self._not_given if value is None or isinstance(value, NotGiven) else value + for key, value in options.items() + } diff --git a/packages/ragbits-core/src/ragbits/core/llms/types.py b/packages/ragbits-core/src/ragbits/core/types.py similarity index 100% rename from packages/ragbits-core/src/ragbits/core/llms/types.py rename to packages/ragbits-core/src/ragbits/core/types.py diff --git a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py index b6959697..928cc16b 100644 --- a/packages/ragbits-core/src/ragbits/core/utils/config_handling.py +++ b/packages/ragbits-core/src/ragbits/core/utils/config_handling.py @@ -4,11 +4,12 @@ from importlib import import_module from pathlib import Path from types import ModuleType -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic from pydantic import BaseModel from typing_extensions import Self +from ragbits.core.options import OptionsT from ragbits.core.utils._pyproject import get_config_from_yaml if TYPE_CHECKING: @@ -169,3 +170,35 @@ def from_config(cls, config: dict) -> Self: An instance of the class initialized with the provided configuration. """ return cls(**config) + + +class ConfigurableComponent(Generic[OptionsT], WithConstructionConfig): + """ + Base class for components with configurable options. + """ + + options_cls: type[OptionsT] + + def __init__(self, default_options: OptionsT | None = None) -> None: + """ + Constructs a new ConfigurableComponent instance. + + Args: + default_options: The default options for the component. + """ + self.default_options: OptionsT = default_options or self.options_cls() + + @classmethod + def from_config(cls, config: dict[str, Any]) -> ConfigurableComponent: + """ + Initializes the class with the provided configuration. + + Args: + config: A dictionary containing configuration details for the class. + + Returns: + An instance of the class initialized with the provided configuration. + """ + default_options = config.pop("default_options", None) + options = cls.options_cls(**default_options) if default_options else None + return cls(**config, default_options=options) diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py b/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py index 5bac47da..6e49488d 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/_cli.py @@ -96,7 +96,7 @@ async def run() -> None: raise ValueError("Vector store not initialized") try: - embedder = Embeddings.subclass_from_defaults( + embedder: Embeddings = Embeddings.subclass_from_defaults( core_config, factory_path_override=embedder_factory_path, yaml_path_override=Path.cwd() / embedder_yaml_path if embedder_yaml_path else None, diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py index e8adb746..44d0d54a 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py @@ -1,12 +1,13 @@ from abc import ABC, abstractmethod -from typing import ClassVar +from typing import ClassVar, TypeVar from pydantic import BaseModel from typing_extensions import Self from ragbits.core import vector_stores from ragbits.core.metadata_stores.base import MetadataStore -from ragbits.core.utils.config_handling import ObjectContructionConfig, WithConstructionConfig +from ragbits.core.options import Options +from ragbits.core.utils.config_handling import ConfigurableComponent, ObjectContructionConfig WhereQuery = dict[str, str | int | float | bool] @@ -22,7 +23,7 @@ class VectorStoreEntry(BaseModel): metadata: dict -class VectorStoreOptions(BaseModel, ABC): +class VectorStoreOptions(Options): """ An object representing the options for the vector store. """ @@ -31,17 +32,21 @@ class VectorStoreOptions(BaseModel, ABC): max_distance: float | None = None -class VectorStore(WithConstructionConfig, ABC): +VectorStoreOptionsT = TypeVar("VectorStoreOptionsT", bound=VectorStoreOptions) + + +class VectorStore(ConfigurableComponent[VectorStoreOptionsT], ABC): """ A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function. """ + options_cls: type[VectorStoreOptionsT] default_module: ClassVar = vector_stores configuration_key: ClassVar = "vector_store" def __init__( self, - default_options: VectorStoreOptions | None = None, + default_options: VectorStoreOptionsT | None = None, metadata_store: MetadataStore | None = None, ) -> None: """ @@ -51,8 +56,7 @@ def __init__( default_options: The default options for querying the vector store. metadata_store: The metadata store to use. """ - super().__init__() - self._default_options = default_options or VectorStoreOptions() + super().__init__(default_options=default_options) self._metadata_store = metadata_store @classmethod @@ -71,7 +75,7 @@ def from_config(cls, config: dict) -> Self: InvalidConfigError: The metadata_store class can't be found or is not the correct type. """ default_options = config.pop("default_options", None) - options = VectorStoreOptions(**default_options) if default_options else None + options = cls.options_cls(**default_options) if default_options else None store_config = config.pop("metadata_store", None) store = ( @@ -92,7 +96,7 @@ async def store(self, entries: list[VectorStoreEntry]) -> None: """ @abstractmethod - async def retrieve(self, vector: list[float], options: VectorStoreOptions | None = None) -> list[VectorStoreEntry]: + async def retrieve(self, vector: list[float], options: VectorStoreOptionsT | None = None) -> list[VectorStoreEntry]: """ Retrieve entries from the vector store. diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py index d89e4489..cc78136b 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py @@ -11,11 +11,13 @@ from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery -class ChromaVectorStore(VectorStore): +class ChromaVectorStore(VectorStore[VectorStoreOptions]): """ Vector store implementation using [Chroma](https://docs.trychroma.com). """ + options_cls = VectorStoreOptions + def __init__( self, client: ClientAPI, @@ -105,11 +107,11 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None Raises: MetadataNotFoundError: If the metadata is not found. """ - options = self._default_options if options is None else options + merged_options = (self.default_options | options) if options else self.default_options results = self._collection.query( query_embeddings=vector, - n_results=options.k, + n_results=merged_options.k, include=["metadatas", "embeddings", "distances", "documents"], ) @@ -132,7 +134,7 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None ) for batch in zip(ids, metadatas, embeddings, distances, documents, strict=True) for id, metadata, embeddings, distance, document in zip(*batch, strict=True) - if options.max_distance is None or distance <= options.max_distance + if merged_options.max_distance is None or distance <= merged_options.max_distance ] @traceable diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py index 63227442..a163eacc 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py @@ -8,11 +8,13 @@ from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions, WhereQuery -class InMemoryVectorStore(VectorStore): +class InMemoryVectorStore(VectorStore[VectorStoreOptions]): """ A simple in-memory implementation of Vector Store, storing vectors in memory. """ + options_cls = VectorStoreOptions + def __init__( self, default_options: VectorStoreOptions | None = None, @@ -76,7 +78,7 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None Returns: The entries. """ - options = self._default_options if options is None else options + merged_options = (self.default_options | options) if options else self.default_options entries = sorted( ( (entry, float(np.linalg.norm(np.array(entry.vector) - np.array(vector)))) @@ -86,8 +88,8 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None ) return [ entry - for entry, distance in entries[: options.k] - if options.max_distance is None or distance <= options.max_distance + for entry, distance in entries[: merged_options.k] + if merged_options.max_distance is None or distance <= merged_options.max_distance ] @traceable diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py index 122c21c0..8d38830f 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py @@ -12,11 +12,13 @@ from ragbits.core.vector_stores.base import VectorStore, VectorStoreEntry, VectorStoreOptions -class QdrantVectorStore(VectorStore): +class QdrantVectorStore(VectorStore[VectorStoreOptions]): """ Vector store implementation using [Qdrant](https://qdrant.tech). """ + options_cls = VectorStoreOptions + def __init__( self, client: AsyncQdrantClient, @@ -116,13 +118,13 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None Raises: MetadataNotFoundError: If metadata cannot be retrieved """ - options = options or self._default_options - score_threshold = 1 - options.max_distance if options.max_distance else None + merged_options = (self.default_options | options) if options else self.default_options + score_threshold = 1 - merged_options.max_distance if merged_options.max_distance else None results = await self._client.query_points( collection_name=self._index_name, query=vector, - limit=options.k, + limit=merged_options.k, score_threshold=score_threshold, with_payload=True, with_vectors=True, diff --git a/packages/ragbits-core/tests/unit/embeddings/test_from_config.py b/packages/ragbits-core/tests/unit/embeddings/test_from_config.py index c0d56a21..c8a41e08 100644 --- a/packages/ragbits-core/tests/unit/embeddings/test_from_config.py +++ b/packages/ragbits-core/tests/unit/embeddings/test_from_config.py @@ -1,5 +1,6 @@ from ragbits.core.embeddings import Embeddings, NoopEmbeddings -from ragbits.core.embeddings.litellm import LiteLLMEmbeddings +from ragbits.core.embeddings.litellm import LiteLLMEmbeddings, LiteLLMEmbeddingsOptions +from ragbits.core.types import NOT_GIVEN from ragbits.core.utils.config_handling import ObjectContructionConfig @@ -9,20 +10,27 @@ def test_subclass_from_config(): "type": "ragbits.core.embeddings.litellm:LiteLLMEmbeddings", "config": { "model": "some_model", - "options": { + "default_options": { "option1": "value1", "option2": "value2", }, }, } ) - embedding = Embeddings.subclass_from_config(config) + embedding: Embeddings = Embeddings.subclass_from_config(config) assert isinstance(embedding, LiteLLMEmbeddings) assert embedding.model == "some_model" - assert embedding.options == {"option1": "value1", "option2": "value2"} + assert embedding.default_options == LiteLLMEmbeddingsOptions( + dimensions=NOT_GIVEN, + timeout=NOT_GIVEN, + user=NOT_GIVEN, + encoding_format=NOT_GIVEN, + option1="value1", + option2="value2", + ) # type: ignore def test_subclass_from_config_default_path(): config = ObjectContructionConfig.model_validate({"type": "NoopEmbeddings"}) - embedding = Embeddings.subclass_from_config(config) + embedding: Embeddings = Embeddings.subclass_from_config(config) assert isinstance(embedding, NoopEmbeddings) diff --git a/packages/ragbits-core/tests/unit/test_options.py b/packages/ragbits-core/tests/unit/test_options.py new file mode 100644 index 00000000..0195247b --- /dev/null +++ b/packages/ragbits-core/tests/unit/test_options.py @@ -0,0 +1,40 @@ +import pytest + +from ragbits.core.options import Options +from ragbits.core.types import NOT_GIVEN, NotGiven + + +class OptionA(Options): + a: int = 1 + d: int | NotGiven = NOT_GIVEN + + +class OptionsB(Options): + b: int = 2 + e: int | None = None + + +class OptionsC(Options): + a: int = 2 + c: str = "c" + + +@pytest.mark.parametrize( + ("options", "expected"), + [ + (OptionA(), {"a": 1, "d": None}), + (OptionsB(), {"b": 2, "e": None}), + ], +) +def test_default_options(options: Options, expected: dict) -> None: + assert options.dict() == expected + + +def test_merge_options() -> None: + options_a = OptionA() + options_b = OptionsB() + options_c = OptionsC() + + merged = options_a | options_b | options_c + + assert merged.dict() == {"a": 2, "b": 2, "c": "c", "d": None, "e": None} diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_from_config.py b/packages/ragbits-core/tests/unit/vector_stores/test_from_config.py index 3206e3e4..32f9e9e6 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_from_config.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_from_config.py @@ -25,17 +25,17 @@ def test_subclass_from_config(): }, } ) - store = VectorStore.subclass_from_config(config) + store = VectorStore.subclass_from_config(config) # type: ignore assert isinstance(store, InMemoryVectorStore) - assert isinstance(store._default_options, VectorStoreOptions) - assert store._default_options.k == 10 - assert store._default_options.max_distance == 0.22 + assert isinstance(store.default_options, VectorStoreOptions) + assert store.default_options.k == 10 + assert store.default_options.max_distance == 0.22 assert isinstance(store._metadata_store, InMemoryMetadataStore) def test_subclass_from_config_default_path(): config = ObjectContructionConfig.model_validate({"type": "InMemoryVectorStore"}) - store = VectorStore.subclass_from_config(config) + store = VectorStore.subclass_from_config(config) # type: ignore assert isinstance(store, InMemoryVectorStore) @@ -53,12 +53,12 @@ def test_subclass_from_config_chroma_client(): }, } ) - store = VectorStore.subclass_from_config(config) + store = VectorStore.subclass_from_config(config) # type: ignore assert isinstance(store, ChromaVectorStore) assert store._index_name == "some_index" assert isinstance(store._client, ClientAPI) - assert store._default_options.k == 10 - assert store._default_options.max_distance == 0.22 + assert store.default_options.k == 10 + assert store.default_options.max_distance == 0.22 def test_subclass_from_config_drant_client(): @@ -80,10 +80,10 @@ def test_subclass_from_config_drant_client(): }, } ) - store = VectorStore.subclass_from_config(config) + store = VectorStore.subclass_from_config(config) # type: ignore assert isinstance(store, QdrantVectorStore) assert store._index_name == "some_index" assert isinstance(store._client, AsyncQdrantClient) assert isinstance(store._client._client, AsyncQdrantLocal) - assert store._default_options.k == 10 - assert store._default_options.max_distance == 0.22 + assert store.default_options.k == 10 + assert store.default_options.max_distance == 0.22 diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 5db3f289..34682432 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -100,10 +100,10 @@ def from_config(cls, config: dict) -> "DocumentSearch": """ model = DocumentSearchConfig.model_validate(config) - embedder = Embeddings.subclass_from_config(model.embedder) + embedder: Embeddings = Embeddings.subclass_from_config(model.embedder) query_rephraser = QueryRephraser.subclass_from_config(model.rephraser) - reranker = Reranker.subclass_from_config(model.reranker) - vector_store = VectorStore.subclass_from_config(model.vector_store) + reranker: Reranker = Reranker.subclass_from_config(model.reranker) + vector_store: VectorStore = VectorStore.subclass_from_config(model.vector_store) processing_strategy = ProcessingExecutionStrategy.subclass_from_config(model.processing_strategy) providers_config = DocumentProcessorRouter.from_dict_to_providers_config(model.providers) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py index 6b335652..90436b48 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/base.py @@ -1,16 +1,14 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import ClassVar +from typing import ClassVar, TypeVar -from pydantic import BaseModel -from typing_extensions import Self - -from ragbits.core.utils.config_handling import WithConstructionConfig +from ragbits.core.options import Options +from ragbits.core.utils.config_handling import ConfigurableComponent from ragbits.document_search.documents.element import Element from ragbits.document_search.retrieval import rerankers -class RerankerOptions(BaseModel): +class RerankerOptions(Options): """ Options for the reranker. """ @@ -19,44 +17,24 @@ class RerankerOptions(BaseModel): max_chunks_per_doc: int | None = None -class Reranker(WithConstructionConfig, ABC): +RerankerOptionsT = TypeVar("RerankerOptionsT", bound=RerankerOptions) + + +class Reranker(ConfigurableComponent[RerankerOptionsT], ABC): """ Reranks elements retrieved from vector store. """ default_module: ClassVar = rerankers + options_cls: type[RerankerOptionsT] configuration_key: ClassVar = "reranker" - def __init__(self, default_options: RerankerOptions | None = None) -> None: - """ - Constructs a new Reranker instance. - - Args: - default_options: The default options for reranking. - """ - self._default_options = default_options or RerankerOptions() - - @classmethod - def from_config(cls, config: dict) -> Self: - """ - Initializes the class with the provided configuration. - - Args: - config: A dictionary containing configuration details for the class. - - Returns: - An instance of the class initialized with the provided configuration. - """ - default_options = config.pop("default_options", None) - options = RerankerOptions(**default_options) if default_options else None - return cls(**config, default_options=options) - @abstractmethod async def rerank( self, elements: Sequence[Element], query: str, - options: RerankerOptions | None = None, + options: RerankerOptionsT | None = None, ) -> Sequence[Element]: """ Rerank elements. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py index aa877a9c..506b5a54 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/litellm.py @@ -7,11 +7,13 @@ from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions -class LiteLLMReranker(Reranker): +class LiteLLMReranker(Reranker[RerankerOptions]): """ A [LiteLLM](https://docs.litellm.ai/docs/rerank) reranker for providers such as Cohere, Together AI, Azure AI. """ + options_cls = RerankerOptions + def __init__(self, model: str, default_options: RerankerOptions | None = None) -> None: """ Constructs a new LiteLLMReranker instance. @@ -20,7 +22,7 @@ def __init__(self, model: str, default_options: RerankerOptions | None = None) - model: The reranker model to use. default_options: The default options for reranking. """ - super().__init__(default_options) + super().__init__(default_options=default_options) self.model = model @traceable @@ -41,15 +43,15 @@ async def rerank( Returns: The reranked elements. """ - options = self._default_options if options is None else options + merged_options = (self.default_options | options) if options else self.default_options documents = [element.text_representation for element in elements] response = await litellm.arerank( model=self.model, query=query, documents=documents, # type: ignore - top_n=options.top_n, - max_chunks_per_doc=options.max_chunks_per_doc, + top_n=merged_options.top_n, + max_chunks_per_doc=merged_options.max_chunks_per_doc, ) return [elements[result["index"]] for result in response.results] # type: ignore diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py index 681fa7e4..f0e8cd9d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rerankers/noop.py @@ -5,11 +5,13 @@ from ragbits.document_search.retrieval.rerankers.base import Reranker, RerankerOptions -class NoopReranker(Reranker): +class NoopReranker(Reranker[RerankerOptions]): """ A no-op reranker that does not change the order of the elements. """ + options_cls = RerankerOptions + @traceable async def rerank( # noqa: PLR6301 self, diff --git a/packages/ragbits-document-search/tests/unit/test_rerankers.py b/packages/ragbits-document-search/tests/unit/test_rerankers.py index fa483b0d..ff127e76 100644 --- a/packages/ragbits-document-search/tests/unit/test_rerankers.py +++ b/packages/ragbits-document-search/tests/unit/test_rerankers.py @@ -15,6 +15,8 @@ class CustomReranker(Reranker): Custom implementation of Reranker for testing. """ + options_cls = RerankerOptions + async def rerank( # noqa: PLR6301 self, elements: Sequence[Element], query: str, options: RerankerOptions | None = None ) -> Sequence[Element]: @@ -37,8 +39,8 @@ def test_litellm_reranker_from_config() -> None: } ) - assert reranker.model == "test-provder/test-model" - assert reranker._default_options == RerankerOptions(top_n=2, max_chunks_per_doc=None) + assert reranker.model == "test-provder/test-model" # type: ignore + assert reranker.default_options == RerankerOptions(top_n=2, max_chunks_per_doc=None) async def test_litellm_reranker_rerank() -> None: @@ -92,16 +94,16 @@ def test_subclass_from_config(): }, } ) - reranker = Reranker.subclass_from_config(config) + reranker: Reranker = Reranker.subclass_from_config(config) assert isinstance(reranker, NoopReranker) - assert isinstance(reranker._default_options, RerankerOptions) - assert reranker._default_options.top_n == 12 - assert reranker._default_options.max_chunks_per_doc == 42 + assert isinstance(reranker.default_options, RerankerOptions) + assert reranker.default_options.top_n == 12 + assert reranker.default_options.max_chunks_per_doc == 42 def test_subclass_from_config_default_path(): config = ObjectContructionConfig.model_validate({"type": "NoopReranker"}) - reranker = Reranker.subclass_from_config(config) + reranker: Reranker = Reranker.subclass_from_config(config) assert isinstance(reranker, NoopReranker) @@ -118,9 +120,9 @@ def test_subclass_from_config_llm(): }, } ) - reranker = Reranker.subclass_from_config(config) + reranker: Reranker = Reranker.subclass_from_config(config) assert isinstance(reranker, LiteLLMReranker) - assert isinstance(reranker._default_options, RerankerOptions) + assert isinstance(reranker.default_options, RerankerOptions) assert reranker.model == "some_model" - assert reranker._default_options.top_n == 12 - assert reranker._default_options.max_chunks_per_doc == 42 + assert reranker.default_options.top_n == 12 + assert reranker.default_options.max_chunks_per_doc == 42