Skip to content

Commit

Permalink
ENH: Validate response of embedding function to be in the expected fo…
Browse files Browse the repository at this point in the history
…rmat during runtime (chroma-core#1615)

## Description of changes
Requested in chroma-core#1488
 - Improvements & Bug fixes
- Raise an exception when an external embedding function doesn't return
embeddings in the expected format

## Test plan

- [X] Tests pass locally with `pytest` for python, `yarn test` for js

## Documentation Changes
*Are all docstrings for user-facing APIs updated if required? Do we need
to make documentation changes in the [docs
repository](https://github.com/chroma-core/docs)?*
  • Loading branch information
GauravWaghmare authored Jan 11, 2024
1 parent 28aa64c commit caa10f6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
11 changes: 11 additions & 0 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...

def __init_subclass__(cls) -> None:
super().__init_subclass__()
# Raise an exception if __call__ is not defined since it is expected to be defined
call = getattr(cls, "__call__")

def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = call(self, input)
return validate_embeddings(maybe_cast_one_to_many_embedding(result))

setattr(cls, "__call__", __call__)


def validate_embedding_function(
embedding_function: EmbeddingFunction[Embeddable],
Expand Down
40 changes: 40 additions & 0 deletions chromadb/test/api/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from typing import List, cast
from chromadb.api.types import EmbeddingFunction, Documents, Image, Document, Embeddings
import numpy as np


def random_embeddings() -> Embeddings:
return cast(Embeddings, np.random.random(size=(10, 10)).tolist())


def random_image() -> Image:
return np.random.randint(0, 255, size=(10, 10, 3), dtype=np.int32)


def random_documents() -> List[Document]:
return [str(random_image()) for _ in range(10)]


def test_embedding_function_results_format_when_response_is_valid() -> None:
valid_embeddings = random_embeddings()

class TestEmbeddingFunction(EmbeddingFunction[Documents]):
def __call__(self, input: Documents) -> Embeddings:
return valid_embeddings

ef = TestEmbeddingFunction()
assert valid_embeddings == ef(random_documents())


def test_embedding_function_results_format_when_response_is_invalid() -> None:
invalid_embedding = {"error": "test"}

class TestEmbeddingFunction(EmbeddingFunction[Documents]):
def __call__(self, input: Documents) -> Embeddings:
return cast(Embeddings, invalid_embedding)

ef = TestEmbeddingFunction()
with pytest.raises(ValueError) as e:
ef(random_documents())
assert e.type is ValueError

0 comments on commit caa10f6

Please sign in to comment.