diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe1de654..3ea6ddf7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -141,7 +141,7 @@ jobs: run: | # run with coverage to not execute tests twice source venv/bin/activate - coverage run -m pytest -v -p no:warnings --junitxml=report.xml + coverage run -m pytest -v -p no:warnings --junitxml=report.xml coverage report coverage xml diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/__init__.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py new file mode 100644 index 00000000..ede4fcad --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/base.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod + + +class Embeddings(ABC): + """ + Abstract client for communication with embedding models. + """ + + @abstractmethod + async def embed_text(self, data: list[str]) -> list[list[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + """ diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py new file mode 100644 index 00000000..4dd99ad1 --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/exceptions.py @@ -0,0 +1,36 @@ +class EmbeddingError(Exception): + """ + Base class for all exceptions raised by the EmbeddingClient. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class EmbeddingConnectionError(EmbeddingError): + """ + Raised when there is an error connecting to the embedding API. + """ + + def __init__(self, message: str = "Connection error.") -> None: + super().__init__(message) + + +class EmbeddingStatusError(EmbeddingError): + """ + Raised when an API response has a status code of 4xx or 5xx. + """ + + def __init__(self, message: str, status_code: int) -> None: + super().__init__(message) + self.status_code = status_code + + +class EmbeddingResponseError(EmbeddingError): + """ + Raised when an API response has an invalid schema. + """ + + def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: + super().__init__(message) diff --git a/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py b/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py new file mode 100644 index 00000000..4a8654fc --- /dev/null +++ b/packages/ragnarok-common/src/ragnarok_common/embeddings/litellm.py @@ -0,0 +1,85 @@ +from typing import Optional + +try: + import litellm + + HAS_LITELLM = True +except ImportError: + HAS_LITELLM = False + +from ragnarok_common.embeddings.base import Embeddings +from ragnarok_common.embeddings.exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError + + +class LiteLLMEmbeddings(Embeddings): + """ + Client for creating text embeddings using LiteLLM API. + """ + + def __init__( + self, + model: str = "text-embedding-3-small", + options: Optional[dict] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + """ + Constructs the LiteLLMEmbeddingClient. + + Args: + model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\ + to be used. Default is "text-embedding-3-small". + options: Additional options to pass to the LiteLLM API. + api_base: The API endpoint you want to call the model with. + api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, + for more information, follow the instructions for your specific vendor in the\ + [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). + api_version: The API version for the call. + + Raises: + ImportError: If the litellm package is not installed. + """ + if not HAS_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + + super().__init__() + self.model = model + self.options = options or {} + self.api_base = api_base + self.api_key = api_key + self.api_version = api_version + + async def embed_text(self, data: list[str]) -> list[list[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + + Raises: + EmbeddingConnectionError: If there is a connection error with the embedding API. + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + + try: + response = await litellm.aembedding( + input=data, + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + **self.options, + ) + except litellm.openai.APIConnectionError as exc: + raise EmbeddingConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise EmbeddingResponseError() from exc + + return [embedding["embedding"] for embedding in response.data] diff --git a/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/app/main.py b/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/app/main.py index 050eac08..292ec980 100644 --- a/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/app/main.py +++ b/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/app/main.py @@ -5,10 +5,10 @@ import jinja2 import typer from pydantic import BaseModel +from ragnarok_dev_kit.discovery.prompt_discovery import PromptDiscovery from ragnarok_common.llms import LiteLLM from ragnarok_common.llms.clients import LiteLLMOptions -from ragnarok_dev_kit.discovery.prompt_discovery import PromptDiscovery class PromptState: @@ -26,7 +26,7 @@ class PromptState: "Render Prompt" button and reflects in the "Rendered Prompt" field. It is used for communication with the LLM. llm_model_name (str): The name of the selected LLM model. - llm_api_key (str): The API key for the chosen LLM model. + llm_api_key (str | None): The API key for the chosen LLM model. temp_field_name (str): Temporary field name used internally. """ @@ -35,7 +35,7 @@ class PromptState: dynamic_tb: dict = {} current_prompt = None llm_model_name: str = "" - llm_api_key: str = "" + llm_api_key: str | None = "" temp_field_name: str = "" @@ -170,7 +170,7 @@ def get_input_type_fields(obj: BaseModel) -> list[dict]: @typer_app.command() -def run_app(prompts_paths: str, llm_model: str, llm_api_key: str = None) -> None: +def run_app(prompts_paths: str, llm_model: str, llm_api_key: str | None = None) -> None: """ Launches the interactive application for working with Large Language Models (LLMs). diff --git a/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/discovery/prompt_discovery.py b/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/discovery/prompt_discovery.py index 5a85ff96..5353ac3d 100644 --- a/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/discovery/prompt_discovery.py +++ b/packages/ragnarok-dev-kit/src/ragnarok_dev_kit/discovery/prompt_discovery.py @@ -92,4 +92,3 @@ def discover(self) -> dict: )._asdict() return result_dict -