Skip to content

Commit

Permalink
feat(core): Implement generic Options class (#248)
Browse files Browse the repository at this point in the history
Co-authored-by: Michał Pstrąg <[email protected]>
  • Loading branch information
akonarski-ds and micpst authored Dec 19, 2024
1 parent 94b1e94 commit 80de16a
Show file tree
Hide file tree
Showing 31 changed files with 368 additions and 234 deletions.
11 changes: 10 additions & 1 deletion examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

from chromadb import EphemeralClient

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings, LiteLLMEmbeddingsOptions
from ragbits.core.vector_stores import VectorStoreOptions
from ragbits.core.vector_stores.chroma import ChromaVectorStore
from ragbits.document_search import DocumentSearch, SearchConfig
from ragbits.document_search.documents.document import DocumentMeta
Expand Down Expand Up @@ -70,10 +71,18 @@ async def main() -> None:
"""
embedder = LiteLLMEmbeddings(
model="text-embedding-3-small",
default_options=LiteLLMEmbeddingsOptions(
dimensions=1024,
timeout=1000,
),
)
vector_store = ChromaVectorStore(
client=EphemeralClient(),
index_name="jokes",
default_options=VectorStoreOptions(
k=10,
max_distance=0.22,
),
)
document_search = DocumentSearch(
embedder=embedder,
Expand Down
4 changes: 4 additions & 0 deletions packages/ragbits-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

### Changed

- Feat: Implement generic Options class (#248).

## 0.5.1 (2024-12-09)

### Changed
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base import Embeddings, EmbeddingType
from .base import Embeddings, EmbeddingsOptionsT, EmbeddingType
from .litellm import LiteLLMEmbeddings
from .noop import NoopEmbeddings

__all__ = ["EmbeddingType", "Embeddings", "LiteLLMEmbeddings", "NoopEmbeddings"]
__all__ = ["EmbeddingType", "Embeddings", "EmbeddingsOptionsT", "LiteLLMEmbeddings", "NoopEmbeddings"]
20 changes: 13 additions & 7 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import ClassVar
from typing import ClassVar, TypeVar

from ragbits.core import embeddings
from ragbits.core.utils.config_handling import WithConstructionConfig
from ragbits.core.options import Options
from ragbits.core.utils.config_handling import ConfigurableComponent

EmbeddingsOptionsT = TypeVar("EmbeddingsOptionsT", bound=Options)


class EmbeddingType(Enum):
Expand All @@ -17,25 +20,27 @@ class EmbeddingType(Enum):
allowing for the creation of different embeddings for the same element.
"""

TEXT: str = "text"
IMAGE: str = "image"
TEXT = "text"
IMAGE = "image"


class Embeddings(WithConstructionConfig, ABC):
class Embeddings(ConfigurableComponent[EmbeddingsOptionsT], ABC):
"""
Abstract client for communication with embedding models.
"""

options_cls: type[EmbeddingsOptionsT]
default_module: ClassVar = embeddings
configuration_key: ClassVar = "embedder"

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
async def embed_text(self, data: list[str], options: EmbeddingsOptionsT | None = None) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
options: Additional settings used by the Embeddings model.
Returns:
List of embeddings for the given strings.
Expand All @@ -50,12 +55,13 @@ def image_support(self) -> bool: # noqa: PLR6301
"""
return False

async def embed_image(self, images: list[bytes]) -> list[list[float]]:
async def embed_image(self, images: list[bytes], options: EmbeddingsOptionsT | None = None) -> list[list[float]]:
"""
Creates embeddings for the given images.
Args:
images: List of images to get embeddings for.
options: Additional settings used by the Embeddings model.
Returns:
List of embeddings for the given images.
Expand Down
34 changes: 26 additions & 8 deletions packages/ragbits-core/src/ragbits/core/embeddings/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,33 @@
EmbeddingResponseError,
EmbeddingStatusError,
)
from ragbits.core.options import Options
from ragbits.core.types import NOT_GIVEN, NotGiven


class LiteLLMEmbeddings(Embeddings):
class LiteLLMEmbeddingsOptions(Options):
"""
Dataclass that represents available call options for the LiteLLMEmbeddingClient client.
Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding#optional-litellm-fields).
"""

dimensions: int | None | NotGiven = NOT_GIVEN
timeout: int | None | NotGiven = NOT_GIVEN
user: str | None | NotGiven = NOT_GIVEN
encoding_format: str | None | NotGiven = NOT_GIVEN


class LiteLLMEmbeddings(Embeddings[LiteLLMEmbeddingsOptions]):
"""
Client for creating text embeddings using LiteLLM API.
"""

options_cls = LiteLLMEmbeddingsOptions

def __init__(
self,
model: str = "text-embedding-3-small",
options: dict | None = None,
default_options: LiteLLMEmbeddingsOptions | None = None,
api_base: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
Expand All @@ -29,26 +45,26 @@ def __init__(
Args:
model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
to be used. Default is "text-embedding-3-small".
options: Additional options to pass to the LiteLLM API.
default_options: Defualt options to pass to the LiteLLM API.
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
for more information, follow the instructions for your specific vendor in the\
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
api_version: The API version for the call.
"""
super().__init__()
super().__init__(default_options=default_options)
self.model = model
self.options = options or {}
self.api_base = api_base
self.api_key = api_key
self.api_version = api_version

async def embed_text(self, data: list[str]) -> list[list[float]]:
async def embed_text(self, data: list[str], options: LiteLLMEmbeddingsOptions | None = None) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
options: Additional options to pass to the Lite LLM API.
Returns:
List of embeddings for the given strings.
Expand All @@ -59,12 +75,14 @@ async def embed_text(self, data: list[str]) -> list[list[float]]:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
merged_options = (self.default_options | options) if options else self.default_options

with trace(
data=data,
model=self.model,
api_base=self.api_base,
api_version=self.api_version,
options=self.options,
options=merged_options.dict(),
) as outputs:
try:
response = await litellm.aembedding(
Expand All @@ -73,7 +91,7 @@ async def embed_text(self, data: list[str]) -> list[list[float]]:
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**self.options,
**merged_options.dict(),
)
except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
Expand Down
27 changes: 21 additions & 6 deletions packages/ragbits-core/src/ragbits/core/embeddings/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections.abc import Iterator

from ragbits.core.embeddings import Embeddings
from ragbits.core.options import Options

try:
import torch
import torch.nn.functional as F
Expand All @@ -9,32 +12,42 @@
except ImportError:
HAS_LOCAL_EMBEDDINGS = False

from ragbits.core.embeddings import Embeddings

class LocalEmbeddingsOptions(Options):
"""
Dataclass that represents available call options for the LocalEmbeddings client.
"""

class LocalEmbeddings(Embeddings):
batch_size: int = 1


class LocalEmbeddings(Embeddings[LocalEmbeddingsOptions]):
"""
Class for interaction with any encoder available in HuggingFace.
"""

options_cls = LocalEmbeddingsOptions

def __init__(
self,
model_name: str,
api_key: str | None = None,
default_options: LocalEmbeddingsOptions | None = None,
) -> None:
"""Constructs a new local LLM instance.
Args:
model_name: Name of the model to use.
api_key: The API key for Hugging Face authentication.
default_options: Default options for the embedding model.
Raises:
ImportError: If the 'local' extra requirements are not installed.
"""
if not HAS_LOCAL_EMBEDDINGS:
raise ImportError("You need to install the 'local' extra requirements to use local embeddings models")

super().__init__()
super().__init__(default_options=default_options)

self.hf_api_key = api_key
self.model_name = model_name
Expand All @@ -43,18 +56,20 @@ def __init__(
self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key)

async def embed_text(self, data: list[str], batch_size: int = 1) -> list[list[float]]:
async def embed_text(self, data: list[str], options: LocalEmbeddingsOptions | None = None) -> list[list[float]]:
"""Calls the appropriate encoder endpoint with the given data and options.
Args:
data: List of strings to get embeddings for.
batch_size: Batch size.
options: Additional options to pass to the embedding model.
Returns:
List of embeddings for the given strings.
"""
merged_options = (self.default_options | options) if options else self.default_options

embeddings = []
for batch in self._batch(data, batch_size):
for batch in self._batch(data, merged_options.batch_size):
batch_dict = self.tokenizer(
batch,
max_length=self.tokenizer.model_max_length,
Expand Down
8 changes: 6 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/noop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ragbits.core.audit import traceable
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.options import Options


class NoopEmbeddings(Embeddings):
class NoopEmbeddings(Embeddings[Options]):
"""
A no-op implementation of the Embeddings class.
Expand All @@ -11,13 +12,16 @@ class NoopEmbeddings(Embeddings):
or as a placeholder when an actual embedding model is not required.
"""

options_cls = Options

@traceable
async def embed_text(self, data: list[str]) -> list[list[float]]: # noqa: PLR6301
async def embed_text(self, data: list[str], options: Options | None = None) -> list[list[float]]: # noqa: PLR6301
"""
Embeds a list of strings into a list of vectors.
Args:
data: A list of input text strings to embed.
options: Additional settings used by the Embeddings model.
Returns:
A list of embedding vectors, where each vector
Expand Down
Loading

0 comments on commit 80de16a

Please sign in to comment.