Skip to content

Commit

Permalink
more type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Dec 19, 2024
1 parent 924b771 commit ea9c187
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 28 deletions.
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .base import Embeddings, EmbeddingsClientOptions, EmbeddingType
from .base import Embeddings, EmbeddingsOptionsT, EmbeddingType
from .litellm import LiteLLMEmbeddings
from .noop import NoopEmbeddings

__all__ = ["EmbeddingType", "Embeddings", "EmbeddingsClientOptions", "LiteLLMEmbeddings", "NoopEmbeddings"]
__all__ = ["EmbeddingType", "Embeddings", "EmbeddingsOptionsT", "LiteLLMEmbeddings", "NoopEmbeddings"]
15 changes: 8 additions & 7 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ragbits.core.options import Options
from ragbits.core.utils.config_handling import ConfigurableComponent

EmbeddingsClientOptions = TypeVar("EmbeddingsClientOptions", bound=Options)
EmbeddingsOptionsT = TypeVar("EmbeddingsOptionsT", bound=Options)


class EmbeddingType(Enum):
Expand All @@ -20,21 +20,21 @@ class EmbeddingType(Enum):
allowing for the creation of different embeddings for the same element.
"""

TEXT: str = "text"
IMAGE: str = "image"
TEXT = "text"
IMAGE = "image"


class Embeddings(ConfigurableComponent[EmbeddingsClientOptions], ABC):
class Embeddings(ConfigurableComponent[EmbeddingsOptionsT], ABC):
"""
Abstract client for communication with embedding models.
"""

options_cls: type[EmbeddingsClientOptions]
options_cls: type[EmbeddingsOptionsT]
default_module: ClassVar = embeddings
configuration_key: ClassVar = "embedder"

@abstractmethod
async def embed_text(self, data: list[str], options: EmbeddingsClientOptions | None = None) -> list[list[float]]:
async def embed_text(self, data: list[str], options: EmbeddingsOptionsT | None = None) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Expand All @@ -55,12 +55,13 @@ def image_support(self) -> bool: # noqa: PLR6301
"""
return False

async def embed_image(self, images: list[bytes]) -> list[list[float]]:
async def embed_image(self, images: list[bytes], options: EmbeddingsOptionsT | None = None) -> list[list[float]]:
"""
Creates embeddings for the given images.
Args:
images: List of images to get embeddings for.
options: Additional settings used by the Embeddings model.
Returns:
List of embeddings for the given images.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbeddingsOptions |
EmbeddingResponseError: If the embedding API response is invalid.
"""
merged_options = (self.default_options | options) if options else self.default_options

with trace(
data=data,
model=self.model,
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ async def embed_text(self, data: list[str], options: LocalEmbeddingsOptions | No
Returns:
List of embeddings for the given strings.
"""
merged_options: LocalEmbeddingsOptions = (self.default_options | options) if options else self.default_options # type: ignore
# for some reason, mypy doesn't recognize that merged_options is LocalEmbeddingsOptions, it thinks it's Options
merged_options = (self.default_options | options) if options else self.default_options

embeddings = []
for batch in self._batch(data, merged_options.batch_size):
batch_dict = self.tokenizer(
Expand Down
6 changes: 3 additions & 3 deletions packages/ragbits-core/src/ragbits/core/embeddings/noop.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from ragbits.core.audit import traceable
from ragbits.core.embeddings.base import Embeddings, EmbeddingsClientOptions
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.options import Options


class NoopEmbeddings(Embeddings):
class NoopEmbeddings(Embeddings[Options]):
"""
A no-op implementation of the Embeddings class.
Expand All @@ -15,7 +15,7 @@ class NoopEmbeddings(Embeddings):
options_cls = Options

@traceable
async def embed_text(self, data: list[str], options: EmbeddingsClientOptions | None = None) -> list[list[float]]: # noqa: PLR6301
async def embed_text(self, data: list[str], options: Options | None = None) -> list[list[float]]: # noqa: PLR6301
"""
Embeds a list of strings into a list of vectors.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import base64

from ragbits.core.embeddings.litellm import LiteLLMEmbeddingsOptions

try:
import litellm
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import VertexAIError
Expand All @@ -16,15 +18,14 @@
EmbeddingResponseError,
EmbeddingStatusError,
)
from ragbits.core.options import Options


class VertexAIMultimodelEmbeddings(Embeddings):
class VertexAIMultimodelEmbeddings(Embeddings[LiteLLMEmbeddingsOptions]):
"""
Client for creating text embeddings using LiteLLM API.
"""

options_cls = Options
options_cls = LiteLLMEmbeddingsOptions
VERTEX_AI_PREFIX = "vertex_ai/"

def __init__(
Expand All @@ -33,7 +34,7 @@ def __init__(
api_base: str | None = None,
api_key: str | None = None,
concurency: int = 10,
default_options: Options | None = None,
default_options: LiteLLMEmbeddingsOptions | None = None,
) -> None:
"""
Constructs the embedding client for multimodal VertexAI models.
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
if model not in supported_models:
raise ValueError(f"Model {model} is not supported by VertexAI multimodal embeddings")

async def _embed(self, data: list[dict], options: Options | None = None) -> list[dict]:
async def _embed(self, data: list[dict], options: LiteLLMEmbeddingsOptions | None = None) -> list[dict]:
"""
Creates embeddings for the given data. The format is defined in the VertexAI API:
https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings
Expand Down Expand Up @@ -106,7 +107,7 @@ async def _embed(self, data: list[dict], options: Options | None = None) -> list
return outputs.embeddings

async def _call_litellm(
self, instance: dict, semaphore: asyncio.Semaphore, options: Options
self, instance: dict, semaphore: asyncio.Semaphore, options: LiteLLMEmbeddingsOptions
) -> litellm.EmbeddingResponse:
"""
Calls the LiteLLM API to get embeddings for the given data.
Expand All @@ -130,7 +131,7 @@ async def _call_litellm(

return response

async def embed_text(self, data: list[str], options: Options | None = None) -> list[list[float]]:
async def embed_text(self, data: list[str], options: LiteLLMEmbeddingsOptions | None = None) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Expand All @@ -157,7 +158,9 @@ def image_support(self) -> bool: # noqa: PLR6301
"""
return True

async def embed_image(self, images: list[bytes], options: Options | None = None) -> list[list[float]]:
async def embed_image(
self, images: list[bytes], options: LiteLLMEmbeddingsOptions | None = None
) -> list[list[float]]:
"""
Creates embeddings for the given images.
Expand Down
2 changes: 1 addition & 1 deletion packages/ragbits-core/src/ragbits/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ragbits.core.types import NotGiven

OptionsTypeVar = TypeVar("OptionsTypeVar", bound="Options")
OptionsT = TypeVar("OptionsT", bound="Options")


class Options(BaseModel, ABC):
Expand Down
10 changes: 5 additions & 5 deletions packages/ragbits-core/src/ragbits/core/utils/config_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel
from typing_extensions import Self

from ragbits.core.options import OptionsTypeVar
from ragbits.core.options import OptionsT
from ragbits.core.utils._pyproject import get_config_from_yaml

if TYPE_CHECKING:
Expand Down Expand Up @@ -172,21 +172,21 @@ def from_config(cls, config: dict) -> Self:
return cls(**config)


class ConfigurableComponent(Generic[OptionsTypeVar], WithConstructionConfig):
class ConfigurableComponent(Generic[OptionsT], WithConstructionConfig):
"""
Base class for components with configurable options.
"""

options_cls: type[OptionsTypeVar]
options_cls: type[OptionsT]

def __init__(self, default_options: OptionsTypeVar | None = None) -> None:
def __init__(self, default_options: OptionsT | None = None) -> None:
"""
Constructs a new ConfigurableComponent instance.
Args:
default_options: The default options for the component.
"""
self.default_options: OptionsTypeVar = default_options or self.options_cls()
self.default_options: OptionsT = default_options or self.options_cls()

@classmethod
def from_config(cls, config: dict[str, Any]) -> ConfigurableComponent:
Expand Down

0 comments on commit ea9c187

Please sign in to comment.