Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/226-cli-event-handler' into 226-…
Browse files Browse the repository at this point in the history
…cli-event-handler
  • Loading branch information
kzamlynska committed Dec 19, 2024
2 parents dc5164a + 723e86e commit 4ca2cef
Show file tree
Hide file tree
Showing 33 changed files with 455 additions and 259 deletions.
11 changes: 10 additions & 1 deletion examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

from chromadb import EphemeralClient

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.core.embeddings.litellm import LiteLLMEmbeddings, LiteLLMEmbeddingsOptions
from ragbits.core.vector_stores import VectorStoreOptions
from ragbits.core.vector_stores.chroma import ChromaVectorStore
from ragbits.document_search import DocumentSearch, SearchConfig
from ragbits.document_search.documents.document import DocumentMeta
Expand Down Expand Up @@ -70,10 +71,18 @@ async def main() -> None:
"""
embedder = LiteLLMEmbeddings(
model="text-embedding-3-small",
default_options=LiteLLMEmbeddingsOptions(
dimensions=1024,
timeout=1000,
),
)
vector_store = ChromaVectorStore(
client=EphemeralClient(),
index_name="jokes",
default_options=VectorStoreOptions(
k=10,
max_distance=0.22,
),
)
document_search = DocumentSearch(
embedder=embedder,
Expand Down
64 changes: 64 additions & 0 deletions packages/ragbits-cli/src/ragbits/cli/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pathlib import Path
from typing import Protocol, TypeVar

import typer
from pydantic.alias_generators import to_snake
from rich.console import Console

from ragbits.core.config import CoreConfig, core_config
from ragbits.core.utils.config_handling import InvalidConfigError, NoDefaultConfigError, WithConstructionConfig

WithConstructionConfigT_co = TypeVar("WithConstructionConfigT_co", bound=WithConstructionConfig, covariant=True)


# Using a Protocol instead of simply typing the `cls` argument to `get_instance_or_exit`
# as `type[WithConstructionConfigT]` in order to workaround the issue of mypy not allowing abstract classes
# to be used as types: https://github.com/python/mypy/issues/4717
class WithConstructionConfigProtocol(Protocol[WithConstructionConfigT_co]):
@classmethod
def subclass_from_defaults(
cls, defaults: CoreConfig, factory_path_override: str | None = None, yaml_path_override: Path | None = None
) -> WithConstructionConfigT_co: ...


def get_instance_or_exit(
cls: WithConstructionConfigProtocol[WithConstructionConfigT_co],
type_name: str | None = None,
yaml_path: Path | None = None,
factory_path: str | None = None,
yaml_path_argument_name: str = "--yaml-path",
factory_path_argument_name: str = "--factory-path",
) -> WithConstructionConfigT_co:
"""
Returns an instance of the provided class, initialized using its `subclass_from_defaults` method.
If the instance can't be created, prints an error message and exits the program.
Args:
cls: The class to create an instance of.
type_name: The name to use in error messages. If None, inferred from the class name.
yaml_path: Path to a YAML configuration file to use for initialization.
factory_path: Python path to a factory function to use for initialization.
yaml_path_argument_name: The name of the argument to use in error messages for the YAML path.
factory_path_argument_name: The name of the argument to use in error messages for the factory path.
"""
if not isinstance(cls, type):
raise TypeError(f"get_instance_or_exit expects the `cls` argument to be a class, got {cls}")

type_name = type_name or to_snake(cls.__name__).replace("_", " ")
try:
return cls.subclass_from_defaults(
core_config,
factory_path_override=factory_path,
yaml_path_override=yaml_path,
)
except NoDefaultConfigError as e:
Console(
stderr=True
).print(f"""You need to provide the [b]{type_name}[/b] instance be used. You can do this by either:
- providing a path to a YAML configuration file with the [b]{yaml_path_argument_name}[/b] option
- providing a Python path to a function that creates a vector store with the [b]{factory_path_argument_name}[/b] option
- setting the default configuration or factory function in your project's [b]pyproject.toml[/b] file""")
raise typer.Exit(1) from e
except InvalidConfigError as e:
Console(stderr=True).print(e)
raise typer.Exit(1) from e
4 changes: 4 additions & 0 deletions packages/ragbits-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Unreleased

### Changed

- Feat: Implement generic Options class (#248).

## 0.5.1 (2024-12-09)

### Changed
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base import Embeddings, EmbeddingType
from .base import Embeddings, EmbeddingsOptionsT, EmbeddingType
from .litellm import LiteLLMEmbeddings
from .noop import NoopEmbeddings

__all__ = ["EmbeddingType", "Embeddings", "LiteLLMEmbeddings", "NoopEmbeddings"]
__all__ = ["EmbeddingType", "Embeddings", "EmbeddingsOptionsT", "LiteLLMEmbeddings", "NoopEmbeddings"]
20 changes: 13 additions & 7 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import ClassVar
from typing import ClassVar, TypeVar

from ragbits.core import embeddings
from ragbits.core.utils.config_handling import WithConstructionConfig
from ragbits.core.options import Options
from ragbits.core.utils.config_handling import ConfigurableComponent

EmbeddingsOptionsT = TypeVar("EmbeddingsOptionsT", bound=Options)


class EmbeddingType(Enum):
Expand All @@ -17,25 +20,27 @@ class EmbeddingType(Enum):
allowing for the creation of different embeddings for the same element.
"""

TEXT: str = "text"
IMAGE: str = "image"
TEXT = "text"
IMAGE = "image"


class Embeddings(WithConstructionConfig, ABC):
class Embeddings(ConfigurableComponent[EmbeddingsOptionsT], ABC):
"""
Abstract client for communication with embedding models.
"""

options_cls: type[EmbeddingsOptionsT]
default_module: ClassVar = embeddings
configuration_key: ClassVar = "embedder"

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
async def embed_text(self, data: list[str], options: EmbeddingsOptionsT | None = None) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
options: Additional settings used by the Embeddings model.
Returns:
List of embeddings for the given strings.
Expand All @@ -50,12 +55,13 @@ def image_support(self) -> bool: # noqa: PLR6301
"""
return False

async def embed_image(self, images: list[bytes]) -> list[list[float]]:
async def embed_image(self, images: list[bytes], options: EmbeddingsOptionsT | None = None) -> list[list[float]]:
"""
Creates embeddings for the given images.
Args:
images: List of images to get embeddings for.
options: Additional settings used by the Embeddings model.
Returns:
List of embeddings for the given images.
Expand Down
34 changes: 26 additions & 8 deletions packages/ragbits-core/src/ragbits/core/embeddings/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,33 @@
EmbeddingResponseError,
EmbeddingStatusError,
)
from ragbits.core.options import Options
from ragbits.core.types import NOT_GIVEN, NotGiven


class LiteLLMEmbeddings(Embeddings):
class LiteLLMEmbeddingsOptions(Options):
"""
Dataclass that represents available call options for the LiteLLMEmbeddingClient client.
Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding#optional-litellm-fields).
"""

dimensions: int | None | NotGiven = NOT_GIVEN
timeout: int | None | NotGiven = NOT_GIVEN
user: str | None | NotGiven = NOT_GIVEN
encoding_format: str | None | NotGiven = NOT_GIVEN


class LiteLLMEmbeddings(Embeddings[LiteLLMEmbeddingsOptions]):
"""
Client for creating text embeddings using LiteLLM API.
"""

options_cls = LiteLLMEmbeddingsOptions

def __init__(
self,
model: str = "text-embedding-3-small",
options: dict | None = None,
default_options: LiteLLMEmbeddingsOptions | None = None,
api_base: str | None = None,
api_key: str | None = None,
api_version: str | None = None,
Expand All @@ -29,26 +45,26 @@ def __init__(
Args:
model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
to be used. Default is "text-embedding-3-small".
options: Additional options to pass to the LiteLLM API.
default_options: Defualt options to pass to the LiteLLM API.
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
for more information, follow the instructions for your specific vendor in the\
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
api_version: The API version for the call.
"""
super().__init__()
super().__init__(default_options=default_options)
self.model = model
self.options = options or {}
self.api_base = api_base
self.api_key = api_key
self.api_version = api_version

async def embed_text(self, data: list[str]) -> list[list[float]]:
async def embed_text(self, data: list[str], options: LiteLLMEmbeddingsOptions | None = None) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
options: Additional options to pass to the Lite LLM API.
Returns:
List of embeddings for the given strings.
Expand All @@ -59,12 +75,14 @@ async def embed_text(self, data: list[str]) -> list[list[float]]:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
merged_options = (self.default_options | options) if options else self.default_options

with trace(
data=data,
model=self.model,
api_base=self.api_base,
api_version=self.api_version,
options=self.options,
options=merged_options.dict(),
) as outputs:
try:
response = await litellm.aembedding(
Expand All @@ -73,7 +91,7 @@ async def embed_text(self, data: list[str]) -> list[list[float]]:
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**self.options,
**merged_options.dict(),
)
except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
Expand Down
27 changes: 21 additions & 6 deletions packages/ragbits-core/src/ragbits/core/embeddings/local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from collections.abc import Iterator

from ragbits.core.embeddings import Embeddings
from ragbits.core.options import Options

try:
import torch
import torch.nn.functional as F
Expand All @@ -9,32 +12,42 @@
except ImportError:
HAS_LOCAL_EMBEDDINGS = False

from ragbits.core.embeddings import Embeddings

class LocalEmbeddingsOptions(Options):
"""
Dataclass that represents available call options for the LocalEmbeddings client.
"""

class LocalEmbeddings(Embeddings):
batch_size: int = 1


class LocalEmbeddings(Embeddings[LocalEmbeddingsOptions]):
"""
Class for interaction with any encoder available in HuggingFace.
"""

options_cls = LocalEmbeddingsOptions

def __init__(
self,
model_name: str,
api_key: str | None = None,
default_options: LocalEmbeddingsOptions | None = None,
) -> None:
"""Constructs a new local LLM instance.
Args:
model_name: Name of the model to use.
api_key: The API key for Hugging Face authentication.
default_options: Default options for the embedding model.
Raises:
ImportError: If the 'local' extra requirements are not installed.
"""
if not HAS_LOCAL_EMBEDDINGS:
raise ImportError("You need to install the 'local' extra requirements to use local embeddings models")

super().__init__()
super().__init__(default_options=default_options)

self.hf_api_key = api_key
self.model_name = model_name
Expand All @@ -43,18 +56,20 @@ def __init__(
self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key)

async def embed_text(self, data: list[str], batch_size: int = 1) -> list[list[float]]:
async def embed_text(self, data: list[str], options: LocalEmbeddingsOptions | None = None) -> list[list[float]]:
"""Calls the appropriate encoder endpoint with the given data and options.
Args:
data: List of strings to get embeddings for.
batch_size: Batch size.
options: Additional options to pass to the embedding model.
Returns:
List of embeddings for the given strings.
"""
merged_options = (self.default_options | options) if options else self.default_options

embeddings = []
for batch in self._batch(data, batch_size):
for batch in self._batch(data, merged_options.batch_size):
batch_dict = self.tokenizer(
batch,
max_length=self.tokenizer.model_max_length,
Expand Down
8 changes: 6 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/noop.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from ragbits.core.audit import traceable
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.options import Options


class NoopEmbeddings(Embeddings):
class NoopEmbeddings(Embeddings[Options]):
"""
A no-op implementation of the Embeddings class.
Expand All @@ -11,13 +12,16 @@ class NoopEmbeddings(Embeddings):
or as a placeholder when an actual embedding model is not required.
"""

options_cls = Options

@traceable
async def embed_text(self, data: list[str]) -> list[list[float]]: # noqa: PLR6301
async def embed_text(self, data: list[str], options: Options | None = None) -> list[list[float]]: # noqa: PLR6301
"""
Embeds a list of strings into a list of vectors.
Args:
data: A list of input text strings to embed.
options: Additional settings used by the Embeddings model.
Returns:
A list of embedding vectors, where each vector
Expand Down
Loading

0 comments on commit 4ca2cef

Please sign in to comment.