Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): add embeddings #2

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ jobs:
run: |
# run with coverage to not execute tests twice
source venv/bin/activate
coverage run -m pytest -v -p no:warnings --junitxml=report.xml
coverage run -m pytest -v -p no:warnings --junitxml=report.xml
coverage report
coverage xml

Expand Down
Empty file.
19 changes: 19 additions & 0 deletions packages/ragnarok-common/src/ragnarok_common/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod


class Embeddings(ABC):
"""
Abstract client for communication with embedding models.
"""

@abstractmethod
async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Creates embeddings for the given strings.

Args:
data: List of strings to get embeddings for.

Returns:
List of embeddings for the given strings.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
class EmbeddingError(Exception):
"""
Base class for all exceptions raised by the EmbeddingClient.
"""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


class EmbeddingConnectionError(EmbeddingError):
"""
Raised when there is an error connecting to the embedding API.
"""

def __init__(self, message: str = "Connection error.") -> None:
super().__init__(message)


class EmbeddingStatusError(EmbeddingError):
"""
Raised when an API response has a status code of 4xx or 5xx.
"""

def __init__(self, message: str, status_code: int) -> None:
super().__init__(message)
self.status_code = status_code


class EmbeddingResponseError(EmbeddingError):
"""
Raised when an API response has an invalid schema.
"""

def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
super().__init__(message)
85 changes: 85 additions & 0 deletions packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Optional

try:
import litellm

HAS_LITELLM = True
except ImportError:
HAS_LITELLM = False

from ragnarok_common.embeddings.base import Embeddings
from ragnarok_common.embeddings.exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError


class LiteLLMEmbeddings(Embeddings):
"""
Client for creating text embeddings using LiteLLM API.
"""

def __init__(
self,
model: str = "text-embedding-3-small",
options: Optional[dict] = None,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
) -> None:
"""
Constructs the LiteLLMEmbeddingClient.

Args:
model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\
to be used. Default is "text-embedding-3-small".
options: Additional options to pass to the LiteLLM API.
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. API key to be used. If not specified, an environment variable will be used,
for more information, follow the instructions for your specific vendor in the\
[LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding).
api_version: The API version for the call.

Raises:
ImportError: If the litellm package is not installed.
"""
if not HAS_LITELLM:
raise ImportError("You need to install litellm package to use LiteLLM models")

super().__init__()
self.model = model
self.options = options or {}
self.api_base = api_base
self.api_key = api_key
self.api_version = api_version

async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Creates embeddings for the given strings.

Args:
data: List of strings to get embeddings for.

Returns:
List of embeddings for the given strings.

Raises:
EmbeddingConnectionError: If there is a connection error with the embedding API.
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""

try:
response = await litellm.aembedding(
input=data,
model=self.model,
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**self.options,
)
except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise EmbeddingResponseError() from exc

return [embedding["embedding"] for embedding in response.data]
8 changes: 4 additions & 4 deletions packages/ragnarok-dev-kit/src/ragnarok_dev_kit/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import jinja2
import typer
from pydantic import BaseModel
from ragnarok_dev_kit.discovery.prompt_discovery import PromptDiscovery

from ragnarok_common.llms import LiteLLM
from ragnarok_common.llms.clients import LiteLLMOptions
from ragnarok_dev_kit.discovery.prompt_discovery import PromptDiscovery


class PromptState:
Expand All @@ -26,7 +26,7 @@ class PromptState:
"Render Prompt" button and reflects in the "Rendered Prompt" field.
It is used for communication with the LLM.
llm_model_name (str): The name of the selected LLM model.
llm_api_key (str): The API key for the chosen LLM model.
llm_api_key (str | None): The API key for the chosen LLM model.
temp_field_name (str): Temporary field name used internally.
"""

Expand All @@ -35,7 +35,7 @@ class PromptState:
dynamic_tb: dict = {}
current_prompt = None
llm_model_name: str = ""
llm_api_key: str = ""
llm_api_key: str | None = ""
temp_field_name: str = ""


Expand Down Expand Up @@ -170,7 +170,7 @@ def get_input_type_fields(obj: BaseModel) -> list[dict]:


@typer_app.command()
def run_app(prompts_paths: str, llm_model: str, llm_api_key: str = None) -> None:
def run_app(prompts_paths: str, llm_model: str, llm_api_key: str | None = None) -> None:
"""
Launches the interactive application for working with Large Language Models (LLMs).

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,3 @@ def discover(self) -> dict:
)._asdict()

return result_dict

Loading