diff --git a/docs/how-to/use_chromadb_store.md b/docs/how-to/use_chromadb_store.md index b715d8e8..6e9d5a58 100644 --- a/docs/how-to/use_chromadb_store.md +++ b/docs/how-to/use_chromadb_store.md @@ -7,7 +7,7 @@ To use Chromadb with db-ally you need to install the chromadb extension -```python +```bash pip install dbally[chromadb] ``` @@ -33,28 +33,33 @@ or [set up Chromadb in the client/server mode](https://docs.trychroma.com/usage- chroma_client = chromadb.HttpClient(host='localhost', port=8000) ``` -Next, you can either use [one of dbally embedding clients][dbally.embedding_client.EmbeddingClient], such as [OpenAiEmbeddingClient][dbally.embedding_client.OpenAiEmbeddingClient] +Next, you can either use [one of dbally embedding clients][dbally.embeddings.EmbeddingClient], such as [LiteLLMEmbeddingClient][dbally.embeddings.LiteLLMEmbeddingClient] ```python -from dbally.embedding_client import OpenAiEmbeddingClient - -embedding_client=OpenAiEmbeddingClient( - api_key="your-api-key", - ) +from dbally.embeddings.litellm import LiteLLMEmbeddingClient +embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embedding model + api_key="your-api-key", +) ``` or [Chromadb embedding functions](https://docs.trychroma.com/embeddings) -``` +```python from chromadb.utils import embedding_functions + embedding_client = embedding_functions.DefaultEmbeddingFunction() ``` to define your [`ChromadbStore`][dbally.similarity.ChromadbStore]. ```python -store = ChromadbStore(index_name="myChromaIndex", chroma_client=chroma_client, embedding_function=embedding_client) +store = ChromadbStore( + index_name="myChromaIndex", + chroma_client=chroma_client, + embedding_function=embedding_client, +) ``` After this setup, you can initialize the SimilarityIndex @@ -63,8 +68,6 @@ After this setup, you can initialize the SimilarityIndex from typing import Annotated country_similarity = SimilarityIndex(store, DummyCountryFetcher()) - - ``` and [update it and find the closest matches in the same way as in built-in similarity indices](./use_custom_similarity_store.md/#using-the-similarity-index) . diff --git a/docs/how-to/use_custom_similarity_fetcher.md b/docs/how-to/use_custom_similarity_fetcher.md index 06a01197..b79b6554 100644 --- a/docs/how-to/use_custom_similarity_fetcher.md +++ b/docs/how-to/use_custom_similarity_fetcher.md @@ -41,15 +41,17 @@ from dbally.similarity.store import FaissStore breeds_similarity = SimilarityIndex( fetcher=DogBreedsFetcher(), store=FaissStore( - index_dir="./similarity_indexes", - index_name="breeds_similarity", - embedding_client=OpenAiEmbeddingClient( - api_key="your-api-key", - ) + index_dir="./similarity_indexes", + index_name="breeds_similarity", + ), + embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embedding model + api_key=os.environ["OPENAI_API_KEY"], + ), ) ``` -In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `OpenAiEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md). +In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `LiteLLMEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md). ## Using the Similarity Index diff --git a/docs/quickstart/quickstart2.md b/docs/quickstart/quickstart2.md index 8f85f093..cb839bff 100644 --- a/docs/quickstart/quickstart2.md +++ b/docs/quickstart/quickstart2.md @@ -60,18 +60,19 @@ Next, let's define a store that will store the country names and can be used to ```python from dbally.similarity import FaissStore -from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.embeddings.litellm import LiteLLMEmbeddingClient country_store = FaissStore( index_dir="./similarity_indexes", index_name="country_similarity", - embedding_client=OpenAiEmbeddingClient( - api_key="your-api-key", - ) + embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embedding model + api_key=os.environ["OPENAI_API_KEY"], + ), ) ``` -In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `OpenAiEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key. +In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `LiteLLMEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key. Finally, let's define the similarity index: diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index cd9669d0..d330470a 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -11,7 +11,7 @@ from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex -from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') @@ -30,9 +30,10 @@ store=FaissStore( index_dir="./similarity_indexes", index_name="country_similarity", - embedding_client=OpenAiEmbeddingClient( + embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embedding model api_key=os.environ["OPENAI_API_KEY"], - ) + ), ), ) diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 3732c9da..f0e385b1 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -11,7 +11,7 @@ from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex -from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') @@ -31,9 +31,10 @@ store=FaissStore( index_dir="./similarity_indexes", index_name="country_similarity", - embedding_client=OpenAiEmbeddingClient( + embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embedding model api_key=os.environ["OPENAI_API_KEY"], - ) + ), ), ) diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index a947eb97..258c2f79 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -12,6 +12,12 @@ from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .collection import Collection +from .embeddings._exceptions import ( + EmbeddingConnectionError, + EmbeddingError, + EmbeddingResponseError, + EmbeddingStatusError, +) from .llms.clients._exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError __all__ = [ @@ -25,6 +31,10 @@ "DataFrameBaseView", "ExecutionResult", "DbAllyError", + "EmbeddingError", + "EmbeddingConnectionError", + "EmbeddingResponseError", + "EmbeddingStatusError", "LLMError", "LLMConnectionError", "LLMResponseError", @@ -32,3 +42,16 @@ "NotGiven", "NOT_GIVEN", ] + +# Update the __module__ attribute for exported symbols so that +# error messages point to this module instead of the module +# it was originally defined in, e.g. +# dbally._exceptions.LLMError -> dbally.LLMError +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + try: + __locals[__name].__module__ = "dbally" + except (TypeError, AttributeError): + # Some of our exported symbols are builtins which we can't set attributes for. + pass diff --git a/src/dbally/embedding_client/__init__.py b/src/dbally/embedding_client/__init__.py deleted file mode 100644 index e2cfb278..00000000 --- a/src/dbally/embedding_client/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import EmbeddingClient -from .openai import OpenAiEmbeddingClient - -__all__ = ["EmbeddingClient", "OpenAiEmbeddingClient"] diff --git a/src/dbally/embedding_client/base.py b/src/dbally/embedding_client/base.py deleted file mode 100644 index a582aa9e..00000000 --- a/src/dbally/embedding_client/base.py +++ /dev/null @@ -1,19 +0,0 @@ -# disable args docstring check as args are documented in OpenAI API docs -import abc -from typing import List - - -class EmbeddingClient(metaclass=abc.ABCMeta): - """Abstract client for creating text embeddings.""" - - @abc.abstractmethod - async def get_embeddings(self, data: List[str]) -> List[List[float]]: - """ - For a given list of strings returns a list of embeddings. - - Args: - data: List of strings to get embeddings for. - - Returns: - List of embeddings for the given strings. - """ diff --git a/src/dbally/embedding_client/openai.py b/src/dbally/embedding_client/openai.py deleted file mode 100644 index 5f601ead..00000000 --- a/src/dbally/embedding_client/openai.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Any, Dict, List, Optional - -from dbally.embedding_client.base import EmbeddingClient - - -class OpenAiEmbeddingClient(EmbeddingClient): - """ - Client for creating text embeddings using OpenAI API. - """ - - def __init__(self, api_key: str, model: str = "text-embedding-3-small", openai_options: Optional[Dict] = None): - """ - Initializes the OpenAiEmbeddingClient. - - Args: - api_key: The OpenAI API key. - model: The model to use for embeddings. - openai_options: Additional options to pass to the OpenAI API. - """ - super().__init__() - self.api_key = api_key - self.model = model - self.openai_options = openai_options - - try: - from openai import AsyncOpenAI # pylint: disable=import-outside-toplevel - except ImportError as exc: - raise ImportError("You need to install openai package to use GPT models") from exc - - self._openai = AsyncOpenAI(api_key=self.api_key) - - async def get_embeddings(self, data: List[str]) -> List[List[float]]: - """ - For a given list of strings returns a list of embeddings. - - Args: - data: List of strings to get embeddings for. - - Returns: - List of embeddings for the given strings. - """ - kwargs: Dict[str, Any] = { - "model": self.model, - } - if self.openai_options: - kwargs.update(self.openai_options) - - response = await self._openai.embeddings.create( - input=data, - **kwargs, - ) - return [embedding.embedding for embedding in response.data] diff --git a/src/dbally/embeddings/__init__.py b/src/dbally/embeddings/__init__.py new file mode 100644 index 00000000..67fe9b7f --- /dev/null +++ b/src/dbally/embeddings/__init__.py @@ -0,0 +1,4 @@ +from .base import EmbeddingClient +from .litellm import LiteLLMEmbeddingClient + +__all__ = ["EmbeddingClient", "LiteLLMEmbeddingClient"] diff --git a/src/dbally/embeddings/_exceptions.py b/src/dbally/embeddings/_exceptions.py new file mode 100644 index 00000000..37c24f3c --- /dev/null +++ b/src/dbally/embeddings/_exceptions.py @@ -0,0 +1,39 @@ +from .._exceptions import DbAllyError + + +class EmbeddingError(DbAllyError): + """ + Base class for all exceptions raised by the EmbeddingClient. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class EmbeddingConnectionError(EmbeddingError): + """ + Raised when there is an error connecting to the embedding API. + """ + + def __init__(self, message: str = "Connection error.") -> None: + super().__init__(message) + + +class EmbeddingStatusError(EmbeddingError): + """ + Raised when an API response has a status code of 4xx or 5xx. + """ + + def __init__(self, message: str, status_code: int) -> None: + super().__init__(message) + self.status_code = status_code + + +class EmbeddingResponseError(EmbeddingError): + """ + Raised when an API response has an invalid schema. + """ + + def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: + super().__init__(message) diff --git a/src/dbally/embeddings/base.py b/src/dbally/embeddings/base.py new file mode 100644 index 00000000..4c757251 --- /dev/null +++ b/src/dbally/embeddings/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod +from typing import List + + +class EmbeddingClient(ABC): + """ + Abstract client for creating text embeddings. + """ + + @abstractmethod + async def get_embeddings(self, data: List[str]) -> List[List[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + """ diff --git a/src/dbally/embeddings/litellm.py b/src/dbally/embeddings/litellm.py new file mode 100644 index 00000000..6e312c56 --- /dev/null +++ b/src/dbally/embeddings/litellm.py @@ -0,0 +1,85 @@ +from typing import Dict, List, Optional + +try: + import litellm + + HAVE_LITELLM = True +except ImportError: + HAVE_LITELLM = False + +from dbally.embeddings.base import EmbeddingClient + +from ._exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError + + +class LiteLLMEmbeddingClient(EmbeddingClient): + """ + Client for creating text embeddings using LiteLLM API. + """ + + def __init__( + self, + model: str = "text-embedding-3-small", + options: Optional[Dict] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + """ + Constructs the LiteLLMEmbeddingClient. + + Args: + model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\ + to be used. Default is "text-embedding-3-small". + options: Additional options to pass to the LiteLLM API. + api_base: The API endpoint you want to call the model with. + api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, + for more information, follow the instructions for your specific vendor in the\ + [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). + api_version: The API version for the call. + + Raises: + ImportError: If the litellm package is not installed. + """ + if not HAVE_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + + super().__init__() + self.model = model + self.options = options or {} + self.api_base = api_base + self.api_key = api_key + self.api_version = api_version + + async def get_embeddings(self, data: List[str]) -> List[List[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + + Raises: + EmbeddingConnectionError: If there is a connection error with the embedding API. + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + try: + response = await litellm.aembedding( + input=data, + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + **self.options, + ) + except litellm.openai.APIConnectionError as exc: + raise EmbeddingConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise EmbeddingResponseError() from exc + + return [embedding["embedding"] for embedding in response.data] diff --git a/src/dbally/llms/clients/litellm.py b/src/dbally/llms/clients/litellm.py index 82752b1d..d24a58c2 100644 --- a/src/dbally/llms/clients/litellm.py +++ b/src/dbally/llms/clients/litellm.py @@ -1,8 +1,13 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union -import litellm -from openai import APIConnectionError, APIResponseValidationError, APIStatusError +try: + import litellm + + HAVE_LITELLM = True +except ImportError: + HAVE_LITELLM = False + from dbally.data_models.audit import LLMEvent from dbally.llms.clients.base import LLMClient, LLMOptions @@ -53,7 +58,13 @@ def __init__( base_url: Base URL of the LLM API. api_key: API key used to authenticate with the LLM API. api_version: API version of the LLM API. + + Raises: + ImportError: If the litellm package is not installed. """ + if not HAVE_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + super().__init__(model_name) self.base_url = base_url self.api_key = api_key @@ -91,17 +102,17 @@ async def call( api_key=self.api_key, api_version=self.api_version, response_format=response_format, - **options.dict(), # type: ignore + **options.dict(), ) - except APIConnectionError as exc: + except litellm.openai.APIConnectionError as exc: raise LLMConnectionError() from exc - except APIStatusError as exc: + except litellm.openai.APIStatusError as exc: raise LLMStatusError(exc.message, exc.status_code) from exc - except APIResponseValidationError as exc: + except litellm.openai.APIResponseValidationError as exc: raise LLMResponseError() from exc event.completion_tokens = response.usage.completion_tokens event.prompt_tokens = response.usage.prompt_tokens event.total_tokens = response.usage.total_tokens - return response.choices[0].message.content # type: ignore + return response.choices[0].message.content diff --git a/src/dbally/llms/litellm.py b/src/dbally/llms/litellm.py index 9b295214..f6b6727d 100644 --- a/src/dbally/llms/litellm.py +++ b/src/dbally/llms/litellm.py @@ -1,7 +1,12 @@ from functools import cached_property from typing import Dict, Optional -from litellm import token_counter +try: + import litellm + + HAVE_LITELLM = True +except ImportError: + HAVE_LITELLM = False from dbally.llms.base import LLM from dbally.llms.clients.litellm import LiteLLMClient, LiteLLMOptions @@ -25,18 +30,24 @@ def __init__( api_version: Optional[str] = None, ) -> None: """ - Construct a new LiteLLM instance. + Constructs a new LiteLLM instance. Args: - model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/providers) to be used, - default is "gpt-3.5-turbo". + model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/providers) to be used.\ + Default is "gpt-3.5-turbo". default_options: Default options to be used. base_url: Base URL of the LLM API. api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, for more information, follow the instructions for your specific vendor in the\ [LiteLLM documentation](https://docs.litellm.ai/docs/providers). api_version: API version to be used. If not specified, the default version will be used. + + Raises: + ImportError: If the litellm package is not installed. """ + if not HAVE_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + super().__init__(model_name, default_options) self.base_url = base_url self.api_key = api_key @@ -56,7 +67,7 @@ def _client(self) -> LiteLLMClient: def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: """ - Count tokens in the messages using a specified model. + Counts tokens in the messages using a specified model. Args: messages: Messages to count tokens for. @@ -65,4 +76,6 @@ def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: Returns: Number of tokens in the messages. """ - return sum(token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages) + return sum( + litellm.token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages + ) diff --git a/src/dbally/similarity/chroma_store.py b/src/dbally/similarity/chroma_store.py index 53ef657a..95dd88c9 100644 --- a/src/dbally/similarity/chroma_store.py +++ b/src/dbally/similarity/chroma_store.py @@ -3,7 +3,7 @@ import chromadb -from dbally.embedding_client.base import EmbeddingClient +from dbally.embeddings.base import EmbeddingClient from dbally.similarity.store import SimilarityStore diff --git a/src/dbally/similarity/faiss_store.py b/src/dbally/similarity/faiss_store.py index 7f43e34b..46f2c725 100644 --- a/src/dbally/similarity/faiss_store.py +++ b/src/dbally/similarity/faiss_store.py @@ -4,7 +4,7 @@ import faiss import numpy as np -from dbally.embedding_client.base import EmbeddingClient +from dbally.embeddings.base import EmbeddingClient from dbally.similarity.store import SimilarityStore diff --git a/tests/integration/test_index_with_chroma.py b/tests/integration/test_index_with_chroma.py index b28ec1db..a2698018 100644 --- a/tests/integration/test_index_with_chroma.py +++ b/tests/integration/test_index_with_chroma.py @@ -2,7 +2,7 @@ import pytest from chromadb import Documents, EmbeddingFunction, Embeddings -from dbally.embedding_client.base import EmbeddingClient +from dbally.embeddings.base import EmbeddingClient from dbally.similarity import ChromadbStore from dbally.similarity.fetcher import SimilarityFetcher from dbally.similarity.index import SimilarityIndex diff --git a/tests/unit/similarity/test_chroma.py b/tests/unit/similarity/test_chroma.py index f54301ed..4cad4497 100644 --- a/tests/unit/similarity/test_chroma.py +++ b/tests/unit/similarity/test_chroma.py @@ -4,7 +4,7 @@ import chromadb import pytest -from dbally.embedding_client import EmbeddingClient +from dbally.embeddings import EmbeddingClient from dbally.similarity import ChromadbStore DEFAULT_METADATA = {"hnsw:space": "l2"}