Skip to content

Commit

Permalink
WIP for fixing connection pools.
Browse files Browse the repository at this point in the history
  • Loading branch information
asanger committed Nov 27, 2024
1 parent d29231d commit 861dc08
Show file tree
Hide file tree
Showing 7 changed files with 191 additions and 124 deletions.
11 changes: 0 additions & 11 deletions langgraph/checkpoint/mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,6 @@ def get_connection(self) -> C:
Conn = Union[C, ConnectionPool[C]]


@contextmanager
def _get_connection(conn: Conn[C]) -> Iterator[C]:
if hasattr(conn, "cursor"):
yield cast(C, conn)
elif hasattr(conn, "get_connection"):
with cast(ConnectionPool[C], conn).get_connection() as _conn:
yield _conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")


class BaseSyncMySQLSaver(BaseMySQLSaver, Generic[C, R]):
lock: threading.Lock

Expand Down
23 changes: 23 additions & 0 deletions langgraph/checkpoint/mysql/_ainternal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Shared async utility functions for the AIOMYSQL checkpoint & storage classes."""

from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union

Check failure on line 4 in langgraph/checkpoint/mysql/_ainternal.py

View workflow job for this annotation

GitHub Actions / lint / lint #3.12

Ruff (F401)

langgraph/checkpoint/mysql/_ainternal.py:4:20: F401 `typing.Any` imported but unused

Check failure on line 4 in langgraph/checkpoint/mysql/_ainternal.py

View workflow job for this annotation

GitHub Actions / lint / lint #3.12

Ruff (F401)

langgraph/checkpoint/mysql/_ainternal.py:4:40: F401 `typing.Iterator` imported but unused

Check failure on line 4 in langgraph/checkpoint/mysql/_ainternal.py

View workflow job for this annotation

GitHub Actions / lint / lint #3.12

Ruff (F401)

langgraph/checkpoint/mysql/_ainternal.py:4:50: F401 `typing.Optional` imported but unused

Check failure on line 4 in langgraph/checkpoint/mysql/_ainternal.py

View workflow job for this annotation

GitHub Actions / lint / lint #3.12

Ruff (F401)

langgraph/checkpoint/mysql/_ainternal.py:4:60: F401 `typing.Sequence` imported but unused

import aiomysql
import logging

logger = logging.getLogger(__name__)

Check failure on line 9 in langgraph/checkpoint/mysql/_ainternal.py

View workflow job for this annotation

GitHub Actions / lint / lint #3.12

Ruff (I001)

langgraph/checkpoint/mysql/_ainternal.py:3:1: I001 Import block is un-sorted or un-formatted
Conn = Union[aiomysql.Connection, aiomysql.Pool]

@asynccontextmanager
async def get_connection(
conn: Conn,
) -> AsyncIterator[aiomysql.Connection]:
if isinstance(conn, aiomysql.Connection):
yield conn
elif isinstance(conn, aiomysql.Pool):
async with conn.acquire() as _conn:
await _conn.set_charset("utf8mb4")
yield _conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")
68 changes: 31 additions & 37 deletions langgraph/checkpoint/mysql/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
CheckpointTuple,
get_checkpoint_id,
)
from langgraph.checkpoint.mysql import _ainternal
from langgraph.checkpoint.mysql.base import BaseMySQLSaver
from langgraph.checkpoint.mysql.utils import (
deserialize_channel_values,
Expand All @@ -26,28 +27,18 @@
)
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = Union[aiomysql.Connection, aiomysql.Pool]
import logging

logger = logging.getLogger(__name__)

Check failure on line 32 in langgraph/checkpoint/mysql/aio.py

View workflow job for this annotation

GitHub Actions / lint / lint #3.12

Ruff (I001)

langgraph/checkpoint/mysql/aio.py:1:1: I001 Import block is un-sorted or un-formatted

@asynccontextmanager
async def _get_connection(
conn: Conn,
) -> AsyncIterator[aiomysql.Connection]:
if isinstance(conn, aiomysql.Connection):
yield conn
elif isinstance(conn, aiomysql.Pool):
async with conn.acquire() as _conn:
yield _conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")

Conn = _ainternal.Conn # For backward compatibility

class AIOMySQLSaver(BaseMySQLSaver):
lock: asyncio.Lock

def __init__(
self,
conn: Conn,
conn: _ainternal.Conn,
serde: Optional[SerializerProtocol] = None,
) -> None:
super().__init__(serde=serde)
Expand Down Expand Up @@ -82,20 +73,17 @@ async def from_conn_string(
# This is necessary when using a unix socket, for example.
params_as_dict = dict(urllib.parse.parse_qsl(parsed.query))

async with aiomysql.connect(
async with aiomysql.create_pool(
host=parsed.hostname or "localhost",
user=parsed.username,
password=parsed.password or "",
db=parsed.path[1:],
port=parsed.port or 3306,
unix_socket=params_as_dict.get("unix_socket"),
autocommit=True,
) as conn:
# This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119
# is merged into aiomysql.
await conn.set_charset(pymysql.connections.DEFAULT_CHARSET)
) as conn:

yield AIOMySQLSaver(conn=conn, serde=serde)
yield cls(conn=conn, serde=serde)

async def setup(self) -> None:
"""Set up the checkpoint database asynchronously.
Expand Down Expand Up @@ -170,15 +158,17 @@ async def alist(
deserialize_pending_sends(value["pending_sends"]),
),
self._load_metadata(value["metadata"]),
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["parent_checkpoint_id"],
(
{
"configurable": {
"thread_id": value["thread_id"],
"checkpoint_ns": value["checkpoint_ns"],
"checkpoint_id": value["parent_checkpoint_id"],
}
if value["parent_checkpoint_id"]
else None
}
}
if value["parent_checkpoint_id"]
else None,
),
await asyncio.to_thread(
self._load_writes,
deserialize_pending_writes(value["pending_writes"]),
Expand Down Expand Up @@ -231,15 +221,17 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
deserialize_pending_sends(value["pending_sends"]),
),
self._load_metadata(value["metadata"]),
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": value["parent_checkpoint_id"],
(
{
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": checkpoint_ns,
"checkpoint_id": value["parent_checkpoint_id"],
}
}
}
if value["parent_checkpoint_id"]
else None,
if value["parent_checkpoint_id"]
else None
),
await asyncio.to_thread(
self._load_writes,
deserialize_pending_writes(value["pending_writes"]),
Expand Down Expand Up @@ -340,7 +332,7 @@ async def aput_writes(

@asynccontextmanager
async def _cursor(self) -> AsyncIterator[aiomysql.DictCursor]:
async with _get_connection(self.conn) as conn:
async with _ainternal.get_connection(self.conn) as conn:
async with self.lock, conn.cursor(aiomysql.DictCursor) as cur:
yield cur

Expand Down Expand Up @@ -449,3 +441,5 @@ def put_writes(
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
).result()

__all__ = ["AIOMySQLSaver", "Conn"]
Empty file.
Loading

0 comments on commit 861dc08

Please sign in to comment.