diff --git a/langgraph/checkpoint/mysql/aio.py b/langgraph/checkpoint/mysql/aio.py index 2236c16..cf841b0 100644 --- a/langgraph/checkpoint/mysql/aio.py +++ b/langgraph/checkpoint/mysql/aio.py @@ -71,15 +71,24 @@ async def from_conn_string( Returns: AIOMySQLSaver: A new AIOMySQLSaver instance. + + Example: + conn_string=mysql+aiomysql://user:password@localhost/db?unix_socket=/path/to/socket """ parsed = urllib.parse.urlparse(conn_string) + # In order to provide additional params via the connection string, + # we convert the parsed.query to a dict so we can access the values. + # This is necessary when using a unix socket, for example. + params_as_dict = dict(urllib.parse.parse_qsl(parsed.query)) + async with aiomysql.connect( host=parsed.hostname or "localhost", user=parsed.username, password=parsed.password or "", db=parsed.path[1:], port=parsed.port or 3306, + unix_socket=params_as_dict.get("unix_socket"), autocommit=True, ) as conn: # This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119 diff --git a/langgraph/checkpoint/mysql/pymysql.py b/langgraph/checkpoint/mysql/pymysql.py index 968ab1f..464e048 100644 --- a/langgraph/checkpoint/mysql/pymysql.py +++ b/langgraph/checkpoint/mysql/pymysql.py @@ -27,15 +27,24 @@ def from_conn_string( Returns: PyMySQLSaver: A new PyMySQLSaver instance. + + Example: + conn_string=mysql+aiomysql://user:password@localhost/db?unix_socket=/path/to/socket """ parsed = urllib.parse.urlparse(conn_string) + # In order to provide additional params via the connection string, + # we convert the parsed.query to a dict so we can access the values. + # This is necessary when using a unix socket, for example. + params_as_dict = dict(urllib.parse.parse_qsl(parsed.query)) + with pymysql.connect( host=parsed.hostname, user=parsed.username, password=parsed.password or "", database=parsed.path[1:], port=parsed.port or 3306, + unix_socket=params_as_dict.get("unix_socket"), autocommit=True, ) as conn: yield cls(conn)