From c2052d11c25354cfb5effaaba3e015f90eaa831a Mon Sep 17 00:00:00 2001 From: vbarda Date: Wed, 13 Nov 2024 21:33:42 -0500 Subject: [PATCH 1/4] checkpoint-postgres: remove pipeline flag in cursor --- .../langgraph/checkpoint/postgres/__init__.py | 60 ++++++++---------- .../langgraph/checkpoint/postgres/aio.py | 63 ++++++++----------- 2 files changed, 54 insertions(+), 69 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 37e7c2831..7085e107a 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -1,5 +1,5 @@ import threading -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from typing import Any, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig @@ -308,27 +308,29 @@ def put( } } - with self._cursor(pipeline=True) as cur: - cur.executemany( - self.UPSERT_CHECKPOINT_BLOBS_SQL, - self._dump_blobs( - thread_id, - checkpoint_ns, - copy.pop("channel_values"), # type: ignore[misc] - new_versions, - ), - ) - cur.execute( - self.UPSERT_CHECKPOINTS_SQL, - ( - thread_id, - checkpoint_ns, - checkpoint["id"], - checkpoint_id, - Jsonb(self._dump_checkpoint(copy)), - self._dump_metadata(metadata), - ), - ) + with self._cursor() as cur: + # Use connection's transaction context manager when not in pipeline mode + with cur.connection.transaction() if self.pipe is None else nullcontext(): + cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + self._dump_blobs( + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + checkpoint_id, + Jsonb(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) return next_config def put_writes( @@ -351,7 +353,7 @@ def put_writes( if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) - with self._cursor(pipeline=True) as cur: + with self._cursor() as cur: cur.executemany( query, self._dump_writes( @@ -364,7 +366,7 @@ def put_writes( ) @contextmanager - def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: + def _cursor(self) -> Iterator[Cursor[DictRow]]: with _get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently @@ -374,15 +376,7 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: - if pipeline: - self.pipe.sync() - elif pipeline: - # a connection not in pipeline mode can only be used by one - # thread/coroutine at a time, so we acquire a lock - with self.lock, conn.pipeline(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: - yield cur + self.pipe.sync() else: with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 59ee7cbf9..eac629605 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -1,5 +1,5 @@ import asyncio -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, nullcontext from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig @@ -264,28 +264,29 @@ async def aput( } } - async with self._cursor(pipeline=True) as cur: - await cur.executemany( - self.UPSERT_CHECKPOINT_BLOBS_SQL, - await asyncio.to_thread( - self._dump_blobs, - thread_id, - checkpoint_ns, - copy.pop("channel_values"), # type: ignore[misc] - new_versions, - ), - ) - await cur.execute( - self.UPSERT_CHECKPOINTS_SQL, - ( - thread_id, - checkpoint_ns, - checkpoint["id"], - checkpoint_id, - Jsonb(self._dump_checkpoint(copy)), - self._dump_metadata(metadata), - ), - ) + async with self._cursor() as cur: + async with cur.connection.transaction() if self.pipe is None else nullcontext(): + await cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + await asyncio.to_thread( + self._dump_blobs, + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + await cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + checkpoint_id, + Jsonb(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) return next_config async def aput_writes( @@ -316,13 +317,11 @@ async def aput_writes( task_id, writes, ) - async with self._cursor(pipeline=True) as cur: + async with self._cursor() as cur: await cur.executemany(query, params) @asynccontextmanager - async def _cursor( - self, *, pipeline: bool = False - ) -> AsyncIterator[AsyncCursor[DictRow]]: + async def _cursor(self) -> AsyncIterator[AsyncCursor[DictRow]]: async with _get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently @@ -332,15 +331,7 @@ async def _cursor( async with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: - if pipeline: - await self.pipe.sync() - elif pipeline: - # a connection not in pipeline mode can only be used by one - # thread/coroutine at a time, so we acquire a lock - async with self.lock, conn.pipeline(), conn.cursor( - binary=True, row_factory=dict_row - ) as cur: - yield cur + await self.pipe.sync() else: async with self.lock, conn.cursor( binary=True, row_factory=dict_row From 0a5220aa0715a0756d748e68c991c5167ef239ca Mon Sep 17 00:00:00 2001 From: vbarda Date: Thu, 14 Nov 2024 18:55:02 -0500 Subject: [PATCH 2/4] code review --- .../langgraph/checkpoint/postgres/__init__.py | 76 +++++++++++------- .../langgraph/checkpoint/postgres/aio.py | 79 ++++++++++++------- 2 files changed, 99 insertions(+), 56 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 7085e107a..260d12863 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -1,10 +1,10 @@ import threading -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager from typing import Any, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig from psycopg import Connection, Cursor, Pipeline -from psycopg.errors import UndefinedTable +from psycopg.errors import NotSupportedError, UndefinedTable from psycopg.rows import DictRow, dict_row from psycopg.types.json import Jsonb from psycopg_pool import ConnectionPool @@ -308,29 +308,27 @@ def put( } } - with self._cursor() as cur: - # Use connection's transaction context manager when not in pipeline mode - with cur.connection.transaction() if self.pipe is None else nullcontext(): - cur.executemany( - self.UPSERT_CHECKPOINT_BLOBS_SQL, - self._dump_blobs( - thread_id, - checkpoint_ns, - copy.pop("channel_values"), # type: ignore[misc] - new_versions, - ), - ) - cur.execute( - self.UPSERT_CHECKPOINTS_SQL, - ( - thread_id, - checkpoint_ns, - checkpoint["id"], - checkpoint_id, - Jsonb(self._dump_checkpoint(copy)), - self._dump_metadata(metadata), - ), - ) + with self._cursor(pipeline=True) as cur: + cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + self._dump_blobs( + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + checkpoint_id, + Jsonb(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) return next_config def put_writes( @@ -353,7 +351,7 @@ def put_writes( if all(w[0] in WRITES_IDX_MAP for w in writes) else self.INSERT_CHECKPOINT_WRITES_SQL ) - with self._cursor() as cur: + with self._cursor(pipeline=True) as cur: cur.executemany( query, self._dump_writes( @@ -366,7 +364,14 @@ def put_writes( ) @contextmanager - def _cursor(self) -> Iterator[Cursor[DictRow]]: + def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use pipeline for the DB operations inside the context manager. + Will be applied regardless of whether the PostgresSaver instance was initialized with a pipeline. + If pipeline mode is not supported, will fall back to using transaction context manager. + """ with _get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently @@ -376,7 +381,22 @@ def _cursor(self) -> Iterator[Cursor[DictRow]]: with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: - self.pipe.sync() + if pipeline: + self.pipe.sync() + elif pipeline: + # a connection not in pipeline mode can only be used by one + # thread/coroutine at a time, so we acquire a lock + try: + with self.lock, conn.pipeline(), conn.cursor( + binary=True, row_factory=dict_row + ) as cur: + yield cur + except NotSupportedError: + # Use connection's transaction context manager when pipeline mode not supported + with self.lock, conn.transaction(), conn.cursor( + binary=True, row_factory=dict_row + ) as cur: + yield cur else: with self.lock, conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index eac629605..43429ae4d 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -1,10 +1,10 @@ import asyncio -from contextlib import asynccontextmanager, nullcontext +from contextlib import asynccontextmanager from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline -from psycopg.errors import UndefinedTable +from psycopg.errors import NotSupportedError, UndefinedTable from psycopg.rows import DictRow, dict_row from psycopg.types.json import Jsonb from psycopg_pool import AsyncConnectionPool @@ -264,29 +264,28 @@ async def aput( } } - async with self._cursor() as cur: - async with cur.connection.transaction() if self.pipe is None else nullcontext(): - await cur.executemany( - self.UPSERT_CHECKPOINT_BLOBS_SQL, - await asyncio.to_thread( - self._dump_blobs, - thread_id, - checkpoint_ns, - copy.pop("channel_values"), # type: ignore[misc] - new_versions, - ), - ) - await cur.execute( - self.UPSERT_CHECKPOINTS_SQL, - ( - thread_id, - checkpoint_ns, - checkpoint["id"], - checkpoint_id, - Jsonb(self._dump_checkpoint(copy)), - self._dump_metadata(metadata), - ), - ) + async with self._cursor(pipeline=True) as cur: + await cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + await asyncio.to_thread( + self._dump_blobs, + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + await cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + checkpoint["id"], + checkpoint_id, + Jsonb(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) return next_config async def aput_writes( @@ -317,11 +316,20 @@ async def aput_writes( task_id, writes, ) - async with self._cursor() as cur: + async with self._cursor(pipeline=True) as cur: await cur.executemany(query, params) @asynccontextmanager - async def _cursor(self) -> AsyncIterator[AsyncCursor[DictRow]]: + async def _cursor( + self, *, pipeline: bool = False + ) -> AsyncIterator[AsyncCursor[DictRow]]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use pipeline for the DB operations inside the context manager. + Will be applied regardless of whether the AsyncPostgresSaver instance was initialized with a pipeline. + If pipeline mode is not supported, will fall back to using transaction context manager. + """ async with _get_connection(self.conn) as conn: if self.pipe: # a connection in pipeline mode can be used concurrently @@ -331,7 +339,22 @@ async def _cursor(self) -> AsyncIterator[AsyncCursor[DictRow]]: async with conn.cursor(binary=True, row_factory=dict_row) as cur: yield cur finally: - await self.pipe.sync() + if pipeline: + await self.pipe.sync() + elif pipeline: + # a connection not in pipeline mode can only be used by one + # thread/coroutine at a time, so we acquire a lock + try: + async with self.lock, conn.pipeline(), conn.cursor( + binary=True, row_factory=dict_row + ) as cur: + yield cur + except NotSupportedError: + # Use connection's transaction context manager when pipeline mode not supported + async with self.lock, conn.transaction(), conn.cursor( + binary=True, row_factory=dict_row + ) as cur: + yield cur else: async with self.lock, conn.cursor( binary=True, row_factory=dict_row From f0505155a2395ab846a4556c1c12cb3fbc05d5f8 Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 18 Nov 2024 11:12:21 -0500 Subject: [PATCH 3/4] cache --- .../langgraph/checkpoint/postgres/__init__.py | 15 +++++++++++++-- .../langgraph/checkpoint/postgres/aio.py | 15 +++++++++++++-- .../langgraph/checkpoint/postgres/base.py | 1 + 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 260d12863..5b2dbff13 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -363,6 +363,16 @@ def put_writes( ), ) + def _check_pipeline_support(self, conn: Connection[DictRow]) -> None: + if self.supports_pipeline is not None: + return + + try: + with conn.pipeline(): + self.supports_pipeline = True + except NotSupportedError: + self.supports_pipeline = False + @contextmanager def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: """Create a database cursor as a context manager. @@ -384,14 +394,15 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: if pipeline: self.pipe.sync() elif pipeline: + self._check_pipeline_support(conn) # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock - try: + if self.supports_pipeline: with self.lock, conn.pipeline(), conn.cursor( binary=True, row_factory=dict_row ) as cur: yield cur - except NotSupportedError: + else: # Use connection's transaction context manager when pipeline mode not supported with self.lock, conn.transaction(), conn.cursor( binary=True, row_factory=dict_row diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 43429ae4d..4440e2807 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -319,6 +319,16 @@ async def aput_writes( async with self._cursor(pipeline=True) as cur: await cur.executemany(query, params) + async def _check_pipeline_support(self, conn: AsyncConnection[DictRow]) -> None: + if self.supports_pipeline is not None: + return + + try: + async with conn.pipeline(): + self.supports_pipeline = True + except NotSupportedError: + self.supports_pipeline = False + @asynccontextmanager async def _cursor( self, *, pipeline: bool = False @@ -342,14 +352,15 @@ async def _cursor( if pipeline: await self.pipe.sync() elif pipeline: + await self._check_pipeline_support(conn) # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock - try: + if self.supports_pipeline: async with self.lock, conn.pipeline(), conn.cursor( binary=True, row_factory=dict_row ) as cur: yield cur - except NotSupportedError: + else: # Use connection's transaction context manager when pipeline mode not supported async with self.lock, conn.transaction(), conn.cursor( binary=True, row_factory=dict_row diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index 5f6a2ab1b..755192beb 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -133,6 +133,7 @@ class BasePostgresSaver(BaseCheckpointSaver[str]): INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL jsonplus_serde = JsonPlusSerializer() + supports_pipeline: Optional[bool] = None def _load_checkpoint( self, From f807b730926ffbb891a18eaa8c4cb704091e6b1e Mon Sep 17 00:00:00 2001 From: vbarda Date: Mon, 18 Nov 2024 12:15:18 -0500 Subject: [PATCH 4/4] use capabilities --- .../langgraph/checkpoint/postgres/__init__.py | 16 +++------------- .../langgraph/checkpoint/postgres/aio.py | 16 +++------------- .../langgraph/checkpoint/postgres/base.py | 2 +- 3 files changed, 7 insertions(+), 27 deletions(-) diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py index 5b2dbff13..b8138a945 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py @@ -3,8 +3,8 @@ from typing import Any, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig -from psycopg import Connection, Cursor, Pipeline -from psycopg.errors import NotSupportedError, UndefinedTable +from psycopg import Capabilities, Connection, Cursor, Pipeline +from psycopg.errors import UndefinedTable from psycopg.rows import DictRow, dict_row from psycopg.types.json import Jsonb from psycopg_pool import ConnectionPool @@ -52,6 +52,7 @@ def __init__( self.conn = conn self.pipe = pipe self.lock = threading.Lock() + self.supports_pipeline = Capabilities().has_pipeline() @classmethod @contextmanager @@ -363,16 +364,6 @@ def put_writes( ), ) - def _check_pipeline_support(self, conn: Connection[DictRow]) -> None: - if self.supports_pipeline is not None: - return - - try: - with conn.pipeline(): - self.supports_pipeline = True - except NotSupportedError: - self.supports_pipeline = False - @contextmanager def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: """Create a database cursor as a context manager. @@ -394,7 +385,6 @@ def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]: if pipeline: self.pipe.sync() elif pipeline: - self._check_pipeline_support(conn) # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py index 4440e2807..5b67e4ca9 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py @@ -3,8 +3,8 @@ from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union from langchain_core.runnables import RunnableConfig -from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline -from psycopg.errors import NotSupportedError, UndefinedTable +from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline, Capabilities +from psycopg.errors import UndefinedTable from psycopg.rows import DictRow, dict_row from psycopg.types.json import Jsonb from psycopg_pool import AsyncConnectionPool @@ -55,6 +55,7 @@ def __init__( self.pipe = pipe self.lock = asyncio.Lock() self.loop = asyncio.get_running_loop() + self.supports_pipeline = Capabilities().has_pipeline() @classmethod @asynccontextmanager @@ -319,16 +320,6 @@ async def aput_writes( async with self._cursor(pipeline=True) as cur: await cur.executemany(query, params) - async def _check_pipeline_support(self, conn: AsyncConnection[DictRow]) -> None: - if self.supports_pipeline is not None: - return - - try: - async with conn.pipeline(): - self.supports_pipeline = True - except NotSupportedError: - self.supports_pipeline = False - @asynccontextmanager async def _cursor( self, *, pipeline: bool = False @@ -352,7 +343,6 @@ async def _cursor( if pipeline: await self.pipe.sync() elif pipeline: - await self._check_pipeline_support(conn) # a connection not in pipeline mode can only be used by one # thread/coroutine at a time, so we acquire a lock if self.supports_pipeline: diff --git a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py index 755192beb..ae65cab68 100644 --- a/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py +++ b/libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py @@ -133,7 +133,7 @@ class BasePostgresSaver(BaseCheckpointSaver[str]): INSERT_CHECKPOINT_WRITES_SQL = INSERT_CHECKPOINT_WRITES_SQL jsonplus_serde = JsonPlusSerializer() - supports_pipeline: Optional[bool] = None + supports_pipeline: bool def _load_checkpoint( self,