Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Deletion of Vectors by Metadata in PGVector #128

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 61 additions & 49 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,74 +588,86 @@ async def adelete_collection(self) -> None:
await session.commit()

def delete(
self,
ids: Optional[List[str]] = None,
collection_only: bool = False,
**kwargs: Any,
self,
ids: Optional[List[str]] = None,
*,
filter: Optional[Dict[str, Any]] = None,
collection_only: bool = False,
**kwargs: Any,
) -> None:
"""Delete vectors by ids or uuids.
"""Delete vectors by ids or metadata filter.

Args:
ids: List of ids to delete.
collection_only: Only delete ids in the collection.
ids: Optional list of ids to delete.
filter: Optional metadata filter dictionary.
collection_only: If True, delete only from the current collection.
**kwargs: Additional arguments.
"""
with self._make_sync_session() as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
"using the custom ids field)"
)
if ids is None and filter is None:
self.logger.warning("No ids or filter provided for deletion.")
return

stmt = delete(self.EmbeddingStore)
with self._make_sync_session() as session:
stmt = delete(self.EmbeddingStore)
if collection_only:
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found.")
return
stmt = stmt.where(self.EmbeddingStore.collection_id == collection.uuid)

if collection_only:
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
return
if ids is not None:
self.logger.debug("Deleting vectors by ids.")
stmt = stmt.where(self.EmbeddingStore.id.in_(ids))

stmt = stmt.where(
self.EmbeddingStore.collection_id == collection.uuid
)
if filter is not None:
self.logger.debug("Deleting vectors by metadata filter.")
filter_clause = self._create_filter_clause(filter)
stmt = stmt.where(filter_clause)

stmt = stmt.where(self.EmbeddingStore.id.in_(ids))
session.execute(stmt)
session.execute(stmt)
session.commit()

async def adelete(
self,
ids: Optional[List[str]] = None,
collection_only: bool = False,
**kwargs: Any,
self,
ids: Optional[List[str]] = None,
*,
filter: Optional[Dict[str, Any]] = None,
collection_only: bool = False,
**kwargs: Any,
) -> None:
"""Async delete vectors by ids or uuids.
"""Asynchronously delete vectors by ids or metadata filter.

Args:
ids: List of ids to delete.
collection_only: Only delete ids in the collection.
ids: Optional list of ids to delete.
filter: Optional metadata filter dictionary.
collection_only: If True, delete only from the current collection.
**kwargs: Additional arguments.
"""
await self.__apost_init__() # Lazy async init
async with self._make_async_session() as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
"using the custom ids field)"
)
if ids is None and filter is None:
self.logger.warning("No ids or filter provided for deletion.")
return

stmt = delete(self.EmbeddingStore)
await self.__apost_init__()
async with self._make_async_session() as session:
stmt = delete(self.EmbeddingStore)
if collection_only:
collection = await self.aget_collection(session)
if not collection:
self.logger.warning("Collection not found.")
return
stmt = stmt.where(self.EmbeddingStore.collection_id == collection.uuid)

if collection_only:
collection = await self.aget_collection(session)
if not collection:
self.logger.warning("Collection not found")
return
if ids is not None:
self.logger.debug("Deleting vectors by ids.")
stmt = stmt.where(self.EmbeddingStore.id.in_(ids))

stmt = stmt.where(
self.EmbeddingStore.collection_id == collection.uuid
)
if filter is not None:
self.logger.debug("Deleting vectors by metadata filter.")
filter_clause = self._create_filter_clause(filter)
stmt = stmt.where(filter_clause)

stmt = stmt.where(self.EmbeddingStore.id.in_(ids))
await session.execute(stmt)
await session.execute(stmt)
await session.commit()

def get_collection(self, session: Session) -> Any:
Expand Down
49 changes: 49 additions & 0 deletions tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,55 @@ async def test_async_pgvector_delete_docs() -> None:
assert sorted(record.id for record in records) == [] # type: ignore


def test_pgvector_delete_by_metadata() -> None:
"""Test deleting documents by metadata."""
texts = ["foo", "bar", "baz"]
metadatas = [{"category": "news"}, {"category": "sports"}, {"category": "news"}]
vectorstore = PGVector.from_texts(
texts=texts,
collection_name="test_delete_by_metadata",
embedding=FakeEmbeddingsWithAdaDimension(),
metadatas=metadatas,
ids=["1", "2", "3"],
connection=CONNECTION_STRING,
pre_delete_collection=True,
)
# Delete documents where category is 'news'
vectorstore.delete(filter={"category": {"$eq": "news"}})
with vectorstore.session_maker() as session:
records = list(session.query(vectorstore.EmbeddingStore).all())
# Should only have the document with category 'sports' remaining
assert len(records) == 1
assert records[0].id == "2"
assert records[0].cmetadata["category"] == "sports"


@pytest.mark.asyncio
async def test_async_pgvector_delete_by_metadata() -> None:
"""Test deleting documents by metadata asynchronously."""
texts = ["foo", "bar", "baz"]
metadatas = [{"category": "news"}, {"category": "sports"}, {"category": "news"}]
vectorstore = await PGVector.afrom_texts(
texts=texts,
collection_name="test_delete_by_metadata",
embedding=FakeEmbeddingsWithAdaDimension(),
metadatas=metadatas,
ids=["1", "2", "3"],
connection=CONNECTION_STRING,
pre_delete_collection=True,
)
# Delete documents where category is 'news'
await vectorstore.adelete(filter={"category": {"$eq": "news"}})
async with vectorstore.session_maker() as session:
records = (
(await session.execute(select(vectorstore.EmbeddingStore))).scalars().all()
)
# Should only have the document with category 'sports' remaining
assert len(records) == 1
assert records[0].id == "2"
assert records[0].cmetadata["category"] == "sports"


def test_pgvector_index_documents() -> None:
"""Test adding duplicate documents results in overwrites."""
documents = [
Expand Down