From a7d345b75812ef99e98c33783dca9444147646e8 Mon Sep 17 00:00:00 2001 From: Philippe PRADOS Date: Mon, 10 Jun 2024 20:01:26 +0200 Subject: [PATCH] Add async mode for pgvector (v2) (#64) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I made the mistake of doing a rebase while a review was in progress. This seems to block the process. There's a ‘1 change requested’ request that I can't validate. I propose another PR identical to the [previous one](https://github.com/langchain-ai/langchain-postgres/pull/32). This PR adds the **async** approach for pgvector. Some remarks: - We use assert to check invocations and not if. Thus, in production, it is possible to remove these checks with `python -O ...` - We propose a public `session_maker` attribute. This is very important for resilient scenarios. In a RAG architecture, it is necessary to import document chunks. To keep track of the links between chunks and documents, we can use the [index()](https://python.langchain.com/docs/modules/data_connection/indexing/) API. This API proposes to use an SQL-type record manager. In a classic use case, using `SQLRecordManager` and a vector database, it is impossible to guarantee the consistency of the import. Indeed, if a crash occurs during the import, there is an inconsistency between the SQL database and the vector database. **PGVector is the solution to this problem.** Indeed, it is possible to use a single database (and not a two-phase commit with 2 technologies, if they are both compatible). But, for this, it is necessary to be able to combine the transactions between the use of `SQLRecordManager` and `PGVector` as a vector database. This is only possible if it is possible to intervene on the `session_maker`. This is why we propose to make this attribute public. By unifying the `session_maker` of `SQLRecordManager` and `PGVector`, it is possible to guarantee that all processes will be executed in a single transaction. This is, moreover, the only solution we know of to guarantee the consistency of the import of chunks into a vector database. It's possible only if the outer session is built with the connection. ```python def main(): db_url = "postgresql+psycopg://postgres:password_postgres@localhost:5432/" engine = create_engine(db_url, echo=True) embeddings = FakeEmbeddings() pgvector:VectorStore = PGVector( embeddings=embeddings, connection=engine, ) record_manager = SQLRecordManager( namespace="namespace", engine=engine, ) record_manager.create_schema() with engine.connect() as connection: session_maker = scoped_session(sessionmaker(bind=connection)) # NOTE: Update session_factories record_manager.session_factory = session_maker pgvector.session_maker = session_maker with connection.begin(): loader = CSVLoader( "data/faq/faq.csv", source_column="source", autodetect_encoding=True, ) result = index( source_id_key="source", docs_source=loader.load()[:1], cleanup="incremental", vector_store=pgvector, record_manager=record_manager, ) print(result) ``` The same thing is possible asynchronously, but a bug in `sql_record_manager.py` in `_amake_session()` must first be fixed (See [PR](https://github.com/langchain-ai/langchain/pull/20735) ). ```python async def _amake_session(self) -> AsyncGenerator[AsyncSession, None]: """Create a session and close it after use.""" # FIXME: REMOVE if not isinstance(self.session_factory, async_sessionmaker):~~ if not isinstance(self.engine, AsyncEngine): raise AssertionError("This method is not supported for sync engines.") async with self.session_factory() as session: yield session ``` Then, it is possible to do the same thing asynchronously: ```python async def main(): db_url = "postgresql+psycopg://postgres:password_postgres@localhost:5432/" engine = create_async_engine(db_url, echo=True) embeddings = FakeEmbeddings() pgvector:VectorStore = PGVector( embeddings=embeddings, connection=engine, ) record_manager = SQLRecordManager( namespace="namespace", engine=engine, async_mode=True, ) await record_manager.acreate_schema() async with engine.connect() as connection: session_maker = async_scoped_session( async_sessionmaker(bind=connection), scopefunc=current_task) record_manager.session_factory = session_maker pgvector.session_maker = session_maker async with connection.begin(): loader = CSVLoader( "data/faq/faq.csv", source_column="source", autodetect_encoding=True, ) result = await aindex( source_id_key="source", docs_source=loader.load()[:1], cleanup="incremental", vector_store=pgvector, record_manager=record_manager, ) print(result) asyncio.run(main()) ``` The promise of the constructor, with the `create_extension` parameter, is to guarantee that the extension is added before the APIs are used. Since this promise cannot be kept in an `async` scenario, there is an alternative: - Remove this parameter, since the promise cannot be kept. Otherwise, an `async` method is needed to install the extension before the APIs are used, and to check that this method has been invoked at the start of each API. - Use a lazy approach as suggested, which simply respects the constructor's promise. --- langchain_postgres/vectorstores.py | 904 ++++++++++++++++++++++++--- pyproject.toml | 4 + tests/unit_tests/test_vectorstore.py | 506 ++++++++++++++- 3 files changed, 1328 insertions(+), 86 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index 89b78fb..659b6f3 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,36 +1,50 @@ +# pylint: disable=too-many-lines from __future__ import annotations +import contextlib import enum import logging import uuid from typing import ( Any, + AsyncGenerator, Callable, Dict, + Generator, Iterable, List, Optional, + Sequence, Tuple, Type, Union, ) +from typing import ( + cast as typing_cast, +) import numpy as np import sqlalchemy -from sqlalchemy import SQLColumnExpression, cast, delete, func -from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert -from sqlalchemy.orm import Session, relationship, sessionmaker - -try: - from sqlalchemy.orm import declarative_base -except ImportError: - from sqlalchemy.ext.declarative import declarative_base - from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.runnables.config import run_in_executor from langchain_core.utils import get_from_dict_or_env from langchain_core.vectorstores import VectorStore +from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select +from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert +from sqlalchemy.engine import Connection, Engine +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import ( + Session, + declarative_base, + relationship, + scoped_session, + sessionmaker, +) from langchain_postgres._utils import maximal_marginal_relevance @@ -112,7 +126,27 @@ class CollectionStore(Base): def get_by_name( cls, session: Session, name: str ) -> Optional["CollectionStore"]: - return session.query(cls).filter(cls.name == name).first() # type: ignore + return ( + session.query(cls) + .filter(typing_cast(sqlalchemy.Column, cls.name) == name) + .first() + ) + + @classmethod + async def aget_by_name( + cls, session: AsyncSession, name: str + ) -> Optional["CollectionStore"]: + return ( + ( + await session.execute( + select(CollectionStore).where( + typing_cast(sqlalchemy.Column, cls.name) == name + ) + ) + ) + .scalars() + .first() + ) @classmethod def get_or_create( @@ -136,6 +170,28 @@ def get_or_create( created = True return collection, created + @classmethod + async def aget_or_create( + cls, + session: AsyncSession, + name: str, + cmetadata: Optional[dict] = None, + ) -> Tuple["CollectionStore", bool]: + """ + Get or create a collection. + Returns [Collection, bool] where the bool is True if the collection was created. + """ # noqa: E501 + created = False + collection = await cls.aget_by_name(session, name) + if collection: + return collection, created + + collection = cls(name=name, cmetadata=cmetadata) + session.add(collection) + await session.commit() + created = True + return collection, created + class EmbeddingStore(Base): """Embedding store.""" @@ -177,7 +233,16 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]: return [doc for doc, _ in docs_and_scores] -Connection = Union[sqlalchemy.engine.Engine, str] +def _create_vector_extension(conn: Connection) -> None: + statement = sqlalchemy.text( + "SELECT pg_advisory_xact_lock(1573678846307946496);" + "CREATE EXTENSION IF NOT EXISTS vector;" + ) + conn.execute(statement) + conn.commit() + + +DBConnection = Union[sqlalchemy.engine.Engine, str] class PGVector(VectorStore): @@ -215,6 +280,7 @@ class PGVector(VectorStore): connection=connection_string, collection_name=collection_name, use_jsonb=True, + async_mode=False, ) @@ -232,13 +298,16 @@ class PGVector(VectorStore): You will need to recreate the tables if you are using an existing database. * A Connection object has to be provided explicitly. Connections will not be picked up automatically based on env variables. + * langchain_postgres now accept async connections. If you want to use the async + version, you need to set `async_mode=True` when initializing the store or + use an async engine. """ def __init__( self, embeddings: Embeddings, *, - connection: Optional[Connection] = None, + connection: Union[None, DBConnection, Engine, AsyncEngine, str] = None, embedding_length: Optional[int] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, collection_metadata: Optional[dict] = None, @@ -249,11 +318,13 @@ def __init__( engine_args: Optional[dict[str, Any]] = None, use_jsonb: bool = True, create_extension: bool = True, + async_mode: bool = False, ) -> None: """Initialize the PGVector store. + For an async version, use `PGVector.acreate()` instead. Args: - connection: Postgres connection string. + connection: Postgres connection string or (async)engine. embeddings: Any embedding function implementing `langchain.embeddings.base.Embeddings` interface. embedding_length: The length of the embedding vector. (default: None) @@ -277,6 +348,7 @@ def __init__( doesn't exist. disabling creation is useful when using ReadOnly Databases. """ + self.async_mode = async_mode self.embedding_function = embeddings self._embedding_length = embedding_length self.collection_name = collection_name @@ -285,20 +357,33 @@ def __init__( self.pre_delete_collection = pre_delete_collection self.logger = logger or logging.getLogger(__name__) self.override_relevance_score_fn = relevance_score_fn + self._engine: Optional[Engine] = None + self._async_engine: Optional[AsyncEngine] = None + self._async_init = False if isinstance(connection, str): - self._engine = sqlalchemy.create_engine( - url=connection, **(engine_args or {}) - ) - elif isinstance(connection, sqlalchemy.engine.Engine): + if async_mode: + self._async_engine = create_async_engine( + connection, **(engine_args or {}) + ) + else: + self._engine = create_engine(url=connection, **(engine_args or {})) + elif isinstance(connection, Engine): + self.async_mode = False self._engine = connection + elif isinstance(connection, AsyncEngine): + self.async_mode = True + self._async_engine = connection else: raise ValueError( "connection should be a connection string or an instance of " - "sqlalchemy.engine.Engine" + "sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine" ) - - self._session_maker = sessionmaker(bind=self._engine) + self.session_maker: Union[scoped_session, async_sessionmaker] + if self.async_mode: + self.session_maker = async_sessionmaker(bind=self._async_engine) + else: + self.session_maker = scoped_session(sessionmaker(bind=self._engine)) self.use_jsonb = use_jsonb self.create_extension = create_extension @@ -306,7 +391,8 @@ def __init__( if not use_jsonb: # Replace with a deprecation warning. raise NotImplementedError("use_jsonb=False is no longer supported.") - self.__post_init__() + if not self.async_mode: + self.__post_init__() def __post_init__( self, @@ -323,52 +409,99 @@ def __post_init__( self.create_tables_if_not_exists() self.create_collection() - def __del__(self) -> None: - if isinstance(self._engine, sqlalchemy.engine.Connection): - self._engine.close() + async def __apost_init__( + self, + ) -> None: + """Async initialize the store (use lazy approach).""" + if self._async_init: # Warning: possible race condition + return + self._async_init = True + + EmbeddingStore, CollectionStore = _get_embedding_collection_store( + self._embedding_length + ) + self.CollectionStore = CollectionStore + self.EmbeddingStore = EmbeddingStore + if self.create_extension: + await self.acreate_vector_extension() + + await self.acreate_tables_if_not_exists() + await self.acreate_collection() @property def embeddings(self) -> Embeddings: return self.embedding_function def create_vector_extension(self) -> None: + assert self._engine, "engine not found" try: - with self._session_maker() as session: # type: ignore[arg-type] - # The advisor lock fixes issue arising from concurrent - # creation of the vector extension. - # https://github.com/langchain-ai/langchain/issues/12933 - # For more information see: - # https://www.postgresql.org/docs/16/explicit-locking.html#ADVISORY-LOCKS - statement = sqlalchemy.text( - "BEGIN;" - "SELECT pg_advisory_xact_lock(1573678846307946496);" - "CREATE EXTENSION IF NOT EXISTS vector;" - "COMMIT;" - ) - session.execute(statement) - session.commit() + with self._engine.connect() as conn: + _create_vector_extension(conn) except Exception as e: raise Exception(f"Failed to create vector extension: {e}") from e + async def acreate_vector_extension(self) -> None: + assert self._async_engine, "_async_engine not found" + + async with self._async_engine.begin() as conn: + await conn.run_sync(_create_vector_extension) + def create_tables_if_not_exists(self) -> None: - with self._session_maker() as session: + with self._make_sync_session() as session: Base.metadata.create_all(session.get_bind()) + session.commit() + + async def acreate_tables_if_not_exists(self) -> None: + assert self._async_engine, "This method must be called with async_mode" + async with self._async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: - with self._session_maker() as session: + with self._make_sync_session() as session: Base.metadata.drop_all(session.get_bind()) + session.commit() + + async def adrop_tables(self) -> None: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + async with self._async_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() - with self._session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) + session.commit() + + async def acreate_collection(self) -> None: + await self.__apost_init__() # Lazy async init + async with self._make_async_session() as session: + if self.pre_delete_collection: + await self._adelete_collection(session) + await self.CollectionStore.aget_or_create( + session, self.collection_name, cmetadata=self.collection_metadata + ) + await session.commit() + + def _delete_collection(self, session: Session) -> None: + collection = self.get_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + session.delete(collection) + + async def _adelete_collection(self, session: AsyncSession) -> None: + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + await session.delete(collection) def delete_collection(self) -> None: - self.logger.debug("Trying to delete collection") - with self._session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -376,6 +509,16 @@ def delete_collection(self) -> None: session.delete(collection) session.commit() + async def adelete_collection(self) -> None: + await self.__apost_init__() # Lazy async init + async with self._make_async_session() as session: + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + await session.delete(collection) + await session.commit() + def delete( self, ids: Optional[List[str]] = None, @@ -388,7 +531,7 @@ def delete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - with self._session_maker() as session: + with self._make_sync_session() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -411,9 +554,51 @@ def delete( session.execute(stmt) session.commit() + async def adelete( + self, + ids: Optional[List[str]] = None, + collection_only: bool = False, + **kwargs: Any, + ) -> None: + """Async delete vectors by ids or uuids. + + Args: + ids: List of ids to delete. + collection_only: Only delete ids in the collection. + """ + await self.__apost_init__() # Lazy async init + async with self._make_async_session() as session: + if ids is not None: + self.logger.debug( + "Trying to delete vectors by ids (represented by the model " + "using the custom ids field)" + ) + + stmt = delete(self.EmbeddingStore) + + if collection_only: + collection = await self.aget_collection(session) + if not collection: + self.logger.warning("Collection not found") + return + + stmt = stmt.where( + self.EmbeddingStore.collection_id == collection.uuid + ) + + stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) + await session.execute(stmt) + await session.commit() + def get_collection(self, session: Session) -> Any: + assert not self._async_engine, "This method must be called without async_mode" return self.CollectionStore.get_by_name(session, self.collection_name) + async def aget_collection(self, session: AsyncSession) -> Any: + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + return await self.CollectionStore.aget_by_name(session, self.collection_name) + @classmethod def __from( cls, @@ -452,6 +637,45 @@ def __from( return store + @classmethod + async def __afrom( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + connection: Optional[str] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + store = cls( + connection=connection, + collection_name=collection_name, + embeddings=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + async_mode=True, + **kwargs, + ) + + await store.aadd_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + + return store + def add_embeddings( self, texts: Iterable[str], @@ -466,15 +690,18 @@ def add_embeddings( texts: Iterable of strings to add to the vectorstore. embeddings: List of list of embedding vectors. metadatas: List of metadatas associated with the texts. + ids: Optional list of ids for the documents. + If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ + assert not self._async_engine, "This method must be called with sync_mode" if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - with self._session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -505,6 +732,62 @@ def add_embeddings( return ids + async def aadd_embeddings( + self, + texts: Iterable[str], + embeddings: List[List[float]], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Async add embeddings to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + embeddings: List of list of embedding vectors. + metadatas: List of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. + kwargs: vectorstore specific parameters + """ + await self.__apost_init__() # Lazy async init + if ids is None: + ids = [str(uuid.uuid1()) for _ in texts] + + if not metadatas: + metadatas = [{} for _ in texts] + + async with self._make_async_session() as session: # type: ignore[arg-type] + collection = await self.aget_collection(session) + if not collection: + raise ValueError("Collection not found") + data = [ + { + "id": id, + "collection_id": collection.uuid, + "embedding": embedding, + "document": text, + "cmetadata": metadata or {}, + } + for text, metadata, embedding, id in zip( + texts, metadatas, embeddings, ids + ) + ] + stmt = insert(self.EmbeddingStore).values(data) + on_conflict_stmt = stmt.on_conflict_do_update( + index_elements=["id"], + # Conflict detection based on these columns + set_={ + "embedding": stmt.excluded.embedding, + "document": stmt.excluded.document, + "cmetadata": stmt.excluded.cmetadata, + }, + ) + await session.execute(on_conflict_stmt) + await session.commit() + + return ids + def add_texts( self, texts: Iterable[str], @@ -517,16 +800,44 @@ def add_texts( Args: texts: Iterable of strings to add to the vectorstore. metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters Returns: List of ids from adding the texts into the vectorstore. """ + assert not self._async_engine, "This method must be called without async_mode" embeddings = self.embedding_function.embed_documents(list(texts)) return self.add_embeddings( texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs ) + async def aadd_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[str]: + """Run more texts through the embeddings and add to the vectorstore. + + Args: + texts: Iterable of strings to add to the vectorstore. + metadatas: Optional list of metadatas associated with the texts. + ids: Optional list of ids for the texts. + If not provided, will generate a new id for each text. + kwargs: vectorstore specific parameters + + Returns: + List of ids from adding the texts into the vectorstore. + """ + await self.__apost_init__() # Lazy async init + embeddings = await self.embedding_function.aembed_documents(list(texts)) + return await self.aadd_embeddings( + texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs + ) + def similarity_search( self, query: str, @@ -544,6 +855,7 @@ def similarity_search( Returns: List of Documents most similar to the query. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(text=query) return self.similarity_search_by_vector( embedding=embedding, @@ -551,6 +863,31 @@ def similarity_search( filter=filter, ) + async def asimilarity_search( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Run similarity search with PGVector with distance. + + Args: + query (str): Query text to search for. + k (int): Number of results to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query. + """ + await self.__apost_init__() # Lazy async init + embedding = self.embedding_function.embed_query(text=query) + return await self.asimilarity_search_by_vector( + embedding=embedding, + k=k, + filter=filter, + ) + def similarity_search_with_score( self, query: str, @@ -567,12 +904,36 @@ def similarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) return docs + async def asimilarity_search_with_score( + self, + query: str, + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query and score for each. + """ + await self.__apost_init__() # Lazy async init + embedding = self.embedding_function.embed_query(query) + docs = await self.asimilarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return docs + @property def distance_strategy(self) -> Any: if self._distance_strategy == DistanceStrategy.EUCLIDEAN: @@ -593,10 +954,25 @@ def similarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: + assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) + async def asimilarity_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + ) -> List[Tuple[Document, float]]: + await self.__apost_init__() # Lazy async init + async with self._make_async_session() as session: # type: ignore[arg-type] + results = await self.__aquery_collection( + session=session, embedding=embedding, k=k, filter=filter + ) + + return self._results_to_docs_and_scores(results) + def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]: """Return docs and scores from results.""" docs = [ @@ -908,9 +1284,9 @@ def __query_collection( embedding: List[float], k: int = 4, filter: Optional[Dict[str, str]] = None, - ) -> List[Any]: + ) -> Sequence[Any]: """Query the collection.""" - with self._session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -931,7 +1307,7 @@ def __query_collection( results: List[Any] = ( session.query( self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore + self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) @@ -945,28 +1321,97 @@ def __query_collection( return results - def similarity_search_by_vector( + async def __aquery_collection( self, + session: AsyncSession, embedding: List[float], k: int = 4, - filter: Optional[dict] = None, - **kwargs: Any, - ) -> List[Document]: - """Return docs most similar to embedding vector. + filter: Optional[Dict[str, str]] = None, + ) -> Sequence[Any]: + """Query the collection.""" + async with self._make_async_session() as session: # type: ignore[arg-type] + collection = await self.aget_collection(session) + if not collection: + raise ValueError("Collection not found") - Args: - embedding: Embedding to look up documents similar to. - k: Number of Documents to return. Defaults to 4. + filter_by = [self.EmbeddingStore.collection_id == collection.uuid] + if filter: + if self.use_jsonb: + filter_clauses = self._create_filter_clause(filter) + if filter_clauses is not None: + filter_by.append(filter_clauses) + else: + # Old way of doing things + filter_clauses = self._create_filter_clause_json_deprecated(filter) + filter_by.extend(filter_clauses) + + _type = self.EmbeddingStore + + stmt = ( + select( + self.EmbeddingStore, + self.distance_strategy(embedding).label("distance"), + ) + .filter(*filter_by) + .order_by(sqlalchemy.asc("distance")) + .join( + self.CollectionStore, + self.EmbeddingStore.collection_id == self.CollectionStore.uuid, + ) + .limit(k) + ) + + results: Sequence[Any] = (await session.execute(stmt)).all() + + return results + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. Returns: List of Documents most similar to the query vector. """ + assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.similarity_search_with_score_by_vector( embedding=embedding, k=k, filter=filter ) return _results_to_docs(docs_and_scores) + async def asimilarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List of Documents most similar to the query vector. + """ + assert self._async_engine, "This method must be called with async_mode" + await self.__apost_init__() # Lazy async init + docs_and_scores = await self.asimilarity_search_with_score_by_vector( + embedding=embedding, k=k, filter=filter + ) + return _results_to_docs(docs_and_scores) + @classmethod def from_texts( cls: Type[PGVector], @@ -997,6 +1442,35 @@ def from_texts( **kwargs, ) + @classmethod + async def afrom_texts( + cls: Type[PGVector], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + """Return VectorStore initialized from documents and embeddings.""" + embeddings = embedding.embed_documents(list(texts)) + return await cls.__afrom( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + use_jsonb=use_jsonb, + **kwargs, + ) + @classmethod def from_embeddings( cls, @@ -1019,6 +1493,7 @@ def from_embeddings( collection_name: Name of the collection. distance_strategy: Distance strategy to use. ids: Optional list of ids for the documents. + If not provided, will generate a new id for each document. pre_delete_collection: If True, will delete the collection if it exists. **Attention**: This will delete all the documents in the existing collection. @@ -1053,6 +1528,51 @@ def from_embeddings( **kwargs, ) + @classmethod + async def afrom_embeddings( + cls, + text_embeddings: List[Tuple[str, List[float]]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + **kwargs: Any, + ) -> PGVector: + """Construct PGVector wrapper from raw documents and pre- + generated embeddings. + + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + + Example: + .. code-block:: python + + from langchain_community.vectorstores import PGVector + from langchain_community.embeddings import OpenAIEmbeddings + embeddings = OpenAIEmbeddings() + text_embeddings = embeddings.embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + faiss = PGVector.from_embeddings(text_embedding_pairs, embeddings) + """ + texts = [t[0] for t in text_embeddings] + embeddings = [t[1] for t in text_embeddings] + + return await cls.__afrom( + texts, + embeddings, + embedding, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + **kwargs, + ) + @classmethod def from_existing_index( cls: Type[PGVector], @@ -1061,7 +1581,7 @@ def from_existing_index( collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, pre_delete_collection: bool = False, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, **kwargs: Any, ) -> PGVector: """ @@ -1080,11 +1600,39 @@ def from_existing_index( return store + @classmethod + async def afrom_existing_index( + cls: Type[PGVector], + embedding: Embeddings, + *, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + pre_delete_collection: bool = False, + connection: Optional[DBConnection] = None, + **kwargs: Any, + ) -> PGVector: + """ + Get instance of an existing PGVector store.This method will + return the instance of the store without inserting any new + embeddings + """ + store = PGVector( + connection=connection, + collection_name=collection_name, + embeddings=embedding, + distance_strategy=distance_strategy, + pre_delete_collection=pre_delete_collection, + async_mode=True, + **kwargs, + ) + + return store + @classmethod def get_connection_string(cls, kwargs: Dict[str, Any]) -> str: connection_string: str = get_from_dict_or_env( data=kwargs, - key="connection_string", + key="connection", env_key="PGVECTOR_CONNECTION_STRING", ) @@ -1103,7 +1651,7 @@ def from_documents( documents: List[Document], embedding: Embeddings, *, - connection: Optional[Connection] = None, + connection: Optional[DBConnection] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, ids: Optional[List[str]] = None, @@ -1129,6 +1677,44 @@ def from_documents( **kwargs, ) + @classmethod + async def afrom_documents( + cls: Type[PGVector], + documents: List[Document], + embedding: Embeddings, + collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, + distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY, + ids: Optional[List[str]] = None, + pre_delete_collection: bool = False, + *, + use_jsonb: bool = True, + **kwargs: Any, + ) -> PGVector: + """ + Return VectorStore initialized from documents and embeddings. + Postgres connection string is required + "Either pass it as a parameter + or set the PGVECTOR_CONNECTION_STRING environment variable. + """ + + texts = [d.page_content for d in documents] + metadatas = [d.metadata for d in documents] + connection_string = cls.get_connection_string(kwargs) + + kwargs["connection"] = connection_string + + return await cls.afrom_texts( + texts=texts, + pre_delete_collection=pre_delete_collection, + embedding=embedding, + distance_strategy=distance_strategy, + metadatas=metadatas, + ids=ids, + collection_name=collection_name, + use_jsonb=use_jsonb, + **kwargs, + ) + @classmethod def connection_string_from_db_params( cls, @@ -1201,6 +1787,7 @@ def max_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ + assert not self._async_engine, "This method must be called without async_mode" results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -1216,6 +1803,55 @@ def max_marginal_relevance_search_with_score_by_vector( return [r for i, r in enumerate(candidates) if i in mmr_selected] + async def amax_marginal_relevance_search_with_score_by_vector( + self, + embedding: List[float], + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + await self.__apost_init__() # Lazy async init + async with self._make_async_session() as session: + results = await self.__aquery_collection( + session=session, embedding=embedding, k=fetch_k, filter=filter + ) + + embedding_list = [result.EmbeddingStore.embedding for result in results] + + mmr_selected = maximal_marginal_relevance( + np.array(embedding, dtype=np.float32), + embedding_list, + k=k, + lambda_mult=lambda_mult, + ) + + candidates = self._results_to_docs_and_scores(results) + + return [r for i, r in enumerate(candidates) if i in mmr_selected] + def max_marginal_relevance_search( self, query: str, @@ -1254,6 +1890,45 @@ def max_marginal_relevance_search( **kwargs, ) + async def amax_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[Dict[str, str]] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + await self.__apost_init__() # Lazy async init + embedding = self.embedding_function.embed_query(query) + return await self.amax_marginal_relevance_search_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + def max_marginal_relevance_search_with_score( self, query: str, @@ -1294,6 +1969,47 @@ def max_marginal_relevance_search_with_score( ) return docs + async def amax_marginal_relevance_search_with_score( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + lambda_mult: float = 0.5, + filter: Optional[dict] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance with score. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Tuple[Document, float]]: List of Documents selected by maximal marginal + relevance to the query and score for each. + """ + await self.__apost_init__() # Lazy async init + embedding = self.embedding_function.embed_query(query) + docs = await self.amax_marginal_relevance_search_with_score_by_vector( + embedding=embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) + return docs + def max_marginal_relevance_search_by_vector( self, embedding: List[float], @@ -1343,18 +2059,58 @@ async def amax_marginal_relevance_search_by_vector( filter: Optional[Dict[str, str]] = None, **kwargs: Any, ) -> List[Document]: - """Return docs selected using the maximal marginal relevance.""" - - # This is a temporary workaround to make the similarity search - # asynchronous. The proper solution is to make the similarity search - # asynchronous in the vector store implementations. - return await run_in_executor( - None, - self.max_marginal_relevance_search_by_vector, - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, + """Return docs selected using the maximal marginal relevance + to embedding vector. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding (str): Text to look up documents similar to. + k (int): Number of Documents to return. Defaults to 4. + fetch_k (int): Number of Documents to fetch to pass to MMR algorithm. + Defaults to 20. + lambda_mult (float): Number between 0 and 1 that determines the degree + of diversity among the results with 0 corresponding + to maximum diversity and 1 to minimum diversity. + Defaults to 0.5. + filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None. + + Returns: + List[Document]: List of Documents selected by maximal marginal relevance. + """ + await self.__apost_init__() # Lazy async init + docs_and_scores = ( + await self.amax_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, + ) ) + + return _results_to_docs(docs_and_scores) + + @contextlib.contextmanager + def _make_sync_session(self) -> Generator[Session, None, None]: + """Make an async session.""" + if self.async_mode: + raise ValueError( + "Attempting to use a sync method in when async mode is turned on. " + "Please use the corresponding async method instead." + ) + with self.session_maker() as session: + yield typing_cast(Session, session) + + @contextlib.asynccontextmanager + async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: + """Make an async session.""" + if not self.async_mode: + raise ValueError( + "Attempting to use an async method in when sync mode is turned on. " + "Please use the corresponding async method instead." + ) + async with self.session_maker() as session: + yield typing_cast(AsyncSession, session) diff --git a/pyproject.toml b/pyproject.toml index 9ab45d9..fa3a5d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,3 +88,7 @@ timeout = 30 markers = [] asyncio_mode = "auto" +[tool.codespell] +skip = '.git,*.pdf,*.svg,*.pdf,*.yaml,*.ipynb,poetry.lock,*.min.js,*.css,package-lock.json,example_data,_dist,examples,templates,*.trig' +ignore-regex = '.*(Stati Uniti|Tense=Pres).*' +ignore-words-list = 'momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogyny,unsecure,damon,crate,aadd,symbl,precesses,accademia,nin' diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 7516968..fcba8ef 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -1,9 +1,10 @@ """Test PGVector functionality.""" import contextlib -from typing import Any, Dict, Generator, List +from typing import Any, AsyncGenerator, Dict, Generator, List import pytest from langchain_core.documents import Document +from sqlalchemy import select from langchain_postgres.vectorstores import ( SUPPORTED_OPERATORS, @@ -17,7 +18,6 @@ TYPE_3_FILTERING_TEST_CASES, TYPE_4_FILTERING_TEST_CASES, TYPE_5_FILTERING_TEST_CASES, - TYPE_6_FILTERING_TEST_CASES, ) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING @@ -38,7 +38,7 @@ def embed_query(self, text: str) -> List[float]: return [float(1.0)] * (ADA_TOKEN_COUNT - 1) + [float(0.0)] -def test_pgvector(pgvector: PGVector) -> None: +def test_pgvector() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] docsearch = PGVector.from_texts( @@ -52,6 +52,21 @@ def test_pgvector(pgvector: PGVector) -> None: assert output == [Document(page_content="foo")] +@pytest.mark.asyncio +async def test_async_pgvector() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_pgvector_embeddings() -> None: """Test end to end construction with embeddings and search.""" texts = ["foo", "bar", "baz"] @@ -68,6 +83,23 @@ def test_pgvector_embeddings() -> None: assert output == [Document(page_content="foo")] +@pytest.mark.asyncio +async def test_async_pgvector_embeddings() -> None: + """Test end to end construction with embeddings and search.""" + texts = ["foo", "bar", "baz"] + text_embeddings = FakeEmbeddingsWithAdaDimension().embed_documents(texts) + text_embedding_pairs = list(zip(texts, text_embeddings)) + docsearch = await PGVector.afrom_embeddings( + text_embeddings=text_embedding_pairs, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_pgvector_with_metadatas() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -84,6 +116,23 @@ def test_pgvector_with_metadatas() -> None: assert output == [Document(page_content="foo", metadata={"page": "0"})] +@pytest.mark.asyncio +async def test_async_pgvector_with_metadatas() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo", metadata={"page": "0"})] + + def test_pgvector_with_metadatas_with_scores() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -100,6 +149,23 @@ def test_pgvector_with_metadatas_with_scores() -> None: assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] +@pytest.mark.asyncio +async def test_async_pgvector_with_metadatas_with_scores() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score("foo", k=1) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + def test_pgvector_with_filter_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -116,6 +182,25 @@ def test_pgvector_with_filter_match() -> None: assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] +@pytest.mark.asyncio +async def test_async_pgvector_with_filter_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "0"} + ) + assert output == [(Document(page_content="foo", metadata={"page": "0"}), 0.0)] + + def test_pgvector_with_filter_distant_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -134,6 +219,27 @@ def test_pgvector_with_filter_distant_match() -> None: ] +@pytest.mark.asyncio +async def test_async_pgvector_with_filter_distant_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "2"} + ) + assert output == [ + (Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406) + ] + + def test_pgvector_with_filter_no_match() -> None: """Test end to end construction and search.""" texts = ["foo", "bar", "baz"] @@ -150,6 +256,25 @@ def test_pgvector_with_filter_no_match() -> None: assert output == [] +@pytest.mark.asyncio +async def test_async_pgvector_with_filter_no_match() -> None: + """Test end to end construction and search.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search_with_score( + "foo", k=1, filter={"page": "5"} + ) + assert output == [] + + def test_pgvector_collection_with_metadata() -> None: """Test end to end collection construction""" pgvector = PGVector( @@ -159,7 +284,7 @@ def test_pgvector_collection_with_metadata() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with pgvector._session_maker() as session: + with pgvector.session_maker() as session: collection = pgvector.get_collection(session) if collection is None: assert False, "Expected a CollectionStore object but received None" @@ -168,6 +293,26 @@ def test_pgvector_collection_with_metadata() -> None: assert collection.cmetadata == {"foo": "bar"} +@pytest.mark.asyncio +async def test_async_pgvector_collection_with_metadata() -> None: + """Test end to end collection construction""" + pgvector = PGVector( + collection_name="test_collection", + collection_metadata={"foo": "bar"}, + embeddings=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + async_mode=True, + ) + async with pgvector.session_maker() as session: + collection = await pgvector.aget_collection(session) + if collection is None: + assert False, "Expected a CollectionStore object but received None" + else: + assert collection.name == "test_collection" + assert collection.cmetadata == {"foo": "bar"} + + def test_pgvector_delete_docs() -> None: """Add and delete documents.""" texts = ["foo", "bar", "baz"] @@ -182,20 +327,69 @@ def test_pgvector_delete_docs() -> None: pre_delete_collection=True, ) vectorstore.delete(["1", "2"]) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == ["3"] # type: ignore vectorstore.delete(["2", "3"]) # Should not raise on missing ids - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable assert sorted(record.id for record in records) == [] # type: ignore +def test_pgvector_delete_collection() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + vectorstore = PGVector.from_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + vectorstore.delete(collection_only=True) + + +@pytest.mark.asyncio +async def test_async_pgvector_delete_docs() -> None: + """Add and delete documents.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + vectorstore = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + ids=["1", "2", "3"], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + await vectorstore.adelete(["1", "2"]) + async with vectorstore.session_maker() as session: + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.id for record in records) == ["3"] # type: ignore + + await vectorstore.adelete(["2", "3"]) # Should not raise on missing ids + async with vectorstore.session_maker() as session: + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.id for record in records) == [] # type: ignore + + def test_pgvector_index_documents() -> None: """Test adding duplicate documents results in overwrites.""" documents = [ @@ -229,7 +423,7 @@ def test_pgvector_index_documents() -> None: connection=CONNECTION_STRING, pre_delete_collection=True, ) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) # ignoring type error since mypy cannot determine whether # the list is sortable @@ -251,7 +445,7 @@ def test_pgvector_index_documents() -> None: vectorstore.add_documents(documents, ids=[doc.metadata["id"] for doc in documents]) - with vectorstore._session_maker() as session: + with vectorstore.session_maker() as session: records = list(session.query(vectorstore.EmbeddingStore).all()) ordered_records = sorted(records, key=lambda x: x.id) # ignoring type error since mypy cannot determine whether @@ -271,6 +465,88 @@ def test_pgvector_index_documents() -> None: } +@pytest.mark.asyncio +async def test_async_pgvector_index_documents() -> None: + """Test adding duplicate documents results in overwrites.""" + documents = [ + Document( + page_content="there are cats in the pond", + metadata={"id": 1, "location": "pond", "topic": "animals"}, + ), + Document( + page_content="ducks are also found in the pond", + metadata={"id": 2, "location": "pond", "topic": "animals"}, + ), + Document( + page_content="fresh apples are available at the market", + metadata={"id": 3, "location": "market", "topic": "food"}, + ), + Document( + page_content="the market also sells fresh oranges", + metadata={"id": 4, "location": "market", "topic": "food"}, + ), + Document( + page_content="the new art exhibit is fascinating", + metadata={"id": 5, "location": "museum", "topic": "art"}, + ), + ] + + vectorstore = await PGVector.afrom_documents( + documents=documents, + collection_name="test_collection_filter", + embedding=FakeEmbeddingsWithAdaDimension(), + ids=[doc.metadata["id"] for doc in documents], + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + async with vectorstore.session_maker() as session: + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert sorted(record.id for record in records) == [ + "1", + "2", + "3", + "4", + "5", + ] + + # Try to overwrite the first document + documents = [ + Document( + page_content="new content in the zoo", + metadata={"id": 1, "location": "zoo", "topic": "zoo"}, + ), + ] + + await vectorstore.aadd_documents( + documents, ids=[doc.metadata["id"] for doc in documents] + ) + + async with vectorstore.session_maker() as session: + records = ( + (await session.execute(select(vectorstore.EmbeddingStore))).scalars().all() + ) + ordered_records = sorted(records, key=lambda x: x.id) + # ignoring type error since mypy cannot determine whether + # the list is sortable + assert [record.id for record in ordered_records] == [ + "1", + "2", + "3", + "4", + "5", + ] + + assert ordered_records[0].cmetadata == { + "id": 1, + "location": "zoo", + "topic": "zoo", + } + + def test_pgvector_relevance_score() -> None: """Test to make sure the relevance score is scaled to 0-1.""" texts = ["foo", "bar", "baz"] @@ -292,6 +568,28 @@ def test_pgvector_relevance_score() -> None: ] +@pytest.mark.asyncio +async def test_async_pgvector_relevance_score() -> None: + """Test to make sure the relevance score is scaled to 0-1.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + + output = await docsearch.asimilarity_search_with_relevance_scores("foo", k=3) + assert output == [ + (Document(page_content="foo", metadata={"page": "0"}), 1.0), + (Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065), + (Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621), + ] + + def test_pgvector_retriever_search_threshold() -> None: """Test using retriever for searching with threshold.""" texts = ["foo", "bar", "baz"] @@ -316,6 +614,31 @@ def test_pgvector_retriever_search_threshold() -> None: ] +@pytest.mark.asyncio +async def test_async_pgvector_retriever_search_threshold() -> None: + """Test using retriever for searching with threshold.""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.999}, + ) + output = await retriever.aget_relevant_documents("summer") + assert output == [ + Document(page_content="foo", metadata={"page": "0"}), + Document(page_content="bar", metadata={"page": "1"}), + ] + + def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: """Test searching with threshold and custom normalization function""" texts = ["foo", "bar", "baz"] @@ -338,6 +661,31 @@ def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None: assert output == [] +@pytest.mark.asyncio +async def test_async_pgvector_retriever_search_threshold_custom_normalization_fn() -> ( + None +): + """Test searching with threshold and custom normalization function""" + texts = ["foo", "bar", "baz"] + metadatas = [{"page": str(i)} for i in range(len(texts))] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + metadatas=metadatas, + connection=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + ) + + retriever = docsearch.as_retriever( + search_type="similarity_score_threshold", + search_kwargs={"k": 3, "score_threshold": 0.5}, + ) + output = await retriever.aget_relevant_documents("foo") + assert output == [] + + def test_pgvector_max_marginal_relevance_search() -> None: """Test max marginal relevance search.""" texts = ["foo", "bar", "baz"] @@ -352,6 +700,21 @@ def test_pgvector_max_marginal_relevance_search() -> None: assert output == [Document(page_content="foo")] +@pytest.mark.asyncio +async def test_async_pgvector_max_marginal_relevance_search() -> None: + """Test max marginal relevance search.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.amax_marginal_relevance_search("foo", k=1, fetch_k=3) + assert output == [Document(page_content="foo")] + + def test_pgvector_max_marginal_relevance_search_with_score() -> None: """Test max marginal relevance search with relevance scores.""" texts = ["foo", "bar", "baz"] @@ -366,6 +729,23 @@ def test_pgvector_max_marginal_relevance_search_with_score() -> None: assert output == [(Document(page_content="foo"), 0.0)] +@pytest.mark.asyncio +async def test_async_pgvector_max_marginal_relevance_search_with_score() -> None: + """Test max marginal relevance search with relevance scores.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.amax_marginal_relevance_search_with_score( + "foo", k=1, fetch_k=3 + ) + assert output == [(Document(page_content="foo"), 0.0)] + + def test_pgvector_with_custom_connection() -> None: """Test construction using a custom connection.""" texts = ["foo", "bar", "baz"] @@ -380,6 +760,21 @@ def test_pgvector_with_custom_connection() -> None: assert output == [Document(page_content="foo")] +@pytest.mark.asyncio +async def test_async_pgvector_with_custom_connection() -> None: + """Test construction using a custom connection.""" + texts = ["foo", "bar", "baz"] + docsearch = await PGVector.afrom_texts( + texts=texts, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + ) + output = await docsearch.asimilarity_search("foo", k=1) + assert output == [Document(page_content="foo")] + + def test_pgvector_with_custom_engine_args() -> None: """Test construction using custom engine arguments.""" texts = ["foo", "bar", "baz"] @@ -412,6 +807,26 @@ def pgvector() -> Generator[PGVector, None, None]: yield vector_store +@pytest.mark.asyncio +@pytest.fixture +async def async_pgvector() -> AsyncGenerator[PGVector, None]: + """Create an async PGVector instance.""" + store = await PGVector.afrom_documents( + documents=DOCUMENTS, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + use_jsonb=True, + ) + try: + yield store + # Do clean up + finally: + await store.adrop_tables() + + @contextlib.contextmanager def get_vectorstore() -> Generator[PGVector, None, None]: """Get a pre-populated-vectorstore""" @@ -430,6 +845,24 @@ def get_vectorstore() -> Generator[PGVector, None, None]: store.drop_tables() +@contextlib.asynccontextmanager +async def aget_vectorstore() -> AsyncGenerator[PGVector, None]: + """Get a pre-populated-vectorstore""" + store = await PGVector.afrom_documents( + documents=DOCUMENTS, + collection_name="test_collection", + embedding=FakeEmbeddingsWithAdaDimension(), + connection=CONNECTION_STRING, + pre_delete_collection=True, + relevance_score_fn=lambda d: d * 0, + use_jsonb=True, + ) + try: + yield store + finally: + await store.adrop_tables() + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_1( test_filter: Dict[str, Any], @@ -441,6 +874,18 @@ def test_pgvector_with_with_metadata_filters_1( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_1_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_1( + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + async with aget_vectorstore() as pgvector: + docs = await pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_2( pgvector: PGVector, @@ -452,6 +897,18 @@ def test_pgvector_with_with_metadata_filters_2( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_2_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_2( + async_pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_3( pgvector: PGVector, @@ -463,6 +920,18 @@ def test_pgvector_with_with_metadata_filters_3( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_3_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_3( + async_pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_4( pgvector: PGVector, @@ -474,6 +943,18 @@ def test_pgvector_with_with_metadata_filters_4( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_4_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_4( + async_pgvector: PGVector, + test_filter: Dict[str, Any], + expected_ids: List[int], +) -> None: + """Test end to end construction and search.""" + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) + assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter + + @pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) def test_pgvector_with_with_metadata_filters_5( pgvector: PGVector, @@ -485,14 +966,15 @@ def test_pgvector_with_with_metadata_filters_5( assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter -@pytest.mark.parametrize("test_filter, expected_ids", TYPE_6_FILTERING_TEST_CASES) -def test_pgvector_with_with_metadata_filters_6( - pgvector: PGVector, +@pytest.mark.asyncio +@pytest.mark.parametrize("test_filter, expected_ids", TYPE_5_FILTERING_TEST_CASES) +async def test_async_pgvector_with_with_metadata_filters_5( + async_pgvector: PGVector, test_filter: Dict[str, Any], expected_ids: List[int], ) -> None: """Test end to end construction and search.""" - docs = pgvector.similarity_search("meow", k=5, filter=test_filter) + docs = await async_pgvector.asimilarity_search("meow", k=5, filter=test_filter) assert [doc.metadata["id"] for doc in docs] == expected_ids, test_filter