Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChromaDB Support. #27

Merged
merged 15 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions packages/ragbits-document-search/examples/chromadb_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits[litellm]",
# ]
# ///
import asyncio
import os

import chromadb

from ragbits.core.embeddings.litellm import LiteLLMEmbeddings
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta
from ragbits.document_search.vector_store.chromadb_store import ChromaDBStore

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
DocumentMeta.create_text_document_from_literal(
"Why programmers don't like to swim? Because they're scared of the floating points."
),
]


async def main():
"""Run the example."""

chroma_client = chromadb.PersistentClient(path="chroma")
embedding_client = LiteLLMEmbeddings(
api_key=os.getenv("OPENAI_API_KEY"),
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
)

vector_store = ChromaDBStore(
index_name="jokes",
chroma_client=chroma_client,
embedding_function=embedding_client,
)
document_search = DocumentSearch(embedder=vector_store.embedding_function, vector_store=vector_store)

for document in documents:
await document_search.ingest_document(document)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)


if __name__ == "__main__":
asyncio.run(main())
6 changes: 6 additions & 0 deletions packages/ragbits-document-search/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ dependencies = [
"ragbits"
]


[project.optional-dependencies]
chromadb = [
"chromadb~=0.4.24",
]

[tool.uv]
dev-dependencies = [
"pre-commit~=3.8.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ async def search(self, query: str) -> list[Element]:
A list of chunks.
"""
queries = self.query_rephraser.rephrase(query)
chunks = []
elements = []
for rephrased_query in queries:
search_vector = await self.embedder.embed_text([rephrased_query])
# TODO: search parameters should be configurable
entries = await self.vector_store.retrieve(search_vector[0], k=1)
chunks.extend([Element.from_vector_db_entry(entry) for entry in entries])
elements.extend([Element.from_vector_db_entry(entry) for entry in entries])

return self.reranker.rerank(chunks)
return self.reranker.rerank(elements)

async def ingest_document(self, document: DocumentMeta) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import json
from copy import deepcopy
from hashlib import sha256
from typing import List, Literal, Optional, Union

try:
import chromadb

HAS_CHROMADB = True
except ImportError:
HAS_CHROMADB = False

from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.vector_store.in_memory import InMemoryVectorStore, VectorDBEntry


class ChromaDBStore(InMemoryVectorStore):
Copy link
Member

@mhordynski mhordynski Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it derives from InMemoryVectorStory? It should be VectorStore base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider it done.

"""Class that stores text embeddings using [Chroma](https://docs.trychroma.com/)"""

def __init__(
self,
index_name: str,
chroma_client: chromadb.ClientAPI,
embedding_function: Union[Embeddings, chromadb.EmbeddingFunction],
max_distance: Optional[float] = None,
distance_method: Literal["l2", "ip", "cosine"] = "l2",
):
"""
Initializes the ChromaDBStore with the given parameters.

Args:
index_name (str): The name of the index.
chroma_client (chromadb.ClientAPI): The ChromaDB client.
embedding_function (Union[Embeddings, chromadb.EmbeddingFunction]): The embedding function.
max_distance (Optional[float], default=None): The maximum distance for similarity.
distance_method (Literal["l2", "ip", "cosine"], default="l2"): The distance method to use.
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
"""
if not HAS_CHROMADB:
raise ImportError("Install the 'ragbits-document-search[chromadb]' extra to use LiteLLM embeddings models")

super().__init__()
self.index_name = index_name
self.chroma_client = chroma_client
self.embedding_function = embedding_function
self.max_distance = max_distance

self._metadata = {"hnsw:space": distance_method}

def _get_chroma_collection(self) -> chromadb.Collection:
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
"""
Based on the selected embedding_function, chooses how to retrieve the ChromaDB collection.
If the collection doesn't exist, it creates one.

Returns:
chromadb.Collection: Retrieved collection
"""
if isinstance(self.embedding_function, Embeddings):
return self.chroma_client.get_or_create_collection(name=self.index_name, metadata=self._metadata)

return self.chroma_client.get_or_create_collection(
name=self.index_name,
metadata=self._metadata,
embedding_function=self.embedding_function,
)

def _return_best_match(self, retrieved: dict) -> Optional[str]:
"""
Based on the retrieved data, returns the best match or None if no match is found.

Args:
retrieved (dict): Retrieved data, with a column-first format

Returns:
Optional[str]: The best match or None if no match is found
"""
if self.max_distance is None or retrieved["distances"][0][0] <= self.max_distance:
return retrieved["documents"][0][0]

return None

def _process_db_entry(self, entry: VectorDBEntry) -> tuple[str, list[float], str, dict]:
doc_id = sha256(entry.key.encode("utf-8")).hexdigest()
embedding = entry.vector
text = entry.metadata["content"]

metadata = deepcopy(entry.metadata)
metadata["document"]["source"]["path"] = str(metadata["document"]["source"]["path"])
metadata["key"] = entry.key
metadata = {key: json.dumps(val) if isinstance(val, dict) else val for key, val in metadata.items()}

return doc_id, embedding, text, metadata

def _process_metadata(self, metadata: dict) -> dict[str, Union[str, int, float, bool]]:
"""
Processes the metadata dictionary by parsing JSON strings if applicable.

Args:
metadata (dict): A dictionary containing metadata where values may be JSON strings.

Returns:
dict[str, Union[str, int, float, bool]]: A dictionary with the same keys as the input,
where JSON strings are parsed into their respective Python data types.
"""
return {key: json.loads(val) if self.is_json(val) else val for key, val in metadata.items()}

def is_json(self, myjson: str) -> bool:
PatrykWyzgowski marked this conversation as resolved.
Show resolved Hide resolved
"""
Check if the provided string is a valid JSON.

Args:
myjson (str): The string to be checked.

Returns:
bool: True if the string is a valid JSON, False otherwise.
"""
try:
if isinstance(myjson, str):
json.loads(myjson)
return True
return False
except ValueError:
return False

async def store(self, entries: List[VectorDBEntry]) -> None:
"""
Stores entries in the ChromaDB collection.

Args:
entries (List[VectorDBEntry]): The entries to store.
"""
collection = self._get_chroma_collection()

entries_processed = list(map(self._process_db_entry, entries))
ids, embeddings, texts, metadatas = map(list, zip(*entries_processed))

collection.add(ids=ids, embeddings=embeddings, documents=texts, metadatas=metadatas)

async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]:
"""
Retrieves entries from the ChromaDB collection.

Args:
vector (List[float]): The vector to query.
k (int): The number of entries to retrieve.

Returns:
List[VectorDBEntry]: The retrieved entries.
"""
collection = self._get_chroma_collection()
query_result = collection.query(query_embeddings=[vector], n_results=k)

db_entries = []
for meta in query_result.get("metadatas"):
db_entry = VectorDBEntry(
key=meta[0].get("key"),
vector=vector,
metadata=self._process_metadata(meta[0]),
)

db_entries.append(db_entry)

return db_entries

async def find_similar(self, text: str) -> Optional[str]:
"""
Finds the most similar text in the chroma collection or returns None if the most similar text
has distance bigger than `self.max_distance`.

Args:
text (str): The text to find similar to.

Returns:
Optional[str]: The most similar text or None if no similar text is found.
"""

collection = self._get_chroma_collection()

if isinstance(self.embedding_function, Embeddings):
embedding = await self.embedding_function.embed_text([text])
retrieved = collection.query(query_embeddings=embedding, n_results=1)
else:
retrieved = collection.query(query_texts=[text], n_results=1)

return self._return_best_match(retrieved)

def __repr__(self) -> str:
"""
Returns the string representation of the object.

Returns:
str: The string representation of the object.
"""
return f"{self.__class__.__name__}(index_name={self.index_name})"
Loading
Loading