From dd7d7cd200a355261101bcad105b7da6386952ab Mon Sep 17 00:00:00 2001 From: Philippe Prados Date: Mon, 27 May 2024 08:29:04 +0200 Subject: [PATCH] Align the code with SQLChatMessageHistory --- langchain_postgres/vectorstores.py | 86 ++++++++++++++-------------- tests/unit_tests/test_vectorstore.py | 1 - 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index f00107c..659b6f3 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -1,13 +1,16 @@ # pylint: disable=too-many-lines from __future__ import annotations +import contextlib import enum import logging import uuid from typing import ( Any, + AsyncGenerator, Callable, Dict, + Generator, Iterable, List, Optional, @@ -430,7 +433,6 @@ def embeddings(self) -> Embeddings: return self.embedding_function def create_vector_extension(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" assert self._engine, "engine not found" try: with self._engine.connect() as conn: @@ -439,15 +441,13 @@ def create_vector_extension(self) -> None: raise Exception(f"Failed to create vector extension: {e}") from e async def acreate_vector_extension(self) -> None: - assert self.async_mode, "This method must be called with async_mode" assert self._async_engine, "_async_engine not found" async with self._async_engine.begin() as conn: await conn.run_sync(_create_vector_extension) def create_tables_if_not_exists(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._make_sync_session() as session: Base.metadata.create_all(session.get_bind()) session.commit() @@ -457,8 +457,7 @@ async def acreate_tables_if_not_exists(self) -> None: await conn.run_sync(Base.metadata.create_all) def drop_tables(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._make_sync_session() as session: Base.metadata.drop_all(session.get_bind()) session.commit() @@ -469,19 +468,17 @@ async def adrop_tables(self) -> None: await conn.run_sync(Base.metadata.drop_all) def create_collection(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" if self.pre_delete_collection: self.delete_collection() - with self.session_maker() as session: + with self._make_sync_session() as session: self.CollectionStore.get_or_create( session, self.collection_name, cmetadata=self.collection_metadata ) session.commit() async def acreate_collection(self) -> None: - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._make_async_session() as session: if self.pre_delete_collection: await self._adelete_collection(session) await self.CollectionStore.aget_or_create( @@ -490,7 +487,6 @@ async def acreate_collection(self) -> None: await session.commit() def _delete_collection(self, session: Session) -> None: - self.logger.debug("Trying to delete collection") collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -498,7 +494,6 @@ def _delete_collection(self, session: Session) -> None: session.delete(collection) async def _adelete_collection(self, session: AsyncSession) -> None: - self.logger.debug("Trying to delete collection") collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") @@ -506,9 +501,7 @@ async def _adelete_collection(self, session: AsyncSession) -> None: await session.delete(collection) def delete_collection(self) -> None: - assert not self._async_engine, "This method must be called without async_mode" - self.logger.debug("Trying to delete collection") - with self.session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: collection = self.get_collection(session) if not collection: self.logger.warning("Collection not found") @@ -517,14 +510,13 @@ def delete_collection(self) -> None: session.commit() async def adelete_collection(self) -> None: - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: collection = await self.aget_collection(session) if not collection: self.logger.warning("Collection not found") return - await session.adelete(collection) + await session.delete(collection) await session.commit() def delete( @@ -539,8 +531,7 @@ def delete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - assert not self._async_engine, "This method must be called without async_mode" - with self.session_maker() as session: + with self._make_sync_session() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -575,9 +566,8 @@ async def adelete( ids: List of ids to delete. collection_only: Only delete ids in the collection. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._make_async_session() as session: if ids is not None: self.logger.debug( "Trying to delete vectors by ids (represented by the model " @@ -704,14 +694,14 @@ def add_embeddings( If not provided, will generate a new id for each document. kwargs: vectorstore specific parameters """ - assert not self._async_engine, "This method must be called without async_mode" + assert not self._async_engine, "This method must be called with sync_mode" if ids is None: ids = [str(uuid.uuid4()) for _ in texts] if not metadatas: metadatas = [{} for _ in texts] - with self.session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -760,7 +750,6 @@ async def aadd_embeddings( If not provided, will generate a new id for each text. kwargs: vectorstore specific parameters """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init if ids is None: ids = [str(uuid.uuid1()) for _ in texts] @@ -768,7 +757,7 @@ async def aadd_embeddings( if not metadatas: metadatas = [{} for _ in texts] - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -843,7 +832,6 @@ async def aadd_texts( Returns: List of ids from adding the texts into the vectorstore. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embeddings = await self.embedding_function.aembed_documents(list(texts)) return await self.aadd_embeddings( @@ -892,7 +880,6 @@ async def asimilarity_search( Returns: List of Documents most similar to the query. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(text=query) return await self.asimilarity_search_by_vector( @@ -940,7 +927,6 @@ async def asimilarity_search_with_score( Returns: List of Documents most similar to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.asimilarity_search_with_score_by_vector( @@ -979,9 +965,8 @@ async def asimilarity_search_with_score_by_vector( k: int = 4, filter: Optional[dict] = None, ) -> List[Tuple[Document, float]]: - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: # type: ignore[arg-type] results = await self.__aquery_collection( session=session, embedding=embedding, k=k, filter=filter ) @@ -1301,7 +1286,7 @@ def __query_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - with self.session_maker() as session: # type: ignore[arg-type] + with self._make_sync_session() as session: # type: ignore[arg-type] collection = self.get_collection(session) if not collection: raise ValueError("Collection not found") @@ -1322,7 +1307,7 @@ def __query_collection( results: List[Any] = ( session.query( self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore + self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) @@ -1344,7 +1329,7 @@ async def __aquery_collection( filter: Optional[Dict[str, str]] = None, ) -> Sequence[Any]: """Query the collection.""" - async with self.session_maker() as session: # type: ignore[arg-type] + async with self._make_async_session() as session: # type: ignore[arg-type] collection = await self.aget_collection(session) if not collection: raise ValueError("Collection not found") @@ -1365,7 +1350,7 @@ async def __aquery_collection( stmt = ( select( self.EmbeddingStore, - self.distance_strategy(embedding).label("distance"), # type: ignore + self.distance_strategy(embedding).label("distance"), ) .filter(*filter_by) .order_by(sqlalchemy.asc("distance")) @@ -1848,9 +1833,8 @@ async def amax_marginal_relevance_search_with_score_by_vector( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - async with self.session_maker() as session: + async with self._make_async_session() as session: results = await self.__aquery_collection( session=session, embedding=embedding, k=fetch_k, filter=filter ) @@ -1896,7 +1880,6 @@ def max_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) return self.max_marginal_relevance_search_by_vector( embedding, @@ -1935,7 +1918,6 @@ async def amax_marginal_relevance_search( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) return await self.amax_marginal_relevance_search_by_vector( @@ -1976,7 +1958,6 @@ def max_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert not self._async_engine, "This method must be called without async_mode" embedding = self.embedding_function.embed_query(query) docs = self.max_marginal_relevance_search_with_score_by_vector( embedding=embedding, @@ -2017,7 +1998,6 @@ async def amax_marginal_relevance_search_with_score( List[Tuple[Document, float]]: List of Documents selected by maximal marginal relevance to the query and score for each. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init embedding = self.embedding_function.embed_query(query) docs = await self.amax_marginal_relevance_search_with_score_by_vector( @@ -2059,7 +2039,6 @@ def max_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert not self._async_engine, "This method must be called without async_mode" docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector( embedding, k=k, @@ -2100,7 +2079,6 @@ async def amax_marginal_relevance_search_by_vector( Returns: List[Document]: List of Documents selected by maximal marginal relevance. """ - assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init docs_and_scores = ( await self.amax_marginal_relevance_search_with_score_by_vector( @@ -2114,3 +2092,25 @@ async def amax_marginal_relevance_search_by_vector( ) return _results_to_docs(docs_and_scores) + + @contextlib.contextmanager + def _make_sync_session(self) -> Generator[Session, None, None]: + """Make an async session.""" + if self.async_mode: + raise ValueError( + "Attempting to use a sync method in when async mode is turned on. " + "Please use the corresponding async method instead." + ) + with self.session_maker() as session: + yield typing_cast(Session, session) + + @contextlib.asynccontextmanager + async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]: + """Make an async session.""" + if not self.async_mode: + raise ValueError( + "Attempting to use an async method in when sync mode is turned on. " + "Please use the corresponding async method instead." + ) + async with self.session_maker() as session: + yield typing_cast(AsyncSession, session) diff --git a/tests/unit_tests/test_vectorstore.py b/tests/unit_tests/test_vectorstore.py index 5a4f877..5ad8116 100644 --- a/tests/unit_tests/test_vectorstore.py +++ b/tests/unit_tests/test_vectorstore.py @@ -18,7 +18,6 @@ TYPE_3_FILTERING_TEST_CASES, TYPE_4_FILTERING_TEST_CASES, TYPE_5_FILTERING_TEST_CASES, - TYPE_6_FILTERING_TEST_CASES, ) from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING