Skip to content

Commit

Permalink
[ENH] Create RoboflowEmbeddingFunction (chroma-core#1434)
Browse files Browse the repository at this point in the history
## Description of changes

This PR adds a new `RoboflowEmbeddingFunction` with which a user can
calculate CLIP text and image embeddings using [Roboflow
Inference](https://inference.roboflow.com).

## Test plan

You can test the embedding function using the following code:

```python
import chromadb
import os
from chromadb.utils.embedding_functions import RoboflowEmbeddingFunction
import uuid
from PIL import Image

client = chromadb.PersistentClient(path="database")

collection = client.create_collection(name="images", metadata={"hnsw:space": "cosine"})
# collection = client.get_collection(name="images")

IMAGE_DIR = "images/train/images/"
SERVER_URL = "https://infer.roboflow.com"
API_KEY = ""

results = []

ef = RoboflowEmbeddingFunction(API_KEY)

documents = [os.path.join(IMAGE_DIR, img) for img in os.listdir(IMAGE_DIR)]
embeddings = ef(images = [img for img in documents])
ids = [str(uuid.uuid4()) for _ in range(len(documents))]

print(len(embeddings))

collection.add(
    embeddings=embeddings,
    documents=documents,
    ids=ids,
)

query = ef(prompt = "baseball")

results = collection.query(
    query_embeddings=query,
    n_results=3
)

top_result = results["documents"]

for i in top_result:
    print(i)
```

You will need a [Roboflow API
key](https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key)

- [X] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes

I will file a PR to the `chroma-core/docs` repository with
documentation.

---------

Co-authored-by: Anton Troynikov <[email protected]>
  • Loading branch information
capjamesg and atroyn authored Apr 2, 2024
1 parent 1cd2ced commit d3f61b1
Show file tree
Hide file tree
Showing 2 changed files with 328 additions and 1 deletion.
71 changes: 70 additions & 1 deletion chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
is_document,
)

from io import BytesIO
from pathlib import Path
import os
import tarfile
Expand All @@ -27,6 +28,7 @@
import inspect
import json
import sys
import base64

try:
from chromadb.is_thin_client import is_thin_client
Expand Down Expand Up @@ -740,6 +742,74 @@ def __call__(self, input: Union[Documents, Images]) -> Embeddings:
return embeddings


class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]):
def __init__(
self, api_key: str = "", api_url = "https://infer.roboflow.com"
) -> None:
"""
Create a RoboflowEmbeddingFunction.
Args:
api_key (str): Your API key for the Roboflow API.
api_url (str, optional): The URL of the Roboflow API. Defaults to "https://infer.roboflow.com".
"""
if not api_key:
api_key = os.environ.get("ROBOFLOW_API_KEY")

self._api_url = api_url
self._api_key = api_key

try:
self._PILImage = importlib.import_module("PIL.Image")
except ImportError:
raise ValueError(
"The PIL python package is not installed. Please install it with `pip install pillow`"
)

def __call__(self, input: Union[Documents, Images]) -> Embeddings:
embeddings = []

for item in input:
if is_image(item):
image = self._PILImage.fromarray(item)

buffer = BytesIO()
image.save(buffer, format="JPEG")
base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8")

infer_clip_payload = {
"image": {
"type": "base64",
"value": base64_image,
},
}

res = requests.post(
f"{self._api_url}/clip/embed_image?api_key={self._api_key}",
json=infer_clip_payload,
)

result = res.json()['embeddings']

embeddings.append(result[0])

elif is_document(item):
infer_clip_payload = {
"text": input,
}

res = requests.post(
f"{self._api_url}/clip/embed_text?api_key={self._api_key}",
json=infer_clip_payload,
)

result = res.json()['embeddings']

embeddings.append(result[0])

return embeddings


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
Expand Down Expand Up @@ -885,7 +955,6 @@ def __call__(self, input: Documents) -> Embeddings:
],
)


# List of all classes in this module
_classes = [
name
Expand Down
Loading

0 comments on commit d3f61b1

Please sign in to comment.