Skip to content

Commit

Permalink
feat(core): Add a "list" method to vector stores
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Oct 22, 2024
1 parent 2465486 commit c611e9f
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 24 deletions.
15 changes: 12 additions & 3 deletions examples/document-search/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@

from ragbits.core.embeddings import LiteLLMEmbeddings
from ragbits.core.vector_store.chromadb_store import ChromaDBStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search import DocumentSearch, SearchConfig
from ragbits.document_search.documents.document import DocumentMeta

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
DocumentMeta.create_text_document_from_literal(
"Why programmers don't like to swim? Because they're scared of the floating points."
),
DocumentMeta.create_text_document_from_literal("This one is completely unrelated."),
]


Expand All @@ -37,8 +38,16 @@ async def main():

await document_search.ingest(documents)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)
print()
print("All documents:")
all_documents = await vector_store.list()
print([doc.metadata["content"] for doc in all_documents])

query = "I'm boiling my water and I need a joke"
print()
print(f"Documents similar to: {query}")
results = await document_search.search(query, search_config=SearchConfig(vector_store_kwargs={"k": 2}))
print([element.get_key() for element in results])


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys

from ..utils.config_handling import get_cls_from_config
from .base import VectorDBEntry, VectorStore
from .base import VectorDBEntry, VectorStore, WhereQuery
from .chromadb_store import ChromaDBStore
from .in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorDBEntry", "VectorStore", "ChromaDBStore"]
__all__ = ["InMemoryVectorStore", "VectorDBEntry", "VectorStore", "ChromaDBStore", "WhereQuery"]

module = sys.modules[__name__]

Expand Down
23 changes: 21 additions & 2 deletions packages/ragbits-core/src/ragbits/core/vector_store/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
from typing import List

from pydantic import BaseModel

Expand All @@ -14,13 +13,16 @@ class VectorDBEntry(BaseModel):
metadata: dict


WhereQuery = dict[str, str | int | float | bool]


class VectorStore(abc.ABC):
"""
A class with an implementation of Vector Store, allowing to store and retrieve vectors by similarity function.
"""

@abc.abstractmethod
async def store(self, entries: List[VectorDBEntry]) -> None:
async def store(self, entries: list[VectorDBEntry]) -> None:
"""
Store entries in the vector store.
Expand All @@ -40,3 +42,20 @@ async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry]
Returns:
The entries.
"""

@abc.abstractmethod
async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
) -> list[VectorDBEntry]:
"""
List entries from the vector store. The entries can be filtered, limited and offset.
Args:
where: The filter dictionary - the keys are the field names and the values are the values to filter by.
Not specifying the key means no filtering.
limit: The maximum number of entries to return.
offset: The number of entries to skip.
Returns:
The entries.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ragbits.core.embeddings import Embeddings
from ragbits.core.utils.config_handling import get_cls_from_config
from ragbits.core.vector_store import VectorDBEntry, VectorStore
from ragbits.core.vector_store import VectorDBEntry, VectorStore, WhereQuery


class ChromaDBStore(VectorStore):
Expand Down Expand Up @@ -148,21 +148,59 @@ async def retrieve(self, vector: List[float], k: int = 5) -> List[VectorDBEntry]
Returns:
The retrieved entries.
"""
query_result = self._collection.query(query_embeddings=[vector], n_results=k)
query_result = self._collection.query(query_embeddings=vector, n_results=k, include=["metadatas", "embeddings"])
metadatas = query_result.get("metadatas") or []
embeddings = query_result.get("embeddings") or []

db_entries = []
for meta in query_result.get("metadatas"):
for result in meta:
for meta_list, embeddings_list in zip(metadatas, embeddings):
for meta, embedding in zip(meta_list, embeddings_list):
db_entry = VectorDBEntry(
key=result["__key"],
vector=vector,
metadata=json.loads(result["__metadata"]),
key=str(meta["__key"]),
vector=list(embedding),
metadata=json.loads(str(meta["__metadata"])),
)

db_entries.append(db_entry)

return db_entries

async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
) -> list[VectorDBEntry]:
"""
List entries from the vector store. The entries can be filtered, limited and offset.
Args:
where: The filter dictionary - the keys are the field names and the values are the values to filter by.
Not specifying the key means no filtering.
limit: The maximum number of entries to return.
offset: The number of entries to skip.
Returns:
The entries.
"""
# Cast `where` to chromadb's Where type
where_chroma: chromadb.Where | None = dict(where) if where else None

get_results = self._collection.get(
where=where_chroma, limit=limit, offset=offset, include=["metadatas", "embeddings"]
)
metadatas = get_results.get("metadatas") or []
embeddings = get_results.get("embeddings") or []

db_entries = []
for meta, embedding in zip(metadatas, embeddings):
db_entry = VectorDBEntry(
key=str(meta["__key"]),
vector=list(embedding),
metadata=json.loads(str(meta["__metadata"])),
)

db_entries.append(db_entry)

return db_entries

def __repr__(self) -> str:
"""
Returns the string representation of the object.
Expand Down
34 changes: 33 additions & 1 deletion packages/ragbits-core/src/ragbits/core/vector_store/in_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from itertools import islice

import numpy as np

from ragbits.core.vector_store.base import VectorDBEntry, VectorStore
from ragbits.core.vector_store.base import VectorDBEntry, VectorStore, WhereQuery


class InMemoryVectorStore(VectorStore):
Expand Down Expand Up @@ -45,3 +47,33 @@ async def retrieve(self, vector: list[float], k: int = 5) -> list[VectorDBEntry]
@staticmethod
def _calculate_squared_euclidean(vector_x: list[float], vector_b: list[float]) -> float:
return np.linalg.norm(np.array(vector_x) - np.array(vector_b))

async def list(
self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0
) -> list[VectorDBEntry]:
"""
List entries from the vector store. The entries can be filtered, limited and offset.
Args:
where: The filter dictionary - the keys are the field names and the values are the values to filter by.
Not specifying the key means no filtering.
limit: The maximum number of entries to return.
offset: The number of entries to skip.
Returns:
The entries.
"""
entries = iter(self._storage.values())

if where:
entries = (
entry for entry in entries if all(entry.metadata.get(key) == value for key, value in where.items())
)

if offset:
entries = islice(entries, offset, None)

if limit:
entries = islice(entries, limit)

return list(entries)
Original file line number Diff line number Diff line change
Expand Up @@ -120,13 +120,43 @@ async def test_retrieves_entries_correctly(mock_chromadb_store):
}
]
],
"embeddings": [[[0.12, 0.25, 0.29]]],
}

entries = await mock_chromadb_store.retrieve(vector)

assert len(entries) == 1
assert entries[0].metadata["content"] == "test content"
assert entries[0].metadata["document"]["title"] == "test title"
assert entries[0].vector == [0.12, 0.25, 0.29]


async def test_lists_entries_correctly(mock_chromadb_store):
mock_collection = mock_chromadb_store._get_chroma_collection()
mock_collection.get.return_value = {
"documents": ["test content", "test content 2"],
"metadatas": [
{
"__key": "test_key",
"__metadata": '{"content": "test content", "document": {"title": "test title", "source": {"path": "/test/path"}, "document_type": "test_type"}}',
},
{
"__key": "test_key_2",
"__metadata": '{"content": "test content 2", "document": {"title": "test title 2", "source": {"path": "/test/path"}, "document_type": "test_type"}}',
},
],
"embeddings": [[0.12, 0.25, 0.29], [0.13, 0.26, 0.30]],
}

entries = await mock_chromadb_store.list()

assert len(entries) == 2
assert entries[0].metadata["content"] == "test content"
assert entries[0].metadata["document"]["title"] == "test title"
assert entries[0].vector == [0.12, 0.25, 0.29]
assert entries[1].metadata["content"] == "test content 2"
assert entries[1].metadata["document"]["title"] == "test title 2"
assert entries[1].vector == [0.13, 0.26, 0.30]


async def test_handles_empty_retrieve(mock_chromadb_store):
Expand Down
Loading

0 comments on commit c611e9f

Please sign in to comment.