diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py index 9ba342e4..4512c659 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/base.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/base.py @@ -86,6 +86,15 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None The entries. """ + @abstractmethod + async def remove(self, ids: list[str]) -> None: + """ + Remove entries from the vector store. + + Args: + ids: The list of entries' IDs to remove. + """ + @abstractmethod async def list( self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py index d769cb30..fb31d9a9 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/chroma.py @@ -135,6 +135,16 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None if options.max_distance is None or distance <= options.max_distance ] + @traceable + async def remove(self, ids: list[str]) -> None: + """ + Remove entries from the vector store. + + Args: + ids: The list of entries' IDs to remove. + """ + self._collection.delete(ids=ids) + @traceable async def list( self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py index 28f69608..3ae96c2a 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/in_memory.py @@ -81,6 +81,17 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None if options.max_distance is None or distance <= options.max_distance ] + @traceable + async def remove(self, ids: list[str]) -> None: + """ + Remove entries from the vector store. + + Args: + ids: The list of entries' IDs to remove. + """ + for id in ids: + del self._storage[id] + @traceable async def list( self, where: WhereQuery | None = None, limit: int | None = None, offset: int = 0 diff --git a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py index d5bfeb4f..221bf107 100644 --- a/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py +++ b/packages/ragbits-core/src/ragbits/core/vector_stores/qdrant.py @@ -1,7 +1,8 @@ import json +import typing import qdrant_client -from qdrant_client import AsyncQdrantClient +from qdrant_client import AsyncQdrantClient, models from qdrant_client.models import Distance, Filter, VectorParams from ragbits.core.audit import traceable @@ -146,6 +147,24 @@ async def retrieve(self, vector: list[float], options: VectorStoreOptions | None for id, document, vector, metadata in zip(ids, documents, vectors, metadatas, strict=True) ] + @traceable + async def remove(self, ids: list[str]) -> None: + """ + Remove entries from the vector store. + + Args: + ids: The list of entries' IDs to remove. + + Raises: + ValueError: If collection named `self._index_name` is not present in the vector store. + """ + await self._client.delete( + collection_name=self._index_name, + points_selector=models.PointIdsList( + points=typing.cast(list[int | str], ids), + ), + ) + @traceable async def list( # type: ignore self, @@ -168,10 +187,16 @@ async def list( # type: ignore Raises: MetadataNotFoundError: If the metadata is not found. """ + collection_exists = await self._client.collection_exists(collection_name=self._index_name) + if not collection_exists: + return [] + + limit = limit or (await self._client.count(collection_name=self._index_name)).count + results = await self._client.query_points( collection_name=self._index_name, query_filter=where, - limit=limit or 10, + limit=limit, offset=offset, with_payload=True, with_vectors=True, diff --git a/packages/ragbits-core/tests/integration/vector_stores/test_vector_store.py b/packages/ragbits-core/tests/integration/vector_stores/test_vector_store.py new file mode 100644 index 00000000..91160322 --- /dev/null +++ b/packages/ragbits-core/tests/integration/vector_stores/test_vector_store.py @@ -0,0 +1,59 @@ +from unittest.mock import AsyncMock + +import pytest +from chromadb import EphemeralClient +from qdrant_client import AsyncQdrantClient + +from ragbits.core.vector_stores.base import VectorStore +from ragbits.core.vector_stores.chroma import ChromaVectorStore +from ragbits.core.vector_stores.in_memory import InMemoryVectorStore +from ragbits.core.vector_stores.qdrant import QdrantVectorStore +from ragbits.document_search import DocumentSearch +from ragbits.document_search.documents.document import DocumentMeta +from ragbits.document_search.documents.sources import LocalFileSource + + +@pytest.mark.parametrize( + "vector_store", + [ + InMemoryVectorStore(), + ChromaVectorStore( + client=EphemeralClient(), + index_name="test_index_name", + ), + QdrantVectorStore( + client=AsyncQdrantClient(":memory:"), + index_name="test_index_name", + ), + ], +) +async def test_handling_document_ingestion_with_different_content_and_verifying_replacement( + vector_store: VectorStore, +) -> None: + document_1_content = "This is a test sentence and it should be in the vector store" + document_2_content = "This is another test sentence and it should be removed from the vector store" + document_2_new_content = "This is one more test sentence and it should be added to the vector store" + + document_1 = DocumentMeta.create_text_document_from_literal(document_1_content) + document_2 = DocumentMeta.create_text_document_from_literal(document_2_content) + + embedder = AsyncMock() + embedder.embed_text.return_value = [[0.0], [0.0]] + document_search = DocumentSearch( + embedder=embedder, + vector_store=vector_store, + ) + await document_search.ingest([document_1, document_2]) + + if isinstance(document_2.source, LocalFileSource): + document_2_path = document_2.source.path + with open(document_2_path, "w") as file: + file.write(document_2_new_content) + + await document_search.ingest([document_2]) + + document_contents = {entry.key for entry in await vector_store.list()} + + assert document_1_content in document_contents + assert document_2_new_content in document_contents + assert document_2_content not in document_contents diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_chroma.py b/packages/ragbits-core/tests/unit/vector_stores/test_chroma.py index ed16d16e..11ab486e 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_chroma.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_chroma.py @@ -101,6 +101,15 @@ async def test_retrieve( assert entry.key == result["content"] +async def test_remove(mock_chromadb_store: ChromaVectorStore) -> None: + ids_to_remove = ["1c7d6b27-4ef1-537c-ad7c-676edb8bc8a8"] + + await mock_chromadb_store.remove(ids_to_remove) + + mock_chromadb_store._client.get_or_create_collection().delete.assert_called_once() # type: ignore + mock_chromadb_store._client.get_or_create_collection().delete.assert_called_with(ids=ids_to_remove) # type: ignore + + async def test_list(mock_chromadb_store: ChromaVectorStore) -> None: mock_chromadb_store._collection.get.return_value = { # type: ignore "metadatas": [ diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py b/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py index 9896d6a9..ccb22d24 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_in_memory.py @@ -76,6 +76,16 @@ async def test_retrieve(store: InMemoryVectorStore, k: int, max_distance: float assert entry.metadata["name"] == result +async def test_remove(store: InMemoryVectorStore) -> None: + entries = await store.list() + entry_number = len(entries) + + ids_to_remove = [entries[0].id] + await store.remove(ids_to_remove) + + assert len(await store.list()) == entry_number - 1 + + async def test_list_all(store: InMemoryVectorStore) -> None: results = await store.list() diff --git a/packages/ragbits-core/tests/unit/vector_stores/test_qdrant.py b/packages/ragbits-core/tests/unit/vector_stores/test_qdrant.py index 22cb9391..7868a455 100644 --- a/packages/ragbits-core/tests/unit/vector_stores/test_qdrant.py +++ b/packages/ragbits-core/tests/unit/vector_stores/test_qdrant.py @@ -1,3 +1,4 @@ +import typing from unittest.mock import AsyncMock import pytest @@ -96,7 +97,22 @@ async def test_retrieve(mock_qdrant_store: QdrantVectorStore) -> None: assert entry.vector == result["vector"] +async def test_remove(mock_qdrant_store: QdrantVectorStore) -> None: + ids_to_remove = ["1c7d6b27-4ef1-537c-ad7c-676edb8bc8a8"] + + await mock_qdrant_store.remove(ids_to_remove) + + mock_qdrant_store._client.delete.assert_called_once() # type: ignore + mock_qdrant_store._client.delete.assert_called_with( # type: ignore + collection_name="test_collection", + points_selector=models.PointIdsList( + points=typing.cast(list[int | str], ids_to_remove), + ), + ) + + async def test_list(mock_qdrant_store: QdrantVectorStore) -> None: + mock_qdrant_store._client.collection_exists.return_value = True # type: ignore mock_qdrant_store._client.query_points.return_value = models.QueryResponse( # type: ignore points=[ models.ScoredPoint( diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index db432eba..104ebb61 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -141,8 +141,27 @@ async def ingest( elements = await self.processing_strategy.process_documents( documents, self.document_processor_router, document_processor ) + await self._remove_entries_with_same_sources(elements) await self.insert_elements(elements) + async def _remove_entries_with_same_sources(self, elements: list[Element]) -> None: + """ + Remove entries from the vector store whose source id is present in the elements' metadata. + + Args: + elements: List of elements whose source ids will be checked and removed from the vector store if present. + """ + unique_source_ids = {element.document_meta.source.id for element in elements} + + ids_to_delete = [] + # TODO: Pass 'where' argument to the list method to filter results and optimize search + for entry in await self.vector_store.list(): + if entry.metadata.get("document_meta", {}).get("source", {}).get("id") in unique_source_ids: + ids_to_delete.append(entry.id) + + if ids_to_delete: + await self.vector_store.remove(ids_to_delete) + async def insert_elements(self, elements: list[Element]) -> None: """ Insert Elements into the vector store. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py index d049d349..cb23b37e 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py @@ -121,6 +121,7 @@ def to_vector_db_entry(self, vector: list[float], embedding_type: EmbeddingType) vector_store_entry_id = str(uuid.uuid5(uuid.NAMESPACE_OID, ";".join(id_components))) metadata = self.model_dump(exclude={"id", "key"}) metadata["embedding_type"] = str(embedding_type) + metadata["document_meta"]["source"]["id"] = self.document_meta.source.id return VectorStoreEntry(id=vector_store_entry_id, key=str(self.key), vector=vector, metadata=metadata)