diff --git a/langgraph/checkpoint/mysql/__init__.py b/langgraph/checkpoint/mysql/__init__.py index 8e207e6..6a070cc 100644 --- a/langgraph/checkpoint/mysql/__init__.py +++ b/langgraph/checkpoint/mysql/__init__.py @@ -69,17 +69,6 @@ def get_connection(self) -> C: Conn = Union[C, ConnectionPool[C]] -@contextmanager -def _get_connection(conn: Conn[C]) -> Iterator[C]: - if hasattr(conn, "cursor"): - yield cast(C, conn) - elif hasattr(conn, "get_connection"): - with cast(ConnectionPool[C], conn).get_connection() as _conn: - yield _conn - else: - raise TypeError(f"Invalid connection type: {type(conn)}") - - class BaseSyncMySQLSaver(BaseMySQLSaver, Generic[C, R]): lock: threading.Lock diff --git a/langgraph/checkpoint/mysql/_ainternal.py b/langgraph/checkpoint/mysql/_ainternal.py new file mode 100644 index 0000000..45e112c --- /dev/null +++ b/langgraph/checkpoint/mysql/_ainternal.py @@ -0,0 +1,23 @@ +"""Shared async utility functions for the AIOMYSQL checkpoint & storage classes.""" + +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union + +import aiomysql +import logging + +logger = logging.getLogger(__name__) +Conn = Union[aiomysql.Connection, aiomysql.Pool] + +@asynccontextmanager +async def get_connection( + conn: Conn, +) -> AsyncIterator[aiomysql.Connection]: + if isinstance(conn, aiomysql.Connection): + yield conn + elif isinstance(conn, aiomysql.Pool): + async with conn.acquire() as _conn: + await _conn.set_charset("utf8mb4") + yield _conn + else: + raise TypeError(f"Invalid connection type: {type(conn)}") diff --git a/langgraph/checkpoint/mysql/aio.py b/langgraph/checkpoint/mysql/aio.py index cf841b0..6acbdfc 100644 --- a/langgraph/checkpoint/mysql/aio.py +++ b/langgraph/checkpoint/mysql/aio.py @@ -18,6 +18,7 @@ CheckpointTuple, get_checkpoint_id, ) +from langgraph.checkpoint.mysql import _ainternal from langgraph.checkpoint.mysql.base import BaseMySQLSaver from langgraph.checkpoint.mysql.utils import ( deserialize_channel_values, @@ -26,28 +27,18 @@ ) from langgraph.checkpoint.serde.base import SerializerProtocol -Conn = Union[aiomysql.Connection, aiomysql.Pool] +import logging +logger = logging.getLogger(__name__) -@asynccontextmanager -async def _get_connection( - conn: Conn, -) -> AsyncIterator[aiomysql.Connection]: - if isinstance(conn, aiomysql.Connection): - yield conn - elif isinstance(conn, aiomysql.Pool): - async with conn.acquire() as _conn: - yield _conn - else: - raise TypeError(f"Invalid connection type: {type(conn)}") - +Conn = _ainternal.Conn # For backward compatibility class AIOMySQLSaver(BaseMySQLSaver): lock: asyncio.Lock def __init__( self, - conn: Conn, + conn: _ainternal.Conn, serde: Optional[SerializerProtocol] = None, ) -> None: super().__init__(serde=serde) @@ -82,7 +73,7 @@ async def from_conn_string( # This is necessary when using a unix socket, for example. params_as_dict = dict(urllib.parse.parse_qsl(parsed.query)) - async with aiomysql.connect( + async with aiomysql.create_pool( host=parsed.hostname or "localhost", user=parsed.username, password=parsed.password or "", @@ -90,12 +81,9 @@ async def from_conn_string( port=parsed.port or 3306, unix_socket=params_as_dict.get("unix_socket"), autocommit=True, - ) as conn: - # This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119 - # is merged into aiomysql. - await conn.set_charset(pymysql.connections.DEFAULT_CHARSET) + ) as conn: - yield AIOMySQLSaver(conn=conn, serde=serde) + yield cls(conn=conn, serde=serde) async def setup(self) -> None: """Set up the checkpoint database asynchronously. @@ -170,15 +158,17 @@ async def alist( deserialize_pending_sends(value["pending_sends"]), ), self._load_metadata(value["metadata"]), - { - "configurable": { - "thread_id": value["thread_id"], - "checkpoint_ns": value["checkpoint_ns"], - "checkpoint_id": value["parent_checkpoint_id"], + ( + { + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["parent_checkpoint_id"], + } + if value["parent_checkpoint_id"] + else None } - } - if value["parent_checkpoint_id"] - else None, + ), await asyncio.to_thread( self._load_writes, deserialize_pending_writes(value["pending_writes"]), @@ -231,15 +221,17 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: deserialize_pending_sends(value["pending_sends"]), ), self._load_metadata(value["metadata"]), - { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": value["parent_checkpoint_id"], + ( + { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": value["parent_checkpoint_id"], + } } - } - if value["parent_checkpoint_id"] - else None, + if value["parent_checkpoint_id"] + else None + ), await asyncio.to_thread( self._load_writes, deserialize_pending_writes(value["pending_writes"]), @@ -340,7 +332,7 @@ async def aput_writes( @asynccontextmanager async def _cursor(self) -> AsyncIterator[aiomysql.DictCursor]: - async with _get_connection(self.conn) as conn: + async with _ainternal.get_connection(self.conn) as conn: async with self.lock, conn.cursor(aiomysql.DictCursor) as cur: yield cur @@ -449,3 +441,5 @@ def put_writes( return asyncio.run_coroutine_threadsafe( self.aput_writes(config, writes, task_id), self.loop ).result() + +__all__ = ["AIOMySQLSaver", "Conn"] diff --git a/langgraph/checkpoint/mysql/py.typed b/langgraph/checkpoint/mysql/py.typed deleted file mode 100644 index e69de29..0000000 diff --git a/langgraph/store/mysql/aio.py b/langgraph/store/mysql/aio.py index eb9cb61..75693dc 100644 --- a/langgraph/store/mysql/aio.py +++ b/langgraph/store/mysql/aio.py @@ -18,10 +18,13 @@ import pymysql import pymysql.constants.ER +from langgraph.checkpoint.mysql import _ainternal + from langgraph.store.base import GetOp, ListNamespacesOp, Op, PutOp, Result, SearchOp from langgraph.store.base.batch import AsyncBatchedBaseStore from langgraph.store.mysql.base import ( BaseMySQLStore, + PoolConfig, Row, _decode_ns_bytes, _group_ops, @@ -31,12 +34,12 @@ logger = logging.getLogger(__name__) -class AIOMySQLStore(AsyncBatchedBaseStore, BaseMySQLStore[aiomysql.Connection]): +class AIOMySQLStore(AsyncBatchedBaseStore, BaseMySQLStore[_ainternal.Conn]): __slots__ = ("_deserializer",) def __init__( self, - conn: aiomysql.Connection, + conn: _ainternal.Conn, *, deserializer: Optional[ Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]] @@ -51,44 +54,47 @@ async def abatch(self, ops: Iterable[Op]) -> list[Result]: grouped_ops, num_ops = _group_ops(ops) results: list[Result] = [None] * num_ops - tasks = [] + # tasks = [] + + async with _ainternal.get_connection(self.conn) as conn: + await self._execute_batch(grouped_ops, results, conn) - if GetOp in grouped_ops: - tasks.append( - self._batch_get_ops( - cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), results + async def _execute_batch( + self, + grouped_ops: dict, + results: list[Result], + conn: aiomysql.Connection + ) -> None: + async with self._cursor(conn) as cur: + if GetOp in grouped_ops: + await self._batch_get_ops( + cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), + results, + cur ) - ) - if PutOp in grouped_ops: - tasks.append( - self._batch_put_ops( - cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]) + if PutOp in grouped_ops: + await self._batch_put_ops( + cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]), + cur ) - ) - if SearchOp in grouped_ops: - tasks.append( - self._batch_search_ops( + if SearchOp in grouped_ops: + await self._batch_search_ops( cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), results, + cur ) - ) - if ListNamespacesOp in grouped_ops: - tasks.append( + if ListNamespacesOp in grouped_ops: self._batch_list_namespaces_ops( cast( Sequence[tuple[int, ListNamespacesOp]], grouped_ops[ListNamespacesOp], ), results, + cur ) - ) - - await asyncio.gather(*tasks) - - return results def batch(self, ops: Iterable[Op]) -> list[Result]: return asyncio.run_coroutine_threadsafe(self.abatch(ops), self.loop).result() @@ -97,14 +103,10 @@ async def _batch_get_ops( self, get_ops: Sequence[tuple[int, GetOp]], results: list[Result], + cur: aiomysql.DictCursor, ) -> None: - cursors = [] for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): - cur = await self._cursor() await cur.execute(query, params) - cursors.append((cur, namespace, items)) - - for cur, namespace, items in cursors: rows = cast(list[Row], await cur.fetchall()) key_to_row = {row["key"]: row for row in rows} for idx, key in items: @@ -119,26 +121,23 @@ async def _batch_get_ops( async def _batch_put_ops( self, put_ops: Sequence[tuple[int, PutOp]], + cur: aiomysql.DictCursor, ) -> None: queries = self._get_batch_PUT_queries(put_ops) for query, params in queries: - cur = await self._cursor() await cur.execute(query, params) async def _batch_search_ops( self, search_ops: Sequence[tuple[int, SearchOp]], results: list[Result], + cur: aiomysql.DictCursor, ) -> None: queries = self._get_batch_search_queries(search_ops) - cursors: list[tuple[aiomysql.DictCursor, int]] = [] for (query, params), (idx, _) in zip(queries, search_ops): - cur = await self._cursor() await cur.execute(query, params) - cursors.append((cur, idx)) - for cur, idx in cursors: rows = cast(list[Row], await cur.fetchall()) items = [ _row_to_item( @@ -165,35 +164,71 @@ async def _batch_list_namespaces_ops( namespaces = [_decode_ns_bytes(row["truncated_prefix"]) for row in rows] results[idx] = namespaces + + @asynccontextmanager + async def _cursor( + self, conn: aiomysql.Connection + ) -> AsyncIterator[aiomysql.DictCursor]: + """Create a database cursor as a context manager. + Args: + conn: The database connection to use + """ + async with conn.cursor(binary=True) as cur: + yield cur + @classmethod @asynccontextmanager async def from_conn_string( cls, conn_string: str, + *, + pool_config: Optional[PoolConfig] = None, ) -> AsyncIterator["AIOMySQLStore"]: """Create a new AIOMySQLStore instance from a connection string. Args: conn_string (str): The MySQL connection info string. - + pool_config (Optional[PoolConfig]): Configuration for the connection pool. + If provided, will create a connection pool and use it instead of a single connection. Returns: AIOMySQLStore: A new AIOMySQLStore instance. """ + logger.info(f"Creating AIOMySQLStore from connection string: {conn_string}") parsed = urllib.parse.urlparse(conn_string) - async with aiomysql.connect( - host=parsed.hostname or "localhost", - user=parsed.username, - password=parsed.password or "", - db=parsed.path[1:], - port=parsed.port or 3306, - autocommit=True, - ) as conn: - # This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119 - # is merged into aiomysql. - await conn.set_charset(pymysql.connections.DEFAULT_CHARSET) + # In order to provide additional params via the connection string, + # we convert the parsed.query to a dict so we can access the values. + # This is necessary when using a unix socket, for example. + params_as_dict = dict(urllib.parse.parse_qsl(parsed.query)) + + if pool_config is not None: + pc = pool_config.copy() + async with aiomysql.create_pool( + host=parsed.hostname or "localhost", + user=parsed.username, + password=parsed.password or "", + db=parsed.path[1:], + port=parsed.port or 3306, + unix_socket=params_as_dict.get("unix_socket"), + autocommit=True + **cast(dict, pc), + ) as pool: + pool.set_charset(pymysql.connections.DEFAULT_CHARSET) + yield cls(conn=pool) + else: + async with aiomysql.connect( + host=parsed.hostname or "localhost", + user=parsed.username, + password=parsed.password or "", + db=parsed.path[1:], + port=parsed.port or 3306, + autocommit=True, + ) as conn: + # This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119 + # is merged into aiomysql. + await conn.set_charset(pymysql.connections.DEFAULT_CHARSET) - yield cls(conn=conn) + yield cls(conn=conn) async def setup(self) -> None: """Set up the store database asynchronously. @@ -202,33 +237,31 @@ async def setup(self) -> None: already exist and runs database migrations. It MUST be called directly by the user the first time the store is used. """ - async with self.conn.cursor(aiomysql.DictCursor) as cur: - try: - await cur.execute( - "SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1" - ) - row = cast(dict, await cur.fetchone()) - if row is None: + async with _ainternal.get_connection(self.conn) as conn: + async with conn.cursor(aiomysql.DictCursor) as cur: + try: + await cur.execute( + "SELECT v FROM store_migrations ORDER BY v DESC LIMIT 1" + ) + row = cast(dict, await cur.fetchone()) + if row is None: + version = -1 + else: + version = row["v"] + except pymysql.ProgrammingError as e: + if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE: + raise version = -1 - else: - version = row["v"] - except pymysql.ProgrammingError as e: - if e.args[0] != pymysql.constants.ER.NO_SUCH_TABLE: - raise - version = -1 - # Create store_migrations table if it doesn't exist - await cur.execute( - """ - CREATE TABLE IF NOT EXISTS store_migrations ( - v INTEGER PRIMARY KEY + # Create store_migrations table if it doesn't exist + await cur.execute( + """ + CREATE TABLE IF NOT EXISTS store_migrations ( + v INTEGER PRIMARY KEY + ) + """ ) - """ - ) - for v, migration in enumerate( - self.MIGRATIONS[version + 1 :], start=version + 1 - ): - await cur.execute(migration) - await cur.execute("INSERT INTO store_migrations (v) VALUES (%s)", (v,)) - - async def _cursor(self) -> aiomysql.DictCursor: - return await self.conn.cursor(aiomysql.DictCursor) + for v, migration in enumerate( + self.MIGRATIONS[version + 1 :], start=version + 1 + ): + await cur.execute(migration) + await cur.execute("INSERT INTO store_migrations (v) VALUES (%s)", (v,)) diff --git a/langgraph/store/mysql/base.py b/langgraph/store/mysql/base.py index 1de3c41..c9fd633 100644 --- a/langgraph/store/mysql/base.py +++ b/langgraph/store/mysql/base.py @@ -32,6 +32,8 @@ SearchOp, ) +from langgraph.checkpoint.mysql import _ainternal as _ainternal + logger = logging.getLogger(__name__) @@ -54,6 +56,32 @@ ] +C = TypeVar("C", bound=Union[_ainternal.Conn]) + + +class PoolConfig(TypedDict, total=False): + """Connection pool settings for PostgreSQL connections. + Controls connection lifecycle and resource utilization: + - Small pools (1-5) suit low-concurrency workloads + - Larger pools handle concurrent requests but consume more resources + - Setting maxsize prevents resource exhaustion under load + """ + + minsize: int + """Minimum number of connections maintained in the pool. Defaults to 1.""" + + maxsize: Optional[int] + """Maximum number of connections allowed in the pool. None means unlimited.""" + + kwargs: dict + """Additional connection arguments passed to each connection in the pool. + + Default kwargs set automatically: + - autocommit: True + - prepare_threshold: 0 + - row_factory: dict_row + """ + class DictCursor(Protocol): """ Protocol that a cursor should implement. diff --git a/tests/conftest.py b/tests/conftest.py index b51416c..519c62b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import pytest DEFAULT_URI = "mysql://mysql:mysql@localhost:5441/mysql" - +DEFAULT_URI_WITH_SOCKET = f"{DEFAULT_URI}?query=unix_socket=/path/to/socket" @pytest.fixture(scope="function") async def conn() -> AsyncIterator[aiomysql.Connection]: