Skip to content

Commit

Permalink
docs: Added docs
Browse files Browse the repository at this point in the history
- renamed EF to align with rest of the EFs naming convention
- API token now defaults to CF_API_TOKEN env var (Account or GW endpoint still need to be passed)
- renamed gateway_url to gateway_endpoint for better consistency
  • Loading branch information
tazarov committed Oct 19, 2024
1 parent 9a310a6 commit 97ca818
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 22 deletions.
38 changes: 26 additions & 12 deletions chromadbx/embeddings/cloudflare.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# original work on this was done by @mileszim - https://github.com/mileszim/chroma/tree/cloudflare-workers-ai-embedding
import logging
import os
from typing import Optional, Dict, cast

import httpx
Expand All @@ -9,35 +10,48 @@
logger = logging.getLogger(__name__)


class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]):
class CloudflareWorkersAIEmbeddings(EmbeddingFunction[Documents]):
# Follow API Quickstart for Cloudflare Workers AI
# https://developers.cloudflare.com/workers-ai/
# Information about the text embedding modules in Google Vertex AI
# https://developers.cloudflare.com/workers-ai/models/embedding/
def __init__(
self,
api_token: str,
account_id: Optional[str] = None,
model_name: Optional[str] = "@cf/baai/bge-base-en-v1.5",
gateway_url: Optional[
*,
api_token: Optional[str] = os.getenv("CF_API_TOKEN"),
account_id: Optional[str] = None,
gateway_endpoint: Optional[
str
] = None, # use Cloudflare AI Gateway instead of the usual endpoint
# right now endpoint schema supports up to 100 docs at a time
# https://developers.cloudflare.com/workers-ai/models/bge-small-en-v1.5/#api-schema (Input JSON Schema)
max_batch_size: Optional[int] = 100,
headers: Optional[Dict[str, str]] = None,
):
if not gateway_url and not account_id:
raise ValueError("Please provide either an account_id or a gateway_url.")
if gateway_url and account_id:
"""
Initialize the Cloudflare Workers AI Embeddings function.
:param model_name: The name of the model to use. Defaults to "@cf/baai/bge-base-en-v1.5".
:param api_token: The API token to use. Defaults to the CF_API_TOKEN environment variable.
:param account_id: The account ID to use.
:param gateway_endpoint: The gateway URL to use.
:param max_batch_size: The maximum batch size to use. Defaults to 100.
:param headers: The headers to use. Defaults to None.
"""
if not gateway_endpoint and not account_id:
raise ValueError(
"Please provide either an account_id or a gateway_endpoint."
)
if gateway_endpoint and account_id:
raise ValueError(
"Please provide either an account_id or a gateway_url, not both."
"Please provide either an account_id or a gateway_endpoint, not both."
)
if gateway_url is not None and not gateway_url.endswith("/"):
gateway_url += "/"
if gateway_endpoint is not None and not gateway_endpoint.endswith("/"):
gateway_endpoint += "/"
self._api_url = (
f"{gateway_url}{model_name}"
if gateway_url is not None
f"{gateway_endpoint}{model_name}"
if gateway_endpoint is not None
else f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}"
)
self._session = httpx.Client()
Expand Down
41 changes: 41 additions & 0 deletions docs/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,44 @@ col = client.get_or_create_collection("test", embedding_function=ef)
col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"])
col.query(query_texts=["lorem ipsum..."], n_results=2)
```


## Cloudflare Workers AI

A convenient way to generate embeddings using Cloudflare Workers AI models.

_Using with Account ID and API Token:_

```py
import os
import chromadb
from chromadbx.embeddings.cloudflare import CloudflareWorkersAIEmbeddings

ef = CloudflareWorkersAIEmbeddings(
model_name="@cf/baai/bge-base-en-v1.5",
api_token=os.getenv("CF_API_TOKEN"),
account_id=os.getenv("CF_ACCOUNT_ID")
)

client = chromadb.Client()

col = client.get_or_create_collection("test", embedding_function=ef)

col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"])
col.query(query_texts=["lorem ipsum..."], n_results=2)
```

_Using with Gateway Endpoint:_

```py
import os
from chromadbx.embeddings.cloudflare import CloudflareWorkersAIEmbeddings

ef = CloudflareWorkersAIEmbeddings(
model_name="@cf/baai/bge-base-en-v1.5",
api_token=os.getenv("CF_API_TOKEN"),
gateway_url=os.getenv("CF_GATEWAY_ENDPOINT") # "https://gateway.ai.cloudflare.com/v1/[account_id]/[project]/workers-ai"
)

# ... rest of the code
```
20 changes: 10 additions & 10 deletions test/embeddings/test_cloudflare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

from chromadbx.embeddings.cloudflare import (
CloudflareWorkersAIEmbeddingFunction,
CloudflareWorkersAIEmbeddings,
)


Expand All @@ -12,7 +12,7 @@
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_token_and_account() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(
ef = CloudflareWorkersAIEmbeddings(
api_token=os.environ.get("CF_API_TOKEN", ""),
account_id=os.environ.get("CF_ACCOUNT_ID"),
)
Expand All @@ -27,9 +27,9 @@ def test_cf_ef_token_and_account() -> None:
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_gateway() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(
ef = CloudflareWorkersAIEmbeddings(
api_token=os.environ.get("CF_API_TOKEN", ""),
gateway_url=os.environ.get("CF_GATEWAY_ENDPOINT"),
gateway_endpoint=os.environ.get("CF_GATEWAY_ENDPOINT"),
)
embeddings = ef(["test doc"])
assert embeddings is not None
Expand All @@ -42,7 +42,7 @@ def test_cf_ef_gateway() -> None:
reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.",
)
def test_cf_ef_large_batch() -> None:
ef = CloudflareWorkersAIEmbeddingFunction(api_token="dummy", account_id="dummy")
ef = CloudflareWorkersAIEmbeddings(api_token="dummy", account_id="dummy")
with pytest.raises(ValueError, match="Batch too large"):
ef(["test doc"] * 101)

Expand All @@ -53,9 +53,9 @@ def test_cf_ef_large_batch() -> None:
)
def test_cf_ef_missing_account_or_gateway() -> None:
with pytest.raises(
ValueError, match="Please provide either an account_id or a gateway_url"
ValueError, match="Please provide either an account_id or a gateway_endpoint"
):
CloudflareWorkersAIEmbeddingFunction(api_token="dummy")
CloudflareWorkersAIEmbeddings(api_token="dummy")


@pytest.mark.skipif(
Expand All @@ -65,8 +65,8 @@ def test_cf_ef_missing_account_or_gateway() -> None:
def test_cf_ef_with_account_or_gateway() -> None:
with pytest.raises(
ValueError,
match="Please provide either an account_id or a gateway_url, not both",
match="Please provide either an account_id or a gateway_endpoint, not both",
):
CloudflareWorkersAIEmbeddingFunction(
api_token="dummy", account_id="dummy", gateway_url="dummy"
CloudflareWorkersAIEmbeddings(
api_token="dummy", account_id="dummy", gateway_endpoint="dummy"
)

0 comments on commit 97ca818

Please sign in to comment.