Skip to content

Commit

Permalink
feat(example): add chromadb embedding function (#270)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D authored Jan 29, 2024
1 parent 15ab900 commit 2caec71
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/extra/vector_database/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Vector Databases
"""
3 changes: 3 additions & 0 deletions examples/extra/vector_database/chroma_db/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Chroma DB
"""
38 changes: 38 additions & 0 deletions examples/extra/vector_database/chroma_db/chroma_db_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Create ChromaDB Embedding Function
"""
from typing import Optional

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from dotenv import load_dotenv

from genai import Client, Credentials
from genai.schema import TextEmbeddingParameters

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint>
load_dotenv()


class ChromaEmbeddingFunction(EmbeddingFunction):
def __init__(self, *, model_id: str, client: Client, parameters: Optional[TextEmbeddingParameters] = None):
self._model_id = model_id
self._parameters = parameters
self._client = client

def __call__(self, inputs: Documents) -> Embeddings:
embeddings: Embeddings = []
for response in self._client.text.embedding.create(
model_id=self._model_id, inputs=inputs, parameters=self._parameters
):
embeddings.extend(response.results)

return embeddings


credentials = Credentials.from_env()
client = Client(credentials=credentials)
embedding_fn = ChromaEmbeddingFunction(model_id="sentence-transformers/all-minilm-l6-v2", client=client)

print(embedding_fn(["Hello world!"]))
1 change: 1 addition & 0 deletions tests/e2e/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# These files are skipped for python >= 3.12 because transformers library cannot be installed
"local_server.py",
"huggingface_agent.py",
"chroma_db_embedding.py",
}

scripts_lt_3_12 = {script for script in all_scripts if script.name not in ignore_files | skip_for_python_3_12}
Expand Down

0 comments on commit 2caec71

Please sign in to comment.