Skip to content

Commit f71ce8f

Browse files
ppradoseyurtsev
authored andcommitted
community[minor]: Add native async support to SQLChatMessageHistory (#22065)
# package community: Fix SQLChatMessageHistory ## Description Here is a rewrite of `SQLChatMessageHistory` to properly implement the asynchronous approach. The code circumvents [issue 22021](#22021) by accepting a synchronous call to `def add_messages()` in an asynchronous scenario. This bypasses the bug. For the same reasons as in [PR 22](langchain-ai/langchain-postgres#32) of `langchain-postgres`, we use a lazy strategy for table creation. Indeed, the promise of the constructor cannot be fulfilled without this. It is not possible to invoke a synchronous call in a constructor. We compensate for this by waiting for the next asynchronous method call to create the table. The goal of the `PostgresChatMessageHistory` class (in `langchain-postgres`) is, among other things, to be able to recycle database connections. The implementation of the class is problematic, as we have demonstrated in [issue 22021](#22021). Our new implementation of `SQLChatMessageHistory` achieves this by using a singleton of type (`Async`)`Engine` for the database connection. The connection pool is managed by this singleton, and the code is then reentrant. We also accept the type `str` (optionally complemented by `async_mode`. I know you don't like this much, but it's the only way to allow an asynchronous connection string). In order to unify the different classes handling database connections, we have renamed `connection_string` to `connection`, and `Session` to `session_maker`. Now, a single transaction is used to add a list of messages. Thus, a crash during this write operation will not leave the database in an unstable state with a partially added message list. This makes the code resilient. We believe that the `PostgresChatMessageHistory` class is no longer necessary and can be replaced by: ``` PostgresChatMessageHistory = SQLChatMessageHistory ``` This also fixes the bug. ## Issue - [issue 22021](#22021) - Bug in _exit_history() - Bugs in PostgresChatMessageHistory and sync usage - Bugs in PostgresChatMessageHistory and async usage - [issue 36](langchain-ai/langchain-postgres#36) ## Twitter handle: pprados ## Tests - libs/community/tests/unit_tests/chat_message_histories/test_sql.py (add async test) @baskaryan, @eyurtsev or @hwchase17 can you check this PR ? And, I've been waiting a long time for validation from other PRs. Can you take a look? - [PR 32](langchain-ai/langchain-postgres#32) - [PR 15575](#15575) - [PR 13200](#13200) --------- Co-authored-by: Eugene Yurtsev <[email protected]>
1 parent d17efe3 commit f71ce8f

File tree

4 files changed

+325
-45
lines changed

4 files changed

+325
-45
lines changed

libs/community/langchain_community/chat_message_histories/sql.py

+188-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
import asyncio
2+
import contextlib
13
import json
24
import logging
35
from abc import ABC, abstractmethod
4-
from typing import Any, List, Optional
6+
from typing import (
7+
Any,
8+
AsyncGenerator,
9+
Dict,
10+
Generator,
11+
List,
12+
Optional,
13+
Sequence,
14+
Union,
15+
cast,
16+
)
517

6-
from sqlalchemy import Column, Integer, Text, create_engine
18+
from langchain_core._api import deprecated, warn_deprecated
19+
from sqlalchemy import Column, Integer, Text, delete, select
720

821
try:
922
from sqlalchemy.orm import declarative_base
@@ -15,7 +28,22 @@
1528
message_to_dict,
1629
messages_from_dict,
1730
)
18-
from sqlalchemy.orm import sessionmaker
31+
from sqlalchemy import create_engine
32+
from sqlalchemy.engine import Engine
33+
from sqlalchemy.ext.asyncio import (
34+
AsyncEngine,
35+
AsyncSession,
36+
async_sessionmaker,
37+
create_async_engine,
38+
)
39+
from sqlalchemy.orm import (
40+
Session as SQLSession,
41+
)
42+
from sqlalchemy.orm import (
43+
declarative_base,
44+
scoped_session,
45+
sessionmaker,
46+
)
1947

2048
logger = logging.getLogger(__name__)
2149

@@ -80,36 +108,98 @@ def get_sql_model_class(self) -> Any:
80108
return self.model_class
81109

82110

111+
DBConnection = Union[AsyncEngine, Engine, str]
112+
113+
_warned_once_already = False
114+
115+
83116
class SQLChatMessageHistory(BaseChatMessageHistory):
84117
"""Chat message history stored in an SQL database."""
85118

119+
@property
120+
@deprecated("0.2.2", removal="0.3.0", alternative="session_maker")
121+
def Session(self) -> Union[scoped_session, async_sessionmaker]:
122+
return self.session_maker
123+
86124
def __init__(
87125
self,
88126
session_id: str,
89-
connection_string: str,
127+
connection_string: Optional[str] = None,
90128
table_name: str = "message_store",
91129
session_id_field_name: str = "session_id",
92130
custom_message_converter: Optional[BaseMessageConverter] = None,
131+
connection: Union[None, DBConnection] = None,
132+
engine_args: Optional[Dict[str, Any]] = None,
133+
async_mode: Optional[bool] = None, # Use only if connection is a string
93134
):
94-
self.connection_string = connection_string
95-
self.engine = create_engine(connection_string, echo=False)
135+
assert not (
136+
connection_string and connection
137+
), "connection_string and connection are mutually exclusive"
138+
if connection_string:
139+
global _warned_once_already
140+
if not _warned_once_already:
141+
warn_deprecated(
142+
since="0.2.2",
143+
removal="0.3.0",
144+
name="connection_string",
145+
alternative="Use connection instead",
146+
)
147+
_warned_once_already = True
148+
connection = connection_string
149+
self.connection_string = connection_string
150+
if isinstance(connection, str):
151+
self.async_mode = async_mode
152+
if async_mode:
153+
self.async_engine = create_async_engine(
154+
connection, **(engine_args or {})
155+
)
156+
else:
157+
self.engine = create_engine(url=connection, **(engine_args or {}))
158+
elif isinstance(connection, Engine):
159+
self.async_mode = False
160+
self.engine = connection
161+
elif isinstance(connection, AsyncEngine):
162+
self.async_mode = True
163+
self.async_engine = connection
164+
else:
165+
raise ValueError(
166+
"connection should be a connection string or an instance of "
167+
"sqlalchemy.engine.Engine or sqlalchemy.ext.asyncio.engine.AsyncEngine"
168+
)
169+
170+
# To be consistent with others SQL implementations, rename to session_maker
171+
self.session_maker: Union[scoped_session, async_sessionmaker]
172+
if self.async_mode:
173+
self.session_maker = async_sessionmaker(bind=self.async_engine)
174+
else:
175+
self.session_maker = scoped_session(sessionmaker(bind=self.engine))
176+
96177
self.session_id_field_name = session_id_field_name
97178
self.converter = custom_message_converter or DefaultMessageConverter(table_name)
98179
self.sql_model_class = self.converter.get_sql_model_class()
99180
if not hasattr(self.sql_model_class, session_id_field_name):
100181
raise ValueError("SQL model class must have session_id column")
101-
self._create_table_if_not_exists()
182+
self._table_created = False
183+
if not self.async_mode:
184+
self._create_table_if_not_exists()
102185

103186
self.session_id = session_id
104-
self.Session = sessionmaker(self.engine)
105187

106188
def _create_table_if_not_exists(self) -> None:
107189
self.sql_model_class.metadata.create_all(self.engine)
190+
self._table_created = True
191+
192+
async def _acreate_table_if_not_exists(self) -> None:
193+
if not self._table_created:
194+
assert self.async_mode, "This method must be called with async_mode"
195+
async with self.async_engine.begin() as conn:
196+
await conn.run_sync(self.sql_model_class.metadata.create_all)
197+
self._table_created = True
108198

109199
@property
110200
def messages(self) -> List[BaseMessage]: # type: ignore
111201
"""Retrieve all messages from db"""
112-
with self.Session() as session:
202+
with self._make_sync_session() as session:
113203
result = (
114204
session.query(self.sql_model_class)
115205
.where(
@@ -123,18 +213,105 @@ def messages(self) -> List[BaseMessage]: # type: ignore
123213
messages.append(self.converter.from_sql_model(record))
124214
return messages
125215

216+
def get_messages(self) -> List[BaseMessage]:
217+
return self.messages
218+
219+
async def aget_messages(self) -> List[BaseMessage]:
220+
"""Retrieve all messages from db"""
221+
await self._acreate_table_if_not_exists()
222+
async with self._make_async_session() as session:
223+
stmt = (
224+
select(self.sql_model_class)
225+
.where(
226+
getattr(self.sql_model_class, self.session_id_field_name)
227+
== self.session_id
228+
)
229+
.order_by(self.sql_model_class.id.asc())
230+
)
231+
result = await session.execute(stmt)
232+
messages = []
233+
for record in result.scalars():
234+
messages.append(self.converter.from_sql_model(record))
235+
return messages
236+
126237
def add_message(self, message: BaseMessage) -> None:
127238
"""Append the message to the record in db"""
128-
with self.Session() as session:
239+
with self._make_sync_session() as session:
129240
session.add(self.converter.to_sql_model(message, self.session_id))
130241
session.commit()
131242

243+
async def aadd_message(self, message: BaseMessage) -> None:
244+
"""Add a Message object to the store.
245+
246+
Args:
247+
message: A BaseMessage object to store.
248+
"""
249+
await self._acreate_table_if_not_exists()
250+
async with self._make_async_session() as session:
251+
session.add(self.converter.to_sql_model(message, self.session_id))
252+
await session.commit()
253+
254+
def add_messages(self, messages: Sequence[BaseMessage]) -> None:
255+
# The method RunnableWithMessageHistory._exit_history() call
256+
# add_message method by mistake and not aadd_message.
257+
# See https://github.com/langchain-ai/langchain/issues/22021
258+
if self.async_mode:
259+
loop = asyncio.get_event_loop()
260+
loop.run_until_complete(self.aadd_messages(messages))
261+
else:
262+
with self._make_sync_session() as session:
263+
for message in messages:
264+
session.add(self.converter.to_sql_model(message, self.session_id))
265+
session.commit()
266+
267+
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
268+
# Add all messages in one transaction
269+
await self._acreate_table_if_not_exists()
270+
async with self.session_maker() as session:
271+
for message in messages:
272+
session.add(self.converter.to_sql_model(message, self.session_id))
273+
await session.commit()
274+
132275
def clear(self) -> None:
133276
"""Clear session memory from db"""
134277

135-
with self.Session() as session:
278+
with self._make_sync_session() as session:
136279
session.query(self.sql_model_class).filter(
137280
getattr(self.sql_model_class, self.session_id_field_name)
138281
== self.session_id
139282
).delete()
140283
session.commit()
284+
285+
async def aclear(self) -> None:
286+
"""Clear session memory from db"""
287+
288+
await self._acreate_table_if_not_exists()
289+
async with self._make_async_session() as session:
290+
stmt = delete(self.sql_model_class).filter(
291+
getattr(self.sql_model_class, self.session_id_field_name)
292+
== self.session_id
293+
)
294+
await session.execute(stmt)
295+
await session.commit()
296+
297+
@contextlib.contextmanager
298+
def _make_sync_session(self) -> Generator[SQLSession, None, None]:
299+
"""Make an async session."""
300+
if self.async_mode:
301+
raise ValueError(
302+
"Attempting to use a sync method in when async mode is turned on. "
303+
"Please use the corresponding async method instead."
304+
)
305+
with self.session_maker() as session:
306+
yield cast(SQLSession, session)
307+
308+
@contextlib.asynccontextmanager
309+
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
310+
"""Make an async session."""
311+
if not self.async_mode:
312+
raise ValueError(
313+
"Attempting to use an async method in when sync mode is turned on. "
314+
"Please use the corresponding async method instead."
315+
)
316+
async with self.session_maker() as session:
317+
yield cast(AsyncSession, session)

libs/community/poetry.lock

+8-13
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/community/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ extended_testing = [
291291
"pyjwt",
292292
"oracledb",
293293
"simsimd",
294+
"aiosqlite"
294295
]
295296

296297
[tool.ruff]

0 commit comments

Comments
 (0)