diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/__init__.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py new file mode 100644 index 00000000..ede4fcad --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + + +class Embeddings(ABC): + """ + Abstract client for communication with embedding models. + """ + + @abstractmethod + async def embed_text(self, data: list[str]) -> list[list[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + """ diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py new file mode 100644 index 00000000..4dd99ad1 --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py @@ -0,0 +1,36 @@ +class EmbeddingError(Exception): + """ + Base class for all exceptions raised by the EmbeddingClient. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class EmbeddingConnectionError(EmbeddingError): + """ + Raised when there is an error connecting to the embedding API. + """ + + def __init__(self, message: str = "Connection error.") -> None: + super().__init__(message) + + +class EmbeddingStatusError(EmbeddingError): + """ + Raised when an API response has a status code of 4xx or 5xx. + """ + + def __init__(self, message: str, status_code: int) -> None: + super().__init__(message) + self.status_code = status_code + + +class EmbeddingResponseError(EmbeddingError): + """ + Raised when an API response has an invalid schema. + """ + + def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: + super().__init__(message) diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py new file mode 100644 index 00000000..d1821e45 --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py @@ -0,0 +1,86 @@ +from typing import Optional + +from ragnarok_common.embeddings.exceptions import EmbeddingConnectionError, EmbeddingStatusError, EmbeddingResponseError + +try: + import litellm + + HAS_LITELLM = True +except ImportError: + HAS_LITELLM = False + +from ragnarok_common.embeddings.base import Embeddings + + +class LiteLLMEmbeddings(Embeddings): + """ + Client for creating text embeddings using LiteLLM API. + """ + + def __init__( + self, + model: str = "text-embedding-3-small", + options: Optional[dict] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + """ + Constructs the LiteLLMEmbeddingClient. + + Args: + model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\ + to be used. Default is "text-embedding-3-small". + options: Additional options to pass to the LiteLLM API. + api_base: The API endpoint you want to call the model with. + api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, + for more information, follow the instructions for your specific vendor in the\ + [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). + api_version: The API version for the call. + + Raises: + ImportError: If the litellm package is not installed. + """ + if not HAS_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + + super().__init__() + self.model = model + self.options = options or {} + self.api_base = api_base + self.api_key = api_key + self.api_version = api_version + + async def embed_text(self, data: list[str]) -> list[list[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + + Raises: + EmbeddingConnectionError: If there is a connection error with the embedding API. + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + + try: + response = await litellm.aembedding( + input=data, + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + **self.options, + ) + except litellm.openai.APIConnectionError as exc: + raise EmbeddingConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise EmbeddingResponseError() from exc + + return [embedding["embedding"] for embedding in response.data]