-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: add embeddings * Fix pylint and mypy errors. --------- Co-authored-by: Patryk Wyżgowski <[email protected]>
- Loading branch information
1 parent
c6dc50c
commit 25d8249
Showing
7 changed files
with
145 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
19 changes: 19 additions & 0 deletions
19
packages/ragnarok-common/src/ragnarok_common/embeddings/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class Embeddings(ABC): | ||
""" | ||
Abstract client for communication with embedding models. | ||
""" | ||
|
||
@abstractmethod | ||
async def embed_text(self, data: list[str]) -> list[list[float]]: | ||
""" | ||
Creates embeddings for the given strings. | ||
Args: | ||
data: List of strings to get embeddings for. | ||
Returns: | ||
List of embeddings for the given strings. | ||
""" |
36 changes: 36 additions & 0 deletions
36
packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
class EmbeddingError(Exception): | ||
""" | ||
Base class for all exceptions raised by the EmbeddingClient. | ||
""" | ||
|
||
def __init__(self, message: str) -> None: | ||
super().__init__(message) | ||
self.message = message | ||
|
||
|
||
class EmbeddingConnectionError(EmbeddingError): | ||
""" | ||
Raised when there is an error connecting to the embedding API. | ||
""" | ||
|
||
def __init__(self, message: str = "Connection error.") -> None: | ||
super().__init__(message) | ||
|
||
|
||
class EmbeddingStatusError(EmbeddingError): | ||
""" | ||
Raised when an API response has a status code of 4xx or 5xx. | ||
""" | ||
|
||
def __init__(self, message: str, status_code: int) -> None: | ||
super().__init__(message) | ||
self.status_code = status_code | ||
|
||
|
||
class EmbeddingResponseError(EmbeddingError): | ||
""" | ||
Raised when an API response has an invalid schema. | ||
""" | ||
|
||
def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: | ||
super().__init__(message) |
85 changes: 85 additions & 0 deletions
85
packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from typing import Optional | ||
|
||
try: | ||
import litellm | ||
|
||
HAS_LITELLM = True | ||
except ImportError: | ||
HAS_LITELLM = False | ||
|
||
from ragnarok_common.embeddings.base import Embeddings | ||
from ragnarok_common.embeddings.exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError | ||
|
||
|
||
class LiteLLMEmbeddings(Embeddings): | ||
""" | ||
Client for creating text embeddings using LiteLLM API. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: str = "text-embedding-3-small", | ||
options: Optional[dict] = None, | ||
api_base: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
api_version: Optional[str] = None, | ||
) -> None: | ||
""" | ||
Constructs the LiteLLMEmbeddingClient. | ||
Args: | ||
model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\ | ||
to be used. Default is "text-embedding-3-small". | ||
options: Additional options to pass to the LiteLLM API. | ||
api_base: The API endpoint you want to call the model with. | ||
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, | ||
for more information, follow the instructions for your specific vendor in the\ | ||
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). | ||
api_version: The API version for the call. | ||
Raises: | ||
ImportError: If the litellm package is not installed. | ||
""" | ||
if not HAS_LITELLM: | ||
raise ImportError("You need to install litellm package to use LiteLLM models") | ||
|
||
super().__init__() | ||
self.model = model | ||
self.options = options or {} | ||
self.api_base = api_base | ||
self.api_key = api_key | ||
self.api_version = api_version | ||
|
||
async def embed_text(self, data: list[str]) -> list[list[float]]: | ||
""" | ||
Creates embeddings for the given strings. | ||
Args: | ||
data: List of strings to get embeddings for. | ||
Returns: | ||
List of embeddings for the given strings. | ||
Raises: | ||
EmbeddingConnectionError: If there is a connection error with the embedding API. | ||
EmbeddingStatusError: If the embedding API returns an error status code. | ||
EmbeddingResponseError: If the embedding API response is invalid. | ||
""" | ||
|
||
try: | ||
response = await litellm.aembedding( | ||
input=data, | ||
model=self.model, | ||
api_base=self.api_base, | ||
api_key=self.api_key, | ||
api_version=self.api_version, | ||
**self.options, | ||
) | ||
except litellm.openai.APIConnectionError as exc: | ||
raise EmbeddingConnectionError() from exc | ||
except litellm.openai.APIStatusError as exc: | ||
raise EmbeddingStatusError(exc.message, exc.status_code) from exc | ||
except litellm.openai.APIResponseValidationError as exc: | ||
raise EmbeddingResponseError() from exc | ||
|
||
return [embedding["embedding"] for embedding in response.data] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,4 +92,3 @@ def discover(self) -> dict: | |
)._asdict() | ||
|
||
return result_dict | ||
|