Skip to content

Commit

Permalink
Merge branch 'pprados/fix_chat_message_history' into pprados/async
Browse files Browse the repository at this point in the history
  • Loading branch information
pprados committed Jun 10, 2024
2 parents 63648c8 + dd7d7cd commit ed5b87b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 44 deletions.
86 changes: 43 additions & 43 deletions langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# pylint: disable=too-many-lines
from __future__ import annotations

import contextlib
import enum
import logging
import uuid
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -430,7 +433,6 @@ def embeddings(self) -> Embeddings:
return self.embedding_function

def create_vector_extension(self) -> None:
assert not self._async_engine, "This method must be called without async_mode"
assert self._engine, "engine not found"
try:
with self._engine.connect() as conn:
Expand All @@ -439,15 +441,13 @@ def create_vector_extension(self) -> None:
raise Exception(f"Failed to create vector extension: {e}") from e

async def acreate_vector_extension(self) -> None:
assert self.async_mode, "This method must be called with async_mode"
assert self._async_engine, "_async_engine not found"

async with self._async_engine.begin() as conn:
await conn.run_sync(_create_vector_extension)

def create_tables_if_not_exists(self) -> None:
assert not self._async_engine, "This method must be called without async_mode"
with self.session_maker() as session:
with self._make_sync_session() as session:
Base.metadata.create_all(session.get_bind())
session.commit()

Expand All @@ -457,8 +457,7 @@ async def acreate_tables_if_not_exists(self) -> None:
await conn.run_sync(Base.metadata.create_all)

def drop_tables(self) -> None:
assert not self._async_engine, "This method must be called without async_mode"
with self.session_maker() as session:
with self._make_sync_session() as session:
Base.metadata.drop_all(session.get_bind())
session.commit()

Expand All @@ -469,19 +468,17 @@ async def adrop_tables(self) -> None:
await conn.run_sync(Base.metadata.drop_all)

def create_collection(self) -> None:
assert not self._async_engine, "This method must be called without async_mode"
if self.pre_delete_collection:
self.delete_collection()
with self.session_maker() as session:
with self._make_sync_session() as session:
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
session.commit()

async def acreate_collection(self) -> None:
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
async with self.session_maker() as session:
async with self._make_async_session() as session:
if self.pre_delete_collection:
await self._adelete_collection(session)
await self.CollectionStore.aget_or_create(
Expand All @@ -490,25 +487,21 @@ async def acreate_collection(self) -> None:
await session.commit()

def _delete_collection(self, session: Session) -> None:
self.logger.debug("Trying to delete collection")
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
return
session.delete(collection)

async def _adelete_collection(self, session: AsyncSession) -> None:
self.logger.debug("Trying to delete collection")
collection = await self.aget_collection(session)
if not collection:
self.logger.warning("Collection not found")
return
await session.delete(collection)

def delete_collection(self) -> None:
assert not self._async_engine, "This method must be called without async_mode"
self.logger.debug("Trying to delete collection")
with self.session_maker() as session: # type: ignore[arg-type]
with self._make_sync_session() as session:
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
Expand All @@ -517,14 +510,13 @@ def delete_collection(self) -> None:
session.commit()

async def adelete_collection(self) -> None:
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
async with self.session_maker() as session: # type: ignore[arg-type]
async with self._make_async_session() as session:
collection = await self.aget_collection(session)
if not collection:
self.logger.warning("Collection not found")
return
await session.adelete(collection)
await session.delete(collection)
await session.commit()

def delete(
Expand All @@ -539,8 +531,7 @@ def delete(
ids: List of ids to delete.
collection_only: Only delete ids in the collection.
"""
assert not self._async_engine, "This method must be called without async_mode"
with self.session_maker() as session:
with self._make_sync_session() as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
Expand Down Expand Up @@ -575,9 +566,8 @@ async def adelete(
ids: List of ids to delete.
collection_only: Only delete ids in the collection.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
async with self.session_maker() as session:
async with self._make_async_session() as session:
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
Expand Down Expand Up @@ -704,14 +694,14 @@ def add_embeddings(
If not provided, will generate a new id for each document.
kwargs: vectorstore specific parameters
"""
assert not self._async_engine, "This method must be called without async_mode"
assert not self._async_engine, "This method must be called with sync_mode"
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]

if not metadatas:
metadatas = [{} for _ in texts]

with self.session_maker() as session: # type: ignore[arg-type]
with self._make_sync_session() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand Down Expand Up @@ -760,15 +750,14 @@ async def aadd_embeddings(
If not provided, will generate a new id for each text.
kwargs: vectorstore specific parameters
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
if ids is None:
ids = [str(uuid.uuid1()) for _ in texts]

if not metadatas:
metadatas = [{} for _ in texts]

async with self.session_maker() as session: # type: ignore[arg-type]
async with self._make_async_session() as session: # type: ignore[arg-type]
collection = await self.aget_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand Down Expand Up @@ -843,7 +832,6 @@ async def aadd_texts(
Returns:
List of ids from adding the texts into the vectorstore.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
embeddings = await self.embedding_function.aembed_documents(list(texts))
return await self.aadd_embeddings(
Expand Down Expand Up @@ -892,7 +880,6 @@ async def asimilarity_search(
Returns:
List of Documents most similar to the query.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
embedding = self.embedding_function.embed_query(text=query)
return await self.asimilarity_search_by_vector(
Expand Down Expand Up @@ -940,7 +927,6 @@ async def asimilarity_search_with_score(
Returns:
List of Documents most similar to the query and score for each.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
embedding = self.embedding_function.embed_query(query)
docs = await self.asimilarity_search_with_score_by_vector(
Expand Down Expand Up @@ -979,9 +965,8 @@ async def asimilarity_search_with_score_by_vector(
k: int = 4,
filter: Optional[dict] = None,
) -> List[Tuple[Document, float]]:
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
async with self.session_maker() as session: # type: ignore[arg-type]
async with self._make_async_session() as session: # type: ignore[arg-type]
results = await self.__aquery_collection(
session=session, embedding=embedding, k=k, filter=filter
)
Expand Down Expand Up @@ -1301,7 +1286,7 @@ def __query_collection(
filter: Optional[Dict[str, str]] = None,
) -> Sequence[Any]:
"""Query the collection."""
with self.session_maker() as session: # type: ignore[arg-type]
with self._make_sync_session() as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand All @@ -1322,7 +1307,7 @@ def __query_collection(
results: List[Any] = (
session.query(
self.EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore
self.distance_strategy(embedding).label("distance"),
)
.filter(*filter_by)
.order_by(sqlalchemy.asc("distance"))
Expand All @@ -1344,7 +1329,7 @@ async def __aquery_collection(
filter: Optional[Dict[str, str]] = None,
) -> Sequence[Any]:
"""Query the collection."""
async with self.session_maker() as session: # type: ignore[arg-type]
async with self._make_async_session() as session: # type: ignore[arg-type]
collection = await self.aget_collection(session)
if not collection:
raise ValueError("Collection not found")
Expand All @@ -1365,7 +1350,7 @@ async def __aquery_collection(
stmt = (
select(
self.EmbeddingStore,
self.distance_strategy(embedding).label("distance"), # type: ignore
self.distance_strategy(embedding).label("distance"),
)
.filter(*filter_by)
.order_by(sqlalchemy.asc("distance"))
Expand Down Expand Up @@ -1848,9 +1833,8 @@ async def amax_marginal_relevance_search_with_score_by_vector(
List[Tuple[Document, float]]: List of Documents selected by maximal marginal
relevance to the query and score for each.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
async with self.session_maker() as session:
async with self._make_async_session() as session:
results = await self.__aquery_collection(
session=session, embedding=embedding, k=fetch_k, filter=filter
)
Expand Down Expand Up @@ -1896,7 +1880,6 @@ def max_marginal_relevance_search(
Returns:
List[Document]: List of Documents selected by maximal marginal relevance.
"""
assert not self._async_engine, "This method must be called without async_mode"
embedding = self.embedding_function.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding,
Expand Down Expand Up @@ -1935,7 +1918,6 @@ async def amax_marginal_relevance_search(
Returns:
List[Document]: List of Documents selected by maximal marginal relevance.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
embedding = self.embedding_function.embed_query(query)
return await self.amax_marginal_relevance_search_by_vector(
Expand Down Expand Up @@ -1976,7 +1958,6 @@ def max_marginal_relevance_search_with_score(
List[Tuple[Document, float]]: List of Documents selected by maximal marginal
relevance to the query and score for each.
"""
assert not self._async_engine, "This method must be called without async_mode"
embedding = self.embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_with_score_by_vector(
embedding=embedding,
Expand Down Expand Up @@ -2017,7 +1998,6 @@ async def amax_marginal_relevance_search_with_score(
List[Tuple[Document, float]]: List of Documents selected by maximal marginal
relevance to the query and score for each.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
embedding = self.embedding_function.embed_query(query)
docs = await self.amax_marginal_relevance_search_with_score_by_vector(
Expand Down Expand Up @@ -2059,7 +2039,6 @@ def max_marginal_relevance_search_by_vector(
Returns:
List[Document]: List of Documents selected by maximal marginal relevance.
"""
assert not self._async_engine, "This method must be called without async_mode"
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
embedding,
k=k,
Expand Down Expand Up @@ -2100,7 +2079,6 @@ async def amax_marginal_relevance_search_by_vector(
Returns:
List[Document]: List of Documents selected by maximal marginal relevance.
"""
assert self._async_engine, "This method must be called with async_mode"
await self.__apost_init__() # Lazy async init
docs_and_scores = (
await self.amax_marginal_relevance_search_with_score_by_vector(
Expand All @@ -2114,3 +2092,25 @@ async def amax_marginal_relevance_search_by_vector(
)

return _results_to_docs(docs_and_scores)

@contextlib.contextmanager
def _make_sync_session(self) -> Generator[Session, None, None]:
"""Make an async session."""
if self.async_mode:
raise ValueError(
"Attempting to use a sync method in when async mode is turned on. "
"Please use the corresponding async method instead."
)
with self.session_maker() as session:
yield typing_cast(Session, session)

@contextlib.asynccontextmanager
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Make an async session."""
if not self.async_mode:
raise ValueError(
"Attempting to use an async method in when sync mode is turned on. "
"Please use the corresponding async method instead."
)
async with self.session_maker() as session:
yield typing_cast(AsyncSession, session)
1 change: 0 additions & 1 deletion tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TYPE_3_FILTERING_TEST_CASES,
TYPE_4_FILTERING_TEST_CASES,
TYPE_5_FILTERING_TEST_CASES,
TYPE_6_FILTERING_TEST_CASES,
)
from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING

Expand Down

0 comments on commit ed5b87b

Please sign in to comment.