From 97ca818b5da6ea275c156f460ce1aff6d7f16d48 Mon Sep 17 00:00:00 2001 From: Trayan Azarov Date: Sat, 19 Oct 2024 14:14:30 +0300 Subject: [PATCH] docs: Added docs - renamed EF to align with rest of the EFs naming convention - API token now defaults to CF_API_TOKEN env var (Account or GW endpoint still need to be passed) - renamed gateway_url to gateway_endpoint for better consistency --- chromadbx/embeddings/cloudflare.py | 38 ++++++++++++++++++--------- docs/embeddings.md | 41 ++++++++++++++++++++++++++++++ test/embeddings/test_cloudflare.py | 20 +++++++-------- 3 files changed, 77 insertions(+), 22 deletions(-) diff --git a/chromadbx/embeddings/cloudflare.py b/chromadbx/embeddings/cloudflare.py index 274259e..8f9eaf8 100644 --- a/chromadbx/embeddings/cloudflare.py +++ b/chromadbx/embeddings/cloudflare.py @@ -1,5 +1,6 @@ # original work on this was done by @mileszim - https://github.com/mileszim/chroma/tree/cloudflare-workers-ai-embedding import logging +import os from typing import Optional, Dict, cast import httpx @@ -9,17 +10,18 @@ logger = logging.getLogger(__name__) -class CloudflareWorkersAIEmbeddingFunction(EmbeddingFunction[Documents]): +class CloudflareWorkersAIEmbeddings(EmbeddingFunction[Documents]): # Follow API Quickstart for Cloudflare Workers AI # https://developers.cloudflare.com/workers-ai/ # Information about the text embedding modules in Google Vertex AI # https://developers.cloudflare.com/workers-ai/models/embedding/ def __init__( self, - api_token: str, - account_id: Optional[str] = None, model_name: Optional[str] = "@cf/baai/bge-base-en-v1.5", - gateway_url: Optional[ + *, + api_token: Optional[str] = os.getenv("CF_API_TOKEN"), + account_id: Optional[str] = None, + gateway_endpoint: Optional[ str ] = None, # use Cloudflare AI Gateway instead of the usual endpoint # right now endpoint schema supports up to 100 docs at a time @@ -27,17 +29,29 @@ def __init__( max_batch_size: Optional[int] = 100, headers: Optional[Dict[str, str]] = None, ): - if not gateway_url and not account_id: - raise ValueError("Please provide either an account_id or a gateway_url.") - if gateway_url and account_id: + """ + Initialize the Cloudflare Workers AI Embeddings function. + + :param model_name: The name of the model to use. Defaults to "@cf/baai/bge-base-en-v1.5". + :param api_token: The API token to use. Defaults to the CF_API_TOKEN environment variable. + :param account_id: The account ID to use. + :param gateway_endpoint: The gateway URL to use. + :param max_batch_size: The maximum batch size to use. Defaults to 100. + :param headers: The headers to use. Defaults to None. + """ + if not gateway_endpoint and not account_id: + raise ValueError( + "Please provide either an account_id or a gateway_endpoint." + ) + if gateway_endpoint and account_id: raise ValueError( - "Please provide either an account_id or a gateway_url, not both." + "Please provide either an account_id or a gateway_endpoint, not both." ) - if gateway_url is not None and not gateway_url.endswith("/"): - gateway_url += "/" + if gateway_endpoint is not None and not gateway_endpoint.endswith("/"): + gateway_endpoint += "/" self._api_url = ( - f"{gateway_url}{model_name}" - if gateway_url is not None + f"{gateway_endpoint}{model_name}" + if gateway_endpoint is not None else f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}" ) self._session = httpx.Client() diff --git a/docs/embeddings.md b/docs/embeddings.md index 249f961..f00540f 100644 --- a/docs/embeddings.md +++ b/docs/embeddings.md @@ -145,3 +145,44 @@ col = client.get_or_create_collection("test", embedding_function=ef) col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"]) col.query(query_texts=["lorem ipsum..."], n_results=2) ``` + + +## Cloudflare Workers AI + +A convenient way to generate embeddings using Cloudflare Workers AI models. + +_Using with Account ID and API Token:_ + +```py +import os +import chromadb +from chromadbx.embeddings.cloudflare import CloudflareWorkersAIEmbeddings + +ef = CloudflareWorkersAIEmbeddings( + model_name="@cf/baai/bge-base-en-v1.5", + api_token=os.getenv("CF_API_TOKEN"), + account_id=os.getenv("CF_ACCOUNT_ID") +) + +client = chromadb.Client() + +col = client.get_or_create_collection("test", embedding_function=ef) + +col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"]) +col.query(query_texts=["lorem ipsum..."], n_results=2) +``` + +_Using with Gateway Endpoint:_ + +```py +import os +from chromadbx.embeddings.cloudflare import CloudflareWorkersAIEmbeddings + +ef = CloudflareWorkersAIEmbeddings( + model_name="@cf/baai/bge-base-en-v1.5", + api_token=os.getenv("CF_API_TOKEN"), + gateway_url=os.getenv("CF_GATEWAY_ENDPOINT") # "https://gateway.ai.cloudflare.com/v1/[account_id]/[project]/workers-ai" +) + +# ... rest of the code +``` diff --git a/test/embeddings/test_cloudflare.py b/test/embeddings/test_cloudflare.py index 5378499..d4f04a0 100644 --- a/test/embeddings/test_cloudflare.py +++ b/test/embeddings/test_cloudflare.py @@ -3,7 +3,7 @@ import pytest from chromadbx.embeddings.cloudflare import ( - CloudflareWorkersAIEmbeddingFunction, + CloudflareWorkersAIEmbeddings, ) @@ -12,7 +12,7 @@ reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", ) def test_cf_ef_token_and_account() -> None: - ef = CloudflareWorkersAIEmbeddingFunction( + ef = CloudflareWorkersAIEmbeddings( api_token=os.environ.get("CF_API_TOKEN", ""), account_id=os.environ.get("CF_ACCOUNT_ID"), ) @@ -27,9 +27,9 @@ def test_cf_ef_token_and_account() -> None: reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", ) def test_cf_ef_gateway() -> None: - ef = CloudflareWorkersAIEmbeddingFunction( + ef = CloudflareWorkersAIEmbeddings( api_token=os.environ.get("CF_API_TOKEN", ""), - gateway_url=os.environ.get("CF_GATEWAY_ENDPOINT"), + gateway_endpoint=os.environ.get("CF_GATEWAY_ENDPOINT"), ) embeddings = ef(["test doc"]) assert embeddings is not None @@ -42,7 +42,7 @@ def test_cf_ef_gateway() -> None: reason="CF_API_TOKEN and CF_ACCOUNT_ID not set, skipping test.", ) def test_cf_ef_large_batch() -> None: - ef = CloudflareWorkersAIEmbeddingFunction(api_token="dummy", account_id="dummy") + ef = CloudflareWorkersAIEmbeddings(api_token="dummy", account_id="dummy") with pytest.raises(ValueError, match="Batch too large"): ef(["test doc"] * 101) @@ -53,9 +53,9 @@ def test_cf_ef_large_batch() -> None: ) def test_cf_ef_missing_account_or_gateway() -> None: with pytest.raises( - ValueError, match="Please provide either an account_id or a gateway_url" + ValueError, match="Please provide either an account_id or a gateway_endpoint" ): - CloudflareWorkersAIEmbeddingFunction(api_token="dummy") + CloudflareWorkersAIEmbeddings(api_token="dummy") @pytest.mark.skipif( @@ -65,8 +65,8 @@ def test_cf_ef_missing_account_or_gateway() -> None: def test_cf_ef_with_account_or_gateway() -> None: with pytest.raises( ValueError, - match="Please provide either an account_id or a gateway_url, not both", + match="Please provide either an account_id or a gateway_endpoint, not both", ): - CloudflareWorkersAIEmbeddingFunction( - api_token="dummy", account_id="dummy", gateway_url="dummy" + CloudflareWorkersAIEmbeddings( + api_token="dummy", account_id="dummy", gateway_endpoint="dummy" )