Skip to content

Commit

Permalink
Merge pull request #1767 from langchain-ai/nc/19sep/mypy-postgres
Browse files Browse the repository at this point in the history
ci: Enable mypy checks for checkpoint-postgres lib
  • Loading branch information
nfcampos authored Sep 19, 2024
2 parents bc95a79 + fdf7bad commit dc360fd
Show file tree
Hide file tree
Showing 10 changed files with 106 additions and 73 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ jobs:
# Starting new jobs is also relatively slow,
# so linting on fewer versions makes CI faster.
python-version:
- "3.9"
- "3.11"
- "3.12"
name: "lint #${{ matrix.python-version }}"
steps:
- uses: actions/checkout@v4
Expand Down
3 changes: 2 additions & 1 deletion libs/checkpoint-postgres/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ lint lint_diff lint_package lint_tests:
poetry run ruff check .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
poetry run ruff format $(PYTHON_FILES)
Expand Down
22 changes: 14 additions & 8 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain_core.runnables import RunnableConfig
from psycopg import Connection, Cursor, Pipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import dict_row
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import ConnectionPool

Expand All @@ -22,9 +22,11 @@
)
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = Union[Connection[DictRow], ConnectionPool[Connection[DictRow]]]


@contextmanager
def _get_connection(conn: Union[Connection, ConnectionPool]) -> Iterator[Connection]:
def _get_connection(conn: Conn) -> Iterator[Connection[DictRow]]:
if isinstance(conn, Connection):
yield conn
elif isinstance(conn, ConnectionPool):
Expand All @@ -39,7 +41,7 @@ class PostgresSaver(BasePostgresSaver):

def __init__(
self,
conn: Union[Connection, ConnectionPool],
conn: Conn,
pipe: Optional[Pipeline] = None,
serde: Optional[SerializerProtocol] = None,
) -> None:
Expand Down Expand Up @@ -85,9 +87,13 @@ def setup(self) -> None:
"""
with self._cursor() as cur:
try:
version = cur.execute(
row = cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
).fetchone()["v"]
).fetchone()
if row is None:
version = -1
else:
version = row["v"]
except UndefinedTable:
version = -1
for v, migration in zip(
Expand Down Expand Up @@ -212,7 +218,7 @@ def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id:
args = (thread_id, checkpoint_ns, checkpoint_id)
args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id)
where = "WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s"
else:
args = (thread_id, checkpoint_ns)
Expand Down Expand Up @@ -306,7 +312,7 @@ def put(
self._dump_blobs(
thread_id,
checkpoint_ns,
copy.pop("channel_values"),
copy.pop("channel_values"), # type: ignore[misc]
new_versions,
),
)
Expand Down Expand Up @@ -356,7 +362,7 @@ def put_writes(
)

@contextmanager
def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor]:
def _cursor(self, *, pipeline: bool = False) -> Iterator[Cursor[DictRow]]:
with _get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
Expand Down
27 changes: 18 additions & 9 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline
from psycopg.errors import UndefinedTable
from psycopg.rows import dict_row
from psycopg.rows import DictRow, dict_row
from psycopg.types.json import Jsonb
from psycopg_pool import AsyncConnectionPool

Expand All @@ -20,11 +20,13 @@
from langgraph.checkpoint.postgres.base import BasePostgresSaver
from langgraph.checkpoint.serde.base import SerializerProtocol

Conn = Union[AsyncConnection[DictRow], AsyncConnectionPool[AsyncConnection[DictRow]]]


@asynccontextmanager
async def _get_connection(
conn: Union[AsyncConnection, AsyncConnectionPool],
) -> AsyncIterator[AsyncConnection]:
conn: Conn,
) -> AsyncIterator[AsyncConnection[DictRow]]:
if isinstance(conn, AsyncConnection):
yield conn
elif isinstance(conn, AsyncConnectionPool):
Expand All @@ -39,7 +41,7 @@ class AsyncPostgresSaver(BasePostgresSaver):

def __init__(
self,
conn: Union[AsyncConnection, AsyncConnectionPool],
conn: Conn,
pipe: Optional[AsyncPipeline] = None,
serde: Optional[SerializerProtocol] = None,
) -> None:
Expand Down Expand Up @@ -93,7 +95,11 @@ async def setup(self) -> None:
results = await cur.execute(
"SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1"
)
version = (await results.fetchone())["v"]
row = await results.fetchone()
if row is None:
version = -1
else:
version = row["v"]
except UndefinedTable:
version = -1
for v, migration in zip(
Expand Down Expand Up @@ -180,7 +186,7 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id:
args = (thread_id, checkpoint_ns, checkpoint_id)
args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id)
where = "WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s"
else:
args = (thread_id, checkpoint_ns)
Expand Down Expand Up @@ -265,7 +271,7 @@ async def aput(
self._dump_blobs,
thread_id,
checkpoint_ns,
copy.pop("channel_values"),
copy.pop("channel_values"), # type: ignore[misc]
new_versions,
),
)
Expand Down Expand Up @@ -314,7 +320,9 @@ async def aput_writes(
await cur.executemany(query, params)

@asynccontextmanager
async def _cursor(self, *, pipeline: bool = False) -> AsyncIterator[AsyncCursor]:
async def _cursor(
self, *, pipeline: bool = False
) -> AsyncIterator[AsyncCursor[DictRow]]:
async with _get_connection(self.conn) as conn:
if self.pipe:
# a connection in pipeline mode can be used concurrently
Expand Down Expand Up @@ -365,7 +373,8 @@ def list(
while True:
try:
yield asyncio.run_coroutine_threadsafe(
anext(aiter_), self.loop
anext(aiter_),
self.loop,
).result()
except StopAsyncIteration:
break
Expand Down
18 changes: 10 additions & 8 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import random
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, cast

from langchain_core.runnables import RunnableConfig
from psycopg.types.json import Jsonb

from langgraph.checkpoint.base import (
WRITES_IDX_MAP,
BaseCheckpointSaver,
ChannelVersions,
Checkpoint,
CheckpointMetadata,
get_checkpoint_id,
)
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer
Expand Down Expand Up @@ -122,7 +124,7 @@
"""


class BasePostgresSaver(BaseCheckpointSaver):
class BasePostgresSaver(BaseCheckpointSaver[str]):
SELECT_SQL = SELECT_SQL
MIGRATIONS = MIGRATIONS
UPSERT_CHECKPOINT_BLOBS_SQL = UPSERT_CHECKPOINT_BLOBS_SQL
Expand Down Expand Up @@ -165,8 +167,8 @@ def _dump_blobs(
thread_id: str,
checkpoint_ns: str,
values: dict[str, Any],
versions: dict[str, str],
) -> list[tuple[str, str, str, str, str, bytes]]:
versions: ChannelVersions,
) -> list[tuple[str, str, str, str, str, Optional[bytes]]]:
if not versions:
return []

Expand All @@ -175,7 +177,7 @@ def _dump_blobs(
thread_id,
checkpoint_ns,
k,
ver,
cast(str, ver),
*(
self.serde.dumps_typed(values[k])
if k in values
Expand Down Expand Up @@ -208,7 +210,7 @@ def _dump_writes(
checkpoint_id: str,
task_id: str,
writes: list[tuple[str, Any]],
) -> list[tuple[str, str, str, int, str, str, bytes]]:
) -> list[tuple[str, str, str, str, int, str, str, bytes]]:
return [
(
thread_id,
Expand All @@ -222,10 +224,10 @@ def _dump_writes(
for idx, (channel, value) in enumerate(writes)
]

def _load_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
def _load_metadata(self, metadata: dict[str, Any]) -> CheckpointMetadata:
return self.jsonplus_serde.loads(self.jsonplus_serde.dumps(metadata))

def _dump_metadata(self, metadata) -> str:
def _dump_metadata(self, metadata: CheckpointMetadata) -> str:
serialized_metadata = self.jsonplus_serde.dumps(metadata)
return serialized_metadata.decode()

Expand Down
56 changes: 28 additions & 28 deletions libs/checkpoint-postgres/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions libs/checkpoint-postgres/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,13 @@ lint.select = [
"I", # isort
]
lint.ignore = ["E501", "B008", "UP007", "UP006"]

[tool.mypy]
# https://mypy.readthedocs.io/en/stable/config_file.html
disallow_untyped_defs = "True"
explicit_package_bases = "True"
warn_no_return = "False"
warn_unused_ignores = "True"
warn_redundant_casts = "True"
allow_redefinition = "True"
disable_error_code = "typeddict-item, return-value"
8 changes: 5 additions & 3 deletions libs/checkpoint-postgres/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
from typing import AsyncIterator

import pytest
from psycopg import AsyncConnection
from psycopg.errors import UndefinedTable
from psycopg.rows import dict_row
from psycopg.rows import DictRow, dict_row

DEFAULT_URI = "postgres://postgres:postgres@localhost:5441/postgres?sslmode=disable"


@pytest.fixture(scope="function")
async def conn():
async def conn() -> AsyncIterator[AsyncConnection[DictRow]]:
async with await AsyncConnection.connect(
DEFAULT_URI, autocommit=True, prepare_threshold=0, row_factory=dict_row
) as conn:
yield conn


@pytest.fixture(scope="function", autouse=True)
async def clear_test_db(conn):
async def clear_test_db(conn: AsyncConnection[DictRow]) -> None:
"""Delete all tables before each test."""
try:
await conn.execute("DELETE FROM checkpoints")
Expand Down
Loading

0 comments on commit dc360fd

Please sign in to comment.