From c611e9fad3c10bdfc331a2bfb3e36693ce1a79be Mon Sep 17 00:00:00 2001 From: Ludwik Trammer Date: Tue, 22 Oct 2024 13:56:43 +0200 Subject: [PATCH] feat(core): Add a "list" method to vector stores --- examples/document-search/chroma.py | 15 ++- .../src/ragbits/core/vector_store/__init__.py | 4 +- .../src/ragbits/core/vector_store/base.py | 23 +++- .../core/vector_store/chromadb_store.py | 52 +++++++- .../ragbits/core/vector_store/in_memory.py | 34 ++++- .../unit/vector_stores/test_chromadb_store.py | 30 +++++ .../vector_stores/test_simple_vector_store.py | 121 +++++++++++++++++- .../src/ragbits/document_search/__init__.py | 4 +- 8 files changed, 259 insertions(+), 24 deletions(-) diff --git a/examples/document-search/chroma.py b/examples/document-search/chroma.py index acd29a7a..734b724e 100644 --- a/examples/document-search/chroma.py +++ b/examples/document-search/chroma.py @@ -11,7 +11,7 @@ from ragbits.core.embeddings import LiteLLMEmbeddings from ragbits.core.vector_store.chromadb_store import ChromaDBStore -from ragbits.document_search import DocumentSearch +from ragbits.document_search import DocumentSearch, SearchConfig from ragbits.document_search.documents.document import DocumentMeta documents = [ @@ -19,6 +19,7 @@ DocumentMeta.create_text_document_from_literal( "Why programmers don't like to swim? Because they're scared of the floating points." ), + DocumentMeta.create_text_document_from_literal("This one is completely unrelated."), ] @@ -37,8 +38,16 @@ async def main(): await document_search.ingest(documents) - results = await document_search.search("I'm boiling my water and I need a joke") - print(results) + print() + print("All documents:") + all_documents = await vector_store.list() + print([doc.metadata["content"] for doc in all_documents]) + + query = "I'm boiling my water and I need a joke" + print() + print(f"Documents similar to: {query}") + results = await document_search.search(query, search_config=SearchConfig(vector_store_kwargs={"k": 2})) + print([element.get_key() for element in results]) if __name__ == "__main__": diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py b/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py index 8d48c78b..1ae7c15c 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/__init__.py @@ -1,11 +1,11 @@ import sys from ..utils.config_handling import get_cls_from_config -from .base import VectorDBEntry, VectorStore +from .base import VectorDBEntry, VectorStore, WhereQuery from .chromadb_store import ChromaDBStore from .in_memory import InMemoryVectorStore -__all__ = ["InMemoryVectorStore", "VectorDBEntry", "VectorStore", "ChromaDBStore"] +__all__ = ["InMemoryVectorStore", "VectorDBEntry", "VectorStore", "ChromaDBStore", "WhereQuery"] module = sys.modules[__name__] diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/base.py b/packages/ragbits-core/src/ragbits/core/vector_store/base.py index 4d494c56..29e76e7e 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/base.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/base.py @@ -1,5 +1,4 @@ import abc -from typing import List from pydantic import BaseModel @@ -14,13 +13,16 @@ class VectorDBEntry(BaseModel): metadata: dict +WhereQuery = dict[str, str | int | float | bool] + + class VectorStore(abc.ABC): """ A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function. """ @abc.abstractmethod - async def store(self, entries: List[VectorDBEntry]) -> None: + async def store(self, entries: list[VectorDBEntry]) -> None: """ Store entries in the vector store. @@ -40,3 +42,20 @@ async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry] Returns: The entries. """ + + @abc.abstractmethod + async def list( + self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 + ) -> list[VectorDBEntry]: + """ + List entries from the vector store. The entries can be filtered, limited and offset. + + Args: + where: The filter dictionary - the keys are the field names and the values are the values to filter by. + Not specifying the key means no filtering. + limit: The maximum number of entries to return. + offset: The number of entries to skip. + + Returns: + The entries. + """ diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py b/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py index 6bee093e..5e71d4bd 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/chromadb_store.py @@ -11,7 +11,7 @@ from ragbits.core.embeddings import Embeddings from ragbits.core.utils.config_handling import get_cls_from_config -from ragbits.core.vector_store import VectorDBEntry, VectorStore +from ragbits.core.vector_store import VectorDBEntry, VectorStore, WhereQuery class ChromaDBStore(VectorStore): @@ -148,21 +148,59 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry] Returns: The retrieved entries. """ - query_result = self._collection.query(query_embeddings=[vector], n_results=k) + query_result = self._collection.query(query_embeddings=vector, n_results=k, include=["metadatas", "embeddings"]) + metadatas = query_result.get("metadatas") or [] + embeddings = query_result.get("embeddings") or [] db_entries = [] - for meta in query_result.get("metadatas"): - for result in meta: + for meta_list, embeddings_list in zip(metadatas, embeddings): + for meta, embedding in zip(meta_list, embeddings_list): db_entry = VectorDBEntry( - key=result["__key"], - vector=vector, - metadata=json.loads(result["__metadata"]), + key=str(meta["__key"]), + vector=list(embedding), + metadata=json.loads(str(meta["__metadata"])), ) db_entries.append(db_entry) return db_entries + async def list( + self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 + ) -> list[VectorDBEntry]: + """ + List entries from the vector store. The entries can be filtered, limited and offset. + + Args: + where: The filter dictionary - the keys are the field names and the values are the values to filter by. + Not specifying the key means no filtering. + limit: The maximum number of entries to return. + offset: The number of entries to skip. + + Returns: + The entries. + """ + # Cast `where` to chromadb's Where type + where_chroma: chromadb.Where | None = dict(where) if where else None + + get_results = self._collection.get( + where=where_chroma, limit=limit, offset=offset, include=["metadatas", "embeddings"] + ) + metadatas = get_results.get("metadatas") or [] + embeddings = get_results.get("embeddings") or [] + + db_entries = [] + for meta, embedding in zip(metadatas, embeddings): + db_entry = VectorDBEntry( + key=str(meta["__key"]), + vector=list(embedding), + metadata=json.loads(str(meta["__metadata"])), + ) + + db_entries.append(db_entry) + + return db_entries + def __repr__(self) -> str: """ Returns the string representation of the object. diff --git a/packages/ragbits-core/src/ragbits/core/vector_store/in_memory.py b/packages/ragbits-core/src/ragbits/core/vector_store/in_memory.py index ce0576fa..121f1c43 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_store/in_memory.py +++ b/packages/ragbits-core/src/ragbits/core/vector_store/in_memory.py @@ -1,6 +1,8 @@ +from itertools import islice + import numpy as np -from ragbits.core.vector_store.base import VectorDBEntry, VectorStore +from ragbits.core.vector_store.base import VectorDBEntry, VectorStore, WhereQuery class InMemoryVectorStore(VectorStore): @@ -45,3 +47,33 @@ async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry] @staticmethod def _calculate_squared_euclidean(vector_x: list[float], vector_b: list[float]) -> float: return np.linalg.norm(np.array(vector_x) - np.array(vector_b)) + + async def list( + self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 + ) -> list[VectorDBEntry]: + """ + List entries from the vector store. The entries can be filtered, limited and offset. + + Args: + where: The filter dictionary - the keys are the field names and the values are the values to filter by. + Not specifying the key means no filtering. + limit: The maximum number of entries to return. + offset: The number of entries to skip. + + Returns: + The entries. + """ + entries = iter(self._storage.values()) + + if where: + entries = ( + entry for entry in entries if all(entry.metadata.get(key) == value for key, value in where.items()) + ) + + if offset: + entries = islice(entries, offset, None) + + if limit: + entries = islice(entries, limit) + + return list(entries) diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py b/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py index 020e756d..240f8a3d 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_chromadb_store.py @@ -120,6 +120,7 @@ async def test_retrieves_entries_correctly(mock_chromadb_store): } ] ], + "embeddings": [[[0.12, 0.25, 0.29]]], } entries = await mock_chromadb_store.retrieve(vector) @@ -127,6 +128,35 @@ async def test_retrieves_entries_correctly(mock_chromadb_store): assert len(entries) == 1 assert entries[0].metadata["content"] == "test content" assert entries[0].metadata["document"]["title"] == "test title" + assert entries[0].vector == [0.12, 0.25, 0.29] + + +async def test_lists_entries_correctly(mock_chromadb_store): + mock_collection = mock_chromadb_store._get_chroma_collection() + mock_collection.get.return_value = { + "documents": ["test content", "test content 2"], + "metadatas": [ + { + "__key": "test_key", + "__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}', + }, + { + "__key": "test_key_2", + "__metadata": '{"content": "test content 2", "document": {"title": "test title 2", "source": {"path": "/test/path"}, "document_type": "test_type"}}', + }, + ], + "embeddings": [[0.12, 0.25, 0.29], [0.13, 0.26, 0.30]], + } + + entries = await mock_chromadb_store.list() + + assert len(entries) == 2 + assert entries[0].metadata["content"] == "test content" + assert entries[0].metadata["document"]["title"] == "test title" + assert entries[0].vector == [0.12, 0.25, 0.29] + assert entries[1].metadata["content"] == "test content 2" + assert entries[1].metadata["document"]["title"] == "test title 2" + assert entries[1].vector == [0.13, 0.26, 0.30] async def test_handles_empty_retrieve(mock_chromadb_store): diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_simple_vector_store.py b/packages/ragbits-core/tests/unit/vector_stores/test_simple_vector_store.py index 8461d93b..6356bef5 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_simple_vector_store.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_simple_vector_store.py @@ -1,28 +1,135 @@ from pathlib import Path +import pytest + from ragbits.core.vector_store.in_memory import InMemoryVectorStore from ragbits.document_search.documents.document import DocumentMeta, DocumentType -from ragbits.document_search.documents.element import TextElement +from ragbits.document_search.documents.element import Element from ragbits.document_search.documents.sources import LocalFileSource -async def test_simple_vector_store(): - store = InMemoryVectorStore() +class AnimalElement(Element): + """ + A test element representing an animal. + """ + + element_type: str = "animal" + name: str + species: str + type: str + age: int + + def get_key(self) -> str: + """ + Get the key of the element which will be used to generate the vector. + + Returns: + The key. + """ + return self.name + +@pytest.fixture(name="store") +async def store_fixture(): document_meta = DocumentMeta(document_type=DocumentType.TXT, source=LocalFileSource(path=Path("test.txt"))) elements = [ - (TextElement(content="dog", document_meta=document_meta), [0.5, 0.5]), - (TextElement(content="cat", document_meta=document_meta), [0.6, 0.6]), + (AnimalElement(name="spikey", species="dog", type="mammal", age=5, document_meta=document_meta), [0.5, 0.5]), + (AnimalElement(name="fluffy", species="cat", type="mammal", age=3, document_meta=document_meta), [0.6, 0.6]), + (AnimalElement(name="slimy", species="frog", type="amphibian", age=1, document_meta=document_meta), [0.7, 0.7]), + (AnimalElement(name="scaly", species="snake", type="reptile", age=2, document_meta=document_meta), [0.8, 0.8]), + (AnimalElement(name="hairy", species="spider", type="insect", age=6, document_meta=document_meta), [0.9, 0.9]), + ( + AnimalElement(name="spotty", species="ladybug", type="insect", age=1, document_meta=document_meta), + [0.1, 0.1], + ), ] entries = [element[0].to_vector_db_entry(vector=element[1]) for element in elements] + store = InMemoryVectorStore() await store.store(entries) + return store + +async def test_simple_vector_store(store): search_vector = [0.4, 0.4] results = await store.retrieve(search_vector, 2) assert len(results) == 2 - assert results[0].metadata["content"] == "dog" - assert results[1].metadata["content"] == "cat" + assert results[0].metadata["name"] == "spikey" + assert results[1].metadata["name"] == "fluffy" + + +async def test_list_all(store): + results = await store.list() + + assert len(results) == 6 + names = [result.metadata["name"] for result in results] + assert names == ["spikey", "fluffy", "slimy", "scaly", "hairy", "spotty"] + + +async def test_list_limit(store): + results = await store.list(limit=3) + + assert len(results) == 3 + names = {result.metadata["name"] for result in results} + assert names == {"spikey", "fluffy", "slimy"} + + +async def test_list_offset(store): + results = await store.list(offset=3) + + assert len(results) == 3 + names = {result.metadata["name"] for result in results} + assert names == {"scaly", "hairy", "spotty"} + + +async def test_limit_with_offset(store): + results = await store.list(limit=2, offset=3) + + assert len(results) == 2 + names = {result.metadata["name"] for result in results} + assert names == {"scaly", "hairy"} + + +async def test_where(store): + results = await store.list(where={"type": "insect"}) + + assert len(results) == 2 + names = {result.metadata["name"] for result in results} + assert names == {"hairy", "spotty"} + + +async def test_multiple_where(store): + results = await store.list(where={"type": "insect", "age": 1}) + + assert len(results) == 1 + assert results[0].metadata["name"] == "spotty" + + +async def test_empty_where(store): + results = await store.list(where={}) + + assert len(results) == 6 + names = {result.metadata["name"] for result in results} + assert names == {"spikey", "fluffy", "slimy", "scaly", "hairy", "spotty"} + + +async def test_empty_results(store): + results = await store.list(where={"type": "bird"}) + + assert len(results) == 0 + + +async def test_empty_results_with_limit(store): + results = await store.list(where={"type": "bird"}, limit=2) + + assert len(results) == 0 + + +async def test_where_limit(store): + results = await store.list(where={"type": "insect"}, limit=1) + + assert len(results) == 1 + assert results[0].metadata["name"] == "hairy" diff --git a/packages/ragbits-document-search/src/ragbits/document_search/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/__init__.py index aafe8c11..2d5a8d67 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/__init__.py @@ -1,3 +1,3 @@ -from ._main import DocumentSearch +from ._main import DocumentSearch, SearchConfig -__all__ = ["DocumentSearch"] +__all__ = ["DocumentSearch", "SearchConfig"]