diff --git a/packages/ragnarok-common/setup.cfg b/packages/ragnarok-common/setup.cfg index b8bdeb872..713afad16 100644 --- a/packages/ragnarok-common/setup.cfg +++ b/packages/ragnarok-common/setup.cfg @@ -41,6 +41,7 @@ local = torch~=2.2.1 transformers~=4.44.2 numpy~=1.24.0 + accelerate~=0.34.2 [options.packages.find] where=src diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py index ede4fcadf..ba035b634 100644 --- a/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py @@ -1,19 +1,53 @@ from abc import ABC, abstractmethod +from functools import cached_property +from typing import Generic, Optional, Type +from .clients.base import EmbeddingsClient, EmbeddingsClientOptions -class Embeddings(ABC): + +class Embeddings(Generic[EmbeddingsClientOptions], ABC): """ Abstract client for communication with embedding models. """ + _options_cls: Type[EmbeddingsClientOptions] + + def __init__(self, model_name: str, default_options: Optional[EmbeddingsClientOptions] = None) -> None: + """ + Constructs a new Embeddings instance. + + Args: + model_name: Name of the model to be used. + default_options: Default options to be used. + """ + self.model_name = model_name + self.default_options = default_options or self._options_cls() + + def __init_subclass__(cls) -> None: + if not hasattr(cls, "_options_cls"): + raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute") + + @cached_property @abstractmethod - async def embed_text(self, data: list[str]) -> list[list[float]]: + def client(self) -> EmbeddingsClient: + """ + Client for embeddings. + """ + + async def embed_text(self, data: list[str], options: Optional[EmbeddingsClientOptions] = None) -> list[list[float]]: """ Creates embeddings for the given strings. Args: data: List of strings to get embeddings for. + options: Additional settings used by the embedding model. Returns: List of embeddings for the given strings. """ + + options = options or self.default_options + + response = await self.client.call(data=data, options=options) + + return response diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/__init__.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/base.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/base.py new file mode 100644 index 000000000..b7da6fd14 --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/base.py @@ -0,0 +1,58 @@ +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from typing import Any, ClassVar, Dict, Generic, TypeVar + +from ...types import NotGiven + +EmbeddingsClientOptions = TypeVar("EmbeddingsClientOptions", bound="EmbeddingsOptions") + + +@dataclass +class EmbeddingsOptions(ABC): + """ + Abstract dataclass that represents all available encoder call options. + """ + + _not_given: ClassVar[Any] = None + + def dict(self) -> Dict[str, Any]: + """ + Creates a dictionary representation of the EmbeddingsOptions instance. + If a value is None, it will be replaced with a provider-specific not-given sentinel. + + Returns: + A dictionary representation of the EmbeddingsOptions instance. + """ + options = asdict(self) + return { + key: self._not_given if value is None or isinstance(value, NotGiven) else value + for key, value in options.items() + } + + +class EmbeddingsClient(Generic[EmbeddingsClientOptions], ABC): + """ + Abstract client for a direct communication with encoder models. + """ + + def __init__(self, model_name: str) -> None: + """ + Constructs a new EmbeddingsClient instance. + + Args: + model_name: Name of the model to be used. + """ + self.model_name = model_name + + @abstractmethod + async def call(self, data: list[str], options: EmbeddingsClientOptions) -> list[list[float]]: + """ + Calls encoder model inference API. + + Args: + data: List of strings to get embeddings for. + options: Additional options to pass to the Embeddings Client. + + Returns: + List of embeddings for the given strings. + """ diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/exceptions.py similarity index 100% rename from packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py rename to packages/ragnarok-common/src/ragnarok_common/embeddings/clients/exceptions.py diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/litellm.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/litellm.py new file mode 100644 index 000000000..c3ffec01b --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/litellm.py @@ -0,0 +1,96 @@ +from dataclasses import dataclass +from typing import Optional, Union + +try: + import litellm + + HAS_LITELLM = True +except ImportError: + HAS_LITELLM = False + +from ...types import NOT_GIVEN, NotGiven +from .base import EmbeddingsClient, EmbeddingsOptions +from .exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError + + +@dataclass +class LiteLLMOptions(EmbeddingsOptions): + """ + Dataclass that represents all available encoder call options for the LiteLLM client. + """ + + dimensions: Union[Optional[int], NotGiven] = NOT_GIVEN + encoding_format: Union[Optional[str], NotGiven] = NOT_GIVEN + + +class LiteLLMEmbeddingsClient(EmbeddingsClient): + """ + Client for the LiteLLM that supports calls to various encoders' APIs, including OpenAI, VertexAI, + Hugging Face and others. + """ + + def __init__( + self, + model_name: str = "text-embedding-3-small", + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + """ + Constructs the LiteLLMEmbeddingClient. + + Args: + model_name: Name of the [LiteLLM supported model] + (https://docs.litellm.ai/docs/embedding/supported_embedding)\ + to be used. Default is "text-embedding-3-small". + api_base: The API endpoint you want to call the model with. + api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, + for more information, follow the instructions for your specific vendor in the\ + [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). + api_version: The API version for the call. + + Raises: + ImportError: If the litellm package is not installed. + """ + if not HAS_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + + super().__init__(model_name) + self.api_base = api_base + self.api_key = api_key + self.api_version = api_version + + async def call(self, data: list[str], options: LiteLLMOptions) -> list[list[float]]: + """ + Calls the appropriate encoder endpoint with the given data and options. + + Args: + data: List of strings to get embeddings for. + options: Additional options to pass to the LiteLLM API. + + Returns: + List of embeddings for the given strings. + + Raises: + EmbeddingConnectionError: If there is a connection error with the embedding API. + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + + try: + response = await litellm.aembedding( + input=data, + model=self.model_name, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + **options, + ) + except litellm.openai.APIConnectionError as exc: + raise EmbeddingConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise EmbeddingResponseError() from exc + + return [embedding["embedding"] for embedding in response.data] diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/local.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/local.py new file mode 100644 index 000000000..c9885c094 --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/clients/local.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Any, Iterator, Optional + +try: + import torch + import torch.nn.functional as F + from transformers import AutoModel, AutoTokenizer + + HAS_LOCAL_EMBEDDINGS = True +except ImportError: + HAS_LOCAL_EMBEDDINGS = False + +from .base import EmbeddingsClient, EmbeddingsOptions + + +def _batch(iterable: Any, per_batch: int = 1) -> Iterator: + length = len(iterable) + for ndx in range(0, length, per_batch): + yield iterable[ndx : min(ndx + per_batch, length)] + + +def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +@dataclass +class LocalEmbeddingsOptions(EmbeddingsOptions): + """ + Dataclass that represents all available encoder call options for the LiteLLM client. + """ + + batch_size: int = 1 + + +class LocalEmbeddingsClient(EmbeddingsClient[LocalEmbeddingsOptions]): + """ + Client for the local encoders that supports Hugging Face models. + """ + + _options_cls = LocalEmbeddingsOptions + + def __init__(self, model_name: str, hf_api_key: Optional[str] = None) -> None: + """ + Constructs a new local EmbeddingsClient instance. + + Args: + model_name: Name of the model to use. + hf_api_key: The Hugging Face API key for authentication. + + Raises: + ImportError: If the 'local' extra requirements are not installed. + """ + if not HAS_LOCAL_EMBEDDINGS: + raise ImportError("You need to install the 'local' extra requirements to use local LLM models") + + super().__init__(model_name) + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.model = AutoModel.from_pretrained(model_name, token=hf_api_key).to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_api_key) + + async def call(self, data: list[str], options: LocalEmbeddingsOptions) -> list[list[float]]: + """ + Calls the appropriate encoder endpoint with the given data and options. + + Args: + data: List of strings to get embeddings for. + options: Additional options to pass to the Embeddings Client. + + Returns: + List of embeddings for the given strings. + """ + + embeddings = [] + for batch in _batch(data, options.batch_size): + batch_dict = self.tokenizer( + batch, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + with torch.no_grad(): + outputs = self.model(**batch_dict) + batch_embeddings = _average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) + batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) + embeddings.extend(batch_embeddings.to("cpu").tolist()) + + torch.cuda.empty_cache() + return embeddings diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py index 4a8654fc5..70c1ef146 100644 --- a/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py @@ -1,85 +1,54 @@ +from functools import cached_property from typing import Optional -try: - import litellm +from .base import Embeddings +from .clients.litellm import LiteLLMEmbeddingsClient, LiteLLMOptions - HAS_LITELLM = True -except ImportError: - HAS_LITELLM = False -from ragnarok_common.embeddings.base import Embeddings -from ragnarok_common.embeddings.exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError - - -class LiteLLMEmbeddings(Embeddings): +class LiteLLMEmbeddings(Embeddings[LiteLLMOptions]): """ - Client for creating text embeddings using LiteLLM API. + Class for interaction with any encoder supported by LiteLLM API. """ + _options_cls = LiteLLMOptions + def __init__( self, - model: str = "text-embedding-3-small", - options: Optional[dict] = None, + model_name: str = "text-embedding-3-small", + default_options: Optional[LiteLLMOptions] = None, api_base: Optional[str] = None, api_key: Optional[str] = None, api_version: Optional[str] = None, ) -> None: """ - Constructs the LiteLLMEmbeddingClient. + Constructs a new LiteLLMEmbeddings instance. Args: - model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\ + model_name: Name of the + [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding) to be used. Default is "text-embedding-3-small". - options: Additional options to pass to the LiteLLM API. + default_options: Default options to be used. api_base: The API endpoint you want to call the model with. api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, for more information, follow the instructions for your specific vendor in the\ [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). api_version: The API version for the call. - Raises: - ImportError: If the litellm package is not installed. """ - if not HAS_LITELLM: - raise ImportError("You need to install litellm package to use LiteLLM models") - super().__init__() - self.model = model - self.options = options or {} + super().__init__(model_name, default_options) self.api_base = api_base self.api_key = api_key self.api_version = api_version - async def embed_text(self, data: list[str]) -> list[list[float]]: + @cached_property + def client(self) -> LiteLLMEmbeddingsClient: """ - Creates embeddings for the given strings. - - Args: - data: List of strings to get embeddings for. - - Returns: - List of embeddings for the given strings. - - Raises: - EmbeddingConnectionError: If there is a connection error with the embedding API. - EmbeddingStatusError: If the embedding API returns an error status code. - EmbeddingResponseError: If the embedding API response is invalid. + Client for the LiteLLM encoder. """ - - try: - response = await litellm.aembedding( - input=data, - model=self.model, - api_base=self.api_base, - api_key=self.api_key, - api_version=self.api_version, - **self.options, - ) - except litellm.openai.APIConnectionError as exc: - raise EmbeddingConnectionError() from exc - except litellm.openai.APIStatusError as exc: - raise EmbeddingStatusError(exc.message, exc.status_code) from exc - except litellm.openai.APIResponseValidationError as exc: - raise EmbeddingResponseError() from exc - - return [embedding["embedding"] for embedding in response.data] + return LiteLLMEmbeddingsClient( + model_name=self.model_name, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + ) diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/local.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/local.py new file mode 100644 index 000000000..51fec8ea9 --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/local.py @@ -0,0 +1,50 @@ +from functools import cached_property +from typing import Optional + +try: + from transformers import PreTrainedModel + + HAS_LOCAL_EMBEDDINGS = True +except ImportError: + HAS_LOCAL_EMBEDDINGS = False + +from .base import Embeddings +from .clients.local import LocalEmbeddingsClient, LocalEmbeddingsOptions + + +class LocalEmbeddings(Embeddings[LocalEmbeddingsOptions]): + """ + Class for interaction with any encoder available in HuggingFace. + """ + + _options_cls = LocalEmbeddingsOptions + + def __init__( + self, + model_name: str, + default_options: Optional[LocalEmbeddingsOptions] = None, + api_key: Optional[str] = None, + ) -> None: + """ + Constructs a new local LLM instance. + + Args: + model_name: Name of the model to use. + default_options: Default options to be used. + api_key: The API key for Hugging Face authentication. + """ + if not HAS_LOCAL_EMBEDDINGS: + raise ImportError("You need to install the 'local' extra requirements to use local embeddings models") + + super().__init__(model_name, default_options) + self.api_key = api_key + + @cached_property + def client(self) -> PreTrainedModel: + """ + Client for the LLM. + + Returns: + The client used to interact with the LLM. + """ + return LocalEmbeddingsClient(model_name=self.model_name, hf_api_key=self.api_key) diff --git a/packages/ragnarok-common/src/ragnarok_common/llms/clients/base.py b/packages/ragnarok-common/src/ragnarok_common/llms/clients/base.py index ae9c5ffbd..d3cf7f428 100644 --- a/packages/ragnarok-common/src/ragnarok_common/llms/clients/base.py +++ b/packages/ragnarok-common/src/ragnarok_common/llms/clients/base.py @@ -6,7 +6,7 @@ from ragnarok_common.prompt import ChatFormat -from ..types import NotGiven +from ...types import NotGiven LLMClientOptions = TypeVar("LLMClientOptions", bound="LLMOptions") diff --git a/packages/ragnarok-common/src/ragnarok_common/llms/clients/litellm.py b/packages/ragnarok-common/src/ragnarok_common/llms/clients/litellm.py index 0824c55e7..095e9d2a6 100644 --- a/packages/ragnarok-common/src/ragnarok_common/llms/clients/litellm.py +++ b/packages/ragnarok-common/src/ragnarok_common/llms/clients/litellm.py @@ -13,7 +13,7 @@ from ragnarok_common.prompt import ChatFormat -from ..types import NOT_GIVEN, NotGiven +from ...types import NOT_GIVEN, NotGiven from .base import LLMClient, LLMOptions from .exceptions import LLMConnectionError, LLMResponseError, LLMStatusError diff --git a/packages/ragnarok-common/src/ragnarok_common/llms/clients/local.py b/packages/ragnarok-common/src/ragnarok_common/llms/clients/local.py index 0acb3e9d2..8f840eee4 100644 --- a/packages/ragnarok-common/src/ragnarok_common/llms/clients/local.py +++ b/packages/ragnarok-common/src/ragnarok_common/llms/clients/local.py @@ -13,7 +13,7 @@ from ragnarok_common.prompt import ChatFormat -from ..types import NOT_GIVEN, NotGiven +from ...types import NOT_GIVEN, NotGiven from .base import LLMClient, LLMOptions @@ -55,6 +55,10 @@ def __init__( Args: model_name: Name of the model to use. hf_api_key: The Hugging Face API key for authentication. + + + Raises: + ImportError: If the 'local' extra requirements are not installed. """ if not HAS_LOCAL_LLM: raise ImportError("You need to install the 'local' extra requirements to use local LLM models") diff --git a/packages/ragnarok-common/src/ragnarok_common/llms/types.py b/packages/ragnarok-common/src/ragnarok_common/types.py similarity index 100% rename from packages/ragnarok-common/src/ragnarok_common/llms/types.py rename to packages/ragnarok-common/src/ragnarok_common/types.py