Skip to content

Commit

Permalink
Add local embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla committed Sep 16, 2024
1 parent 25d8249 commit 91dffbd
Show file tree
Hide file tree
Showing 13 changed files with 359 additions and 59 deletions.
1 change: 1 addition & 0 deletions packages/ragnarok-common/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ local =
torch~=2.2.1
transformers~=4.44.2
numpy~=1.24.0
accelerate~=0.34.2

[options.packages.find]
where=src
Expand Down
38 changes: 36 additions & 2 deletions packages/ragnarok-common/src/ragnarok_common/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,53 @@
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Generic, Optional, Type

from .clients.base import EmbeddingsClient, EmbeddingsClientOptions

class Embeddings(ABC):

class Embeddings(Generic[EmbeddingsClientOptions], ABC):
"""
Abstract client for communication with embedding models.
"""

_options_cls: Type[EmbeddingsClientOptions]

def __init__(self, model_name: str, default_options: Optional[EmbeddingsClientOptions] = None) -> None:
"""
Constructs a new Embeddings instance.
Args:
model_name: Name of the model to be used.
default_options: Default options to be used.
"""
self.model_name = model_name
self.default_options = default_options or self._options_cls()

def __init_subclass__(cls) -> None:
if not hasattr(cls, "_options_cls"):
raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute")

@cached_property
@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
def client(self) -> EmbeddingsClient:
"""
Client for embeddings.
"""

async def embed_text(self, data: list[str], options: Optional[EmbeddingsClientOptions] = None) -> list[list[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
options: Additional settings used by the embedding model.
Returns:
List of embeddings for the given strings.
"""

options = options or self.default_options

response = await self.client.call(data=data, options=options)

return response
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import Any, ClassVar, Dict, Generic, TypeVar

from ...types import NotGiven

EmbeddingsClientOptions = TypeVar("EmbeddingsClientOptions", bound="EmbeddingsOptions")


@dataclass
class EmbeddingsOptions(ABC):
"""
Abstract dataclass that represents all available encoder call options.
"""

_not_given: ClassVar[Any] = None

def dict(self) -> Dict[str, Any]:
"""
Creates a dictionary representation of the EmbeddingsOptions instance.
If a value is None, it will be replaced with a provider-specific not-given sentinel.
Returns:
A dictionary representation of the EmbeddingsOptions instance.
"""
options = asdict(self)
return {
key: self._not_given if value is None or isinstance(value, NotGiven) else value
for key, value in options.items()
}


class EmbeddingsClient(Generic[EmbeddingsClientOptions], ABC):
"""
Abstract client for a direct communication with encoder models.
"""

def __init__(self, model_name: str) -> None:
"""
Constructs a new EmbeddingsClient instance.
Args:
model_name: Name of the model to be used.
"""
self.model_name = model_name

@abstractmethod
async def call(self, data: list[str], options: EmbeddingsClientOptions) -> list[list[float]]:
"""
Calls encoder model inference API.
Args:
data: List of strings to get embeddings for.
options: Additional options to pass to the Embeddings Client.
Returns:
List of embeddings for the given strings.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from dataclasses import dataclass
from typing import Optional, Union

try:
import litellm

HAS_LITELLM = True
except ImportError:
HAS_LITELLM = False

from ...types import NOT_GIVEN, NotGiven
from .base import EmbeddingsClient, EmbeddingsOptions
from .exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError


@dataclass
class LiteLLMOptions(EmbeddingsOptions):
"""
Dataclass that represents all available encoder call options for the LiteLLM client.
"""

dimensions: Union[Optional[int], NotGiven] = NOT_GIVEN
encoding_format: Union[Optional[str], NotGiven] = NOT_GIVEN


class LiteLLMEmbeddingsClient(EmbeddingsClient):
"""
Client for the LiteLLM that supports calls to various encoders' APIs, including OpenAI, VertexAI,
Hugging Face and others.
"""

def __init__(
self,
model_name: str = "text-embedding-3-small",
api_base: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
) -> None:
"""
Constructs the LiteLLMEmbeddingClient.
Args:
model_name: Name of the [LiteLLM supported model]
(https://docs.litellm.ai/docs/embedding/supported_embedding)\
to be used. Default is "text-embedding-3-small".
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
for more information, follow the instructions for your specific vendor in the\
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
api_version: The API version for the call.
Raises:
ImportError: If the litellm package is not installed.
"""
if not HAS_LITELLM:
raise ImportError("You need to install litellm package to use LiteLLM models")

super().__init__(model_name)
self.api_base = api_base
self.api_key = api_key
self.api_version = api_version

async def call(self, data: list[str], options: LiteLLMOptions) -> list[list[float]]:
"""
Calls the appropriate encoder endpoint with the given data and options.
Args:
data: List of strings to get embeddings for.
options: Additional options to pass to the LiteLLM API.
Returns:
List of embeddings for the given strings.
Raises:
EmbeddingConnectionError: If there is a connection error with the embedding API.
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""

try:
response = await litellm.aembedding(
input=data,
model=self.model_name,
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**options,
)
except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise EmbeddingResponseError() from exc

return [embedding["embedding"] for embedding in response.data]
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from dataclasses import dataclass
from typing import Any, Iterator, Optional

try:
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

HAS_LOCAL_EMBEDDINGS = True
except ImportError:
HAS_LOCAL_EMBEDDINGS = False

from .base import EmbeddingsClient, EmbeddingsOptions


def _batch(iterable: Any, per_batch: int = 1) -> Iterator:
length = len(iterable)
for ndx in range(0, length, per_batch):
yield iterable[ndx : min(ndx + per_batch, length)]


def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


@dataclass
class LocalEmbeddingsOptions(EmbeddingsOptions):
"""
Dataclass that represents all available encoder call options for the LiteLLM client.
"""

batch_size: int = 1


class LocalEmbeddingsClient(EmbeddingsClient[LocalEmbeddingsOptions]):
"""
Client for the local encoders that supports Hugging Face models.
"""

_options_cls = LocalEmbeddingsOptions

def __init__(self, model_name: str, hf_api_key: Optional[str] = None) -> None:
"""
Constructs a new local EmbeddingsClient instance.
Args:
model_name: Name of the model to use.
hf_api_key: The Hugging Face API key for authentication.
Raises:
ImportError: If the 'local' extra requirements are not installed.
"""
if not HAS_LOCAL_EMBEDDINGS:
raise ImportError("You need to install the 'local' extra requirements to use local LLM models")

super().__init__(model_name)

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self.model = AutoModel.from_pretrained(model_name, token=hf_api_key).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_api_key)

async def call(self, data: list[str], options: LocalEmbeddingsOptions) -> list[list[float]]:
"""
Calls the appropriate encoder endpoint with the given data and options.
Args:
data: List of strings to get embeddings for.
options: Additional options to pass to the Embeddings Client.
Returns:
List of embeddings for the given strings.
"""

embeddings = []
for batch in _batch(data, options.batch_size):
batch_dict = self.tokenizer(
batch, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**batch_dict)
batch_embeddings = _average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
embeddings.extend(batch_embeddings.to("cpu").tolist())

torch.cuda.empty_cache()
return embeddings
Loading

0 comments on commit 91dffbd

Please sign in to comment.