diff --git a/chromadb/utils/data_loaders.py b/chromadb/utils/data_loaders.py index 60057e0e584..82ea894aa9a 100644 --- a/chromadb/utils/data_loaders.py +++ b/chromadb/utils/data_loaders.py @@ -1,8 +1,8 @@ import importlib import multiprocessing -from typing import Optional, Sequence, List +from typing import Optional, Sequence, List, Tuple import numpy as np -from chromadb.api.types import URI, DataLoader, Image +from chromadb.api.types import URI, DataLoader, Image, URIs from concurrent.futures import ThreadPoolExecutor @@ -22,3 +22,10 @@ def _load_image(self, uri: Optional[URI]) -> Optional[Image]: def __call__(self, uris: Sequence[Optional[URI]]) -> List[Optional[Image]]: with ThreadPoolExecutor(max_workers=self._max_workers) as executor: return list(executor.map(self._load_image, uris)) + + +class ChromaLangchainPassthroughDataLoader(DataLoader[List[Optional[Image]]]): + # This is a simple pass through data loader that just returns the input data with "images" + # flag which lets the langchain embedding function know that the data is image uris + def __call__(self, uris: URIs) -> Tuple[str, URIs]: # type: ignore + return ("images", uris) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index d54174a9e71..3f0a1ce043b 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -900,6 +900,69 @@ def __call__(self, input: Documents) -> Embeddings: ) +def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore + try: + from langchain_core.embeddings import Embeddings as LangchainEmbeddings + except ImportError: + raise ValueError( + "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" + ) + + class ChromaLangchainEmbeddingFunction( + LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore + ): + """ + This class is used as bridge between langchain embedding functions and custom chroma embedding functions. + """ + + def __init__(self, embedding_function: LangchainEmbeddings) -> None: + """ + Initialize the ChromaLangchainEmbeddingFunction + + Args: + embedding_function : The embedding function implementing Embeddings from langchain_core. + """ + self.embedding_function = embedding_function + + def embed_documents(self, documents: Documents) -> List[List[float]]: + return self.embedding_function.embed_documents(documents) # type: ignore + + def embed_query(self, query: str) -> List[float]: + return self.embedding_function.embed_query(query) # type: ignore + + def embed_image(self, uris: List[str]) -> List[List[float]]: + if hasattr(self.embedding_function, "embed_image"): + return self.embedding_function.embed_image(uris) # type: ignore + else: + raise ValueError( + "The provided embedding function does not support image embeddings." + ) + + def __call__(self, input: Documents) -> Embeddings: # type: ignore + """ + Get the embeddings for a list of texts or images. + + Args: + input (Documents | Images): A list of texts or images to get embeddings for. + Images should be provided as a list of URIs passed through the langchain data loader + + Returns: + Embeddings: The embeddings for the texts or images. + + Example: + >>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large")) + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = langchain_embedding(texts) + """ + # Due to langchain quirks, the dataloader returns a tuple if the input is uris of images + if input[0] == "images": + return self.embed_image(list(input[1])) # type: ignore + + return self.embed_documents(list(input)) # type: ignore + + return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) + + class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): """ This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). @@ -955,6 +1018,7 @@ def __call__(self, input: Documents) -> Embeddings: ], ) + # List of all classes in this module _classes = [ name