Skip to content

Commit

Permalink
feat(embedding): add litellm embeddings (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored May 21, 2024
1 parent c06b11d commit 2fb275f
Show file tree
Hide file tree
Showing 19 changed files with 248 additions and 120 deletions.
25 changes: 14 additions & 11 deletions docs/how-to/use_chromadb_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

To use Chromadb with db-ally you need to install the chromadb extension

```python
```bash
pip install dbally[chromadb]
```

Expand All @@ -33,28 +33,33 @@ or [set up Chromadb in the client/server mode](https://docs.trychroma.com/usage-
chroma_client = chromadb.HttpClient(host='localhost', port=8000)
```

Next, you can either use [one of dbally embedding clients][dbally.embedding_client.EmbeddingClient], such as [OpenAiEmbeddingClient][dbally.embedding_client.OpenAiEmbeddingClient]
Next, you can either use [one of dbally embedding clients][dbally.embeddings.EmbeddingClient], such as [LiteLLMEmbeddingClient][dbally.embeddings.LiteLLMEmbeddingClient]

```python
from dbally.embedding_client import OpenAiEmbeddingClient

embedding_client=OpenAiEmbeddingClient(
api_key="your-api-key",
)
from dbally.embeddings.litellm import LiteLLMEmbeddingClient

embedding_client=LiteLLMEmbeddingClient(
model="text-embedding-3-small", # to use openai embedding model
api_key="your-api-key",
)
```

or [Chromadb embedding functions](https://docs.trychroma.com/embeddings)

```
```python
from chromadb.utils import embedding_functions

embedding_client = embedding_functions.DefaultEmbeddingFunction()
```

to define your [`ChromadbStore`][dbally.similarity.ChromadbStore].

```python
store = ChromadbStore(index_name="myChromaIndex", chroma_client=chroma_client, embedding_function=embedding_client)
store = ChromadbStore(
index_name="myChromaIndex",
chroma_client=chroma_client,
embedding_function=embedding_client,
)
```

After this setup, you can initialize the SimilarityIndex
Expand All @@ -63,8 +68,6 @@ After this setup, you can initialize the SimilarityIndex
from typing import Annotated

country_similarity = SimilarityIndex(store, DummyCountryFetcher())


```

and [update it and find the closest matches in the same way as in built-in similarity indices](./use_custom_similarity_store.md/#using-the-similarity-index) .
14 changes: 8 additions & 6 deletions docs/how-to/use_custom_similarity_fetcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@ from dbally.similarity.store import FaissStore
breeds_similarity = SimilarityIndex(
fetcher=DogBreedsFetcher(),
store=FaissStore(
index_dir="./similarity_indexes",
index_name="breeds_similarity",
embedding_client=OpenAiEmbeddingClient(
api_key="your-api-key",
)
index_dir="./similarity_indexes",
index_name="breeds_similarity",
),
embedding_client=LiteLLMEmbeddingClient(
model="text-embedding-3-small", # to use openai embedding model
api_key=os.environ["OPENAI_API_KEY"],
),
)
```

In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `OpenAiEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md).
In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `LiteLLMEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md).

## Using the Similarity Index

Expand Down
11 changes: 6 additions & 5 deletions docs/quickstart/quickstart2.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,19 @@ Next, let's define a store that will store the country names and can be used to

```python
from dbally.similarity import FaissStore
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.embeddings.litellm import LiteLLMEmbeddingClient

country_store = FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=OpenAiEmbeddingClient(
api_key="your-api-key",
)
embedding_client=LiteLLMEmbeddingClient(
model="text-embedding-3-small", # to use openai embedding model
api_key=os.environ["OPENAI_API_KEY"],
),
)
```

In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `OpenAiEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key.
In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `LiteLLMEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key.

Finally, let's define the similarity index:

Expand Down
7 changes: 4 additions & 3 deletions docs/quickstart/quickstart2_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dbally import decorators, SqlAlchemyBaseView
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.litellm import LiteLLM

engine = create_engine('sqlite:///candidates.db')
Expand All @@ -30,9 +30,10 @@
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=OpenAiEmbeddingClient(
embedding_client=LiteLLMEmbeddingClient(
model="text-embedding-3-small", # to use openai embedding model
api_key=os.environ["OPENAI_API_KEY"],
)
),
),
)

Expand Down
7 changes: 4 additions & 3 deletions docs/quickstart/quickstart3_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult
from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex
from dbally.embedding_client.openai import OpenAiEmbeddingClient
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.llms.litellm import LiteLLM

engine = create_engine('sqlite:///candidates.db')
Expand All @@ -31,9 +31,10 @@
store=FaissStore(
index_dir="./similarity_indexes",
index_name="country_similarity",
embedding_client=OpenAiEmbeddingClient(
embedding_client=LiteLLMEmbeddingClient(
model="text-embedding-3-small", # to use openai embedding model
api_key=os.environ["OPENAI_API_KEY"],
)
),
),
)

Expand Down
23 changes: 23 additions & 0 deletions src/dbally/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
from ._main import create_collection
from ._types import NOT_GIVEN, NotGiven
from .collection import Collection
from .embeddings._exceptions import (
EmbeddingConnectionError,
EmbeddingError,
EmbeddingResponseError,
EmbeddingStatusError,
)
from .llms.clients._exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError

__all__ = [
Expand All @@ -25,10 +31,27 @@
"DataFrameBaseView",
"ExecutionResult",
"DbAllyError",
"EmbeddingError",
"EmbeddingConnectionError",
"EmbeddingResponseError",
"EmbeddingStatusError",
"LLMError",
"LLMConnectionError",
"LLMResponseError",
"LLMStatusError",
"NotGiven",
"NOT_GIVEN",
]

# Update the __module__ attribute for exported symbols so that
# error messages point to this module instead of the module
# it was originally defined in, e.g.
# dbally._exceptions.LLMError -> dbally.LLMError
__locals = locals()
for __name in __all__:
if not __name.startswith("__"):
try:
__locals[__name].__module__ = "dbally"
except (TypeError, AttributeError):
# Some of our exported symbols are builtins which we can't set attributes for.
pass
4 changes: 0 additions & 4 deletions src/dbally/embedding_client/__init__.py

This file was deleted.

19 changes: 0 additions & 19 deletions src/dbally/embedding_client/base.py

This file was deleted.

52 changes: 0 additions & 52 deletions src/dbally/embedding_client/openai.py

This file was deleted.

4 changes: 4 additions & 0 deletions src/dbally/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import EmbeddingClient
from .litellm import LiteLLMEmbeddingClient

__all__ = ["EmbeddingClient", "LiteLLMEmbeddingClient"]
39 changes: 39 additions & 0 deletions src/dbally/embeddings/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from .._exceptions import DbAllyError


class EmbeddingError(DbAllyError):
"""
Base class for all exceptions raised by the EmbeddingClient.
"""

def __init__(self, message: str) -> None:
super().__init__(message)
self.message = message


class EmbeddingConnectionError(EmbeddingError):
"""
Raised when there is an error connecting to the embedding API.
"""

def __init__(self, message: str = "Connection error.") -> None:
super().__init__(message)


class EmbeddingStatusError(EmbeddingError):
"""
Raised when an API response has a status code of 4xx or 5xx.
"""

def __init__(self, message: str, status_code: int) -> None:
super().__init__(message)
self.status_code = status_code


class EmbeddingResponseError(EmbeddingError):
"""
Raised when an API response has an invalid schema.
"""

def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None:
super().__init__(message)
20 changes: 20 additions & 0 deletions src/dbally/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod
from typing import List


class EmbeddingClient(ABC):
"""
Abstract client for creating text embeddings.
"""

@abstractmethod
async def get_embeddings(self, data: List[str]) -> List[List[float]]:
"""
Creates embeddings for the given strings.
Args:
data: List of strings to get embeddings for.
Returns:
List of embeddings for the given strings.
"""
Loading

0 comments on commit 2fb275f

Please sign in to comment.