-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
359 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
38 changes: 36 additions & 2 deletions
38
packages/ragnarok-common/src/ragnarok_common/embeddings/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,53 @@ | ||
from abc import ABC, abstractmethod | ||
from functools import cached_property | ||
from typing import Generic, Optional, Type | ||
|
||
from .clients.base import EmbeddingsClient, EmbeddingsClientOptions | ||
|
||
class Embeddings(ABC): | ||
|
||
class Embeddings(Generic[EmbeddingsClientOptions], ABC): | ||
""" | ||
Abstract client for communication with embedding models. | ||
""" | ||
|
||
_options_cls: Type[EmbeddingsClientOptions] | ||
|
||
def __init__(self, model_name: str, default_options: Optional[EmbeddingsClientOptions] = None) -> None: | ||
""" | ||
Constructs a new Embeddings instance. | ||
Args: | ||
model_name: Name of the model to be used. | ||
default_options: Default options to be used. | ||
""" | ||
self.model_name = model_name | ||
self.default_options = default_options or self._options_cls() | ||
|
||
def __init_subclass__(cls) -> None: | ||
if not hasattr(cls, "_options_cls"): | ||
raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute") | ||
|
||
@cached_property | ||
@abstractmethod | ||
async def embed_text(self, data: list[str]) -> list[list[float]]: | ||
def client(self) -> EmbeddingsClient: | ||
""" | ||
Client for embeddings. | ||
""" | ||
|
||
async def embed_text(self, data: list[str], options: Optional[EmbeddingsClientOptions] = None) -> list[list[float]]: | ||
""" | ||
Creates embeddings for the given strings. | ||
Args: | ||
data: List of strings to get embeddings for. | ||
options: Additional settings used by the embedding model. | ||
Returns: | ||
List of embeddings for the given strings. | ||
""" | ||
|
||
options = options or self.default_options | ||
|
||
response = await self.client.call(data=data, options=options) | ||
|
||
return response |
Empty file.
58 changes: 58 additions & 0 deletions
58
packages/ragnarok-common/src/ragnarok_common/embeddings/clients/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from abc import ABC, abstractmethod | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, ClassVar, Dict, Generic, TypeVar | ||
|
||
from ...types import NotGiven | ||
|
||
EmbeddingsClientOptions = TypeVar("EmbeddingsClientOptions", bound="EmbeddingsOptions") | ||
|
||
|
||
@dataclass | ||
class EmbeddingsOptions(ABC): | ||
""" | ||
Abstract dataclass that represents all available encoder call options. | ||
""" | ||
|
||
_not_given: ClassVar[Any] = None | ||
|
||
def dict(self) -> Dict[str, Any]: | ||
""" | ||
Creates a dictionary representation of the EmbeddingsOptions instance. | ||
If a value is None, it will be replaced with a provider-specific not-given sentinel. | ||
Returns: | ||
A dictionary representation of the EmbeddingsOptions instance. | ||
""" | ||
options = asdict(self) | ||
return { | ||
key: self._not_given if value is None or isinstance(value, NotGiven) else value | ||
for key, value in options.items() | ||
} | ||
|
||
|
||
class EmbeddingsClient(Generic[EmbeddingsClientOptions], ABC): | ||
""" | ||
Abstract client for a direct communication with encoder models. | ||
""" | ||
|
||
def __init__(self, model_name: str) -> None: | ||
""" | ||
Constructs a new EmbeddingsClient instance. | ||
Args: | ||
model_name: Name of the model to be used. | ||
""" | ||
self.model_name = model_name | ||
|
||
@abstractmethod | ||
async def call(self, data: list[str], options: EmbeddingsClientOptions) -> list[list[float]]: | ||
""" | ||
Calls encoder model inference API. | ||
Args: | ||
data: List of strings to get embeddings for. | ||
options: Additional options to pass to the Embeddings Client. | ||
Returns: | ||
List of embeddings for the given strings. | ||
""" |
File renamed without changes.
96 changes: 96 additions & 0 deletions
96
packages/ragnarok-common/src/ragnarok_common/embeddings/clients/litellm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional, Union | ||
|
||
try: | ||
import litellm | ||
|
||
HAS_LITELLM = True | ||
except ImportError: | ||
HAS_LITELLM = False | ||
|
||
from ...types import NOT_GIVEN, NotGiven | ||
from .base import EmbeddingsClient, EmbeddingsOptions | ||
from .exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError | ||
|
||
|
||
@dataclass | ||
class LiteLLMOptions(EmbeddingsOptions): | ||
""" | ||
Dataclass that represents all available encoder call options for the LiteLLM client. | ||
""" | ||
|
||
dimensions: Union[Optional[int], NotGiven] = NOT_GIVEN | ||
encoding_format: Union[Optional[str], NotGiven] = NOT_GIVEN | ||
|
||
|
||
class LiteLLMEmbeddingsClient(EmbeddingsClient): | ||
""" | ||
Client for the LiteLLM that supports calls to various encoders' APIs, including OpenAI, VertexAI, | ||
Hugging Face and others. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "text-embedding-3-small", | ||
api_base: Optional[str] = None, | ||
api_key: Optional[str] = None, | ||
api_version: Optional[str] = None, | ||
) -> None: | ||
""" | ||
Constructs the LiteLLMEmbeddingClient. | ||
Args: | ||
model_name: Name of the [LiteLLM supported model] | ||
(https://docs.litellm.ai/docs/embedding/supported_embedding)\ | ||
to be used. Default is "text-embedding-3-small". | ||
api_base: The API endpoint you want to call the model with. | ||
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, | ||
for more information, follow the instructions for your specific vendor in the\ | ||
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). | ||
api_version: The API version for the call. | ||
Raises: | ||
ImportError: If the litellm package is not installed. | ||
""" | ||
if not HAS_LITELLM: | ||
raise ImportError("You need to install litellm package to use LiteLLM models") | ||
|
||
super().__init__(model_name) | ||
self.api_base = api_base | ||
self.api_key = api_key | ||
self.api_version = api_version | ||
|
||
async def call(self, data: list[str], options: LiteLLMOptions) -> list[list[float]]: | ||
""" | ||
Calls the appropriate encoder endpoint with the given data and options. | ||
Args: | ||
data: List of strings to get embeddings for. | ||
options: Additional options to pass to the LiteLLM API. | ||
Returns: | ||
List of embeddings for the given strings. | ||
Raises: | ||
EmbeddingConnectionError: If there is a connection error with the embedding API. | ||
EmbeddingStatusError: If the embedding API returns an error status code. | ||
EmbeddingResponseError: If the embedding API response is invalid. | ||
""" | ||
|
||
try: | ||
response = await litellm.aembedding( | ||
input=data, | ||
model=self.model_name, | ||
api_base=self.api_base, | ||
api_key=self.api_key, | ||
api_version=self.api_version, | ||
**options, | ||
) | ||
except litellm.openai.APIConnectionError as exc: | ||
raise EmbeddingConnectionError() from exc | ||
except litellm.openai.APIStatusError as exc: | ||
raise EmbeddingStatusError(exc.message, exc.status_code) from exc | ||
except litellm.openai.APIResponseValidationError as exc: | ||
raise EmbeddingResponseError() from exc | ||
|
||
return [embedding["embedding"] for embedding in response.data] |
88 changes: 88 additions & 0 deletions
88
packages/ragnarok-common/src/ragnarok_common/embeddings/clients/local.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Iterator, Optional | ||
|
||
try: | ||
import torch | ||
import torch.nn.functional as F | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
HAS_LOCAL_EMBEDDINGS = True | ||
except ImportError: | ||
HAS_LOCAL_EMBEDDINGS = False | ||
|
||
from .base import EmbeddingsClient, EmbeddingsOptions | ||
|
||
|
||
def _batch(iterable: Any, per_batch: int = 1) -> Iterator: | ||
length = len(iterable) | ||
for ndx in range(0, length, per_batch): | ||
yield iterable[ndx : min(ndx + per_batch, length)] | ||
|
||
|
||
def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | ||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) | ||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | ||
|
||
|
||
@dataclass | ||
class LocalEmbeddingsOptions(EmbeddingsOptions): | ||
""" | ||
Dataclass that represents all available encoder call options for the LiteLLM client. | ||
""" | ||
|
||
batch_size: int = 1 | ||
|
||
|
||
class LocalEmbeddingsClient(EmbeddingsClient[LocalEmbeddingsOptions]): | ||
""" | ||
Client for the local encoders that supports Hugging Face models. | ||
""" | ||
|
||
_options_cls = LocalEmbeddingsOptions | ||
|
||
def __init__(self, model_name: str, hf_api_key: Optional[str] = None) -> None: | ||
""" | ||
Constructs a new local EmbeddingsClient instance. | ||
Args: | ||
model_name: Name of the model to use. | ||
hf_api_key: The Hugging Face API key for authentication. | ||
Raises: | ||
ImportError: If the 'local' extra requirements are not installed. | ||
""" | ||
if not HAS_LOCAL_EMBEDDINGS: | ||
raise ImportError("You need to install the 'local' extra requirements to use local LLM models") | ||
|
||
super().__init__(model_name) | ||
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
self.model = AutoModel.from_pretrained(model_name, token=hf_api_key).to(self.device) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_api_key) | ||
|
||
async def call(self, data: list[str], options: LocalEmbeddingsOptions) -> list[list[float]]: | ||
""" | ||
Calls the appropriate encoder endpoint with the given data and options. | ||
Args: | ||
data: List of strings to get embeddings for. | ||
options: Additional options to pass to the Embeddings Client. | ||
Returns: | ||
List of embeddings for the given strings. | ||
""" | ||
|
||
embeddings = [] | ||
for batch in _batch(data, options.batch_size): | ||
batch_dict = self.tokenizer( | ||
batch, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" | ||
).to(self.device) | ||
with torch.no_grad(): | ||
outputs = self.model(**batch_dict) | ||
batch_embeddings = _average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) | ||
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) | ||
embeddings.extend(batch_embeddings.to("cpu").tolist()) | ||
|
||
torch.cuda.empty_cache() | ||
return embeddings |
Oops, something went wrong.