Skip to content

Commit

Permalink
Merge pull request #1771 from langchain-ai/nc/19sep/mypy-langgraph-pa…
Browse files Browse the repository at this point in the history
…rtial

Enable mypy for langgraph lib
  • Loading branch information
nfcampos authored Sep 19, 2024
2 parents 0cf1a13 + 7643c11 commit 0bbe461
Show file tree
Hide file tree
Showing 47 changed files with 593 additions and 401 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Any, Iterator, List, Optional, Union
from typing import Any, Iterator, Optional, Sequence, Union

from langchain_core.runnables import RunnableConfig
from psycopg import Connection, Cursor, Pipeline
Expand Down Expand Up @@ -332,7 +332,7 @@ def put(
def put_writes(
self,
config: RunnableConfig,
writes: List[tuple[str, Any]],
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
Expand Down
6 changes: 3 additions & 3 deletions libs/checkpoint-postgres/langgraph/checkpoint/postgres/aio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Union

from langchain_core.runnables import RunnableConfig
from psycopg import AsyncConnection, AsyncCursor, AsyncPipeline
Expand Down Expand Up @@ -291,7 +291,7 @@ async def aput(
async def aput_writes(
self,
config: RunnableConfig,
writes: list[tuple[str, Any]],
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint asynchronously.
Expand Down Expand Up @@ -425,7 +425,7 @@ def put(
def put_writes(
self,
config: RunnableConfig,
writes: List[tuple[str, Any]],
writes: Sequence[tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from typing import Any, List, Optional, Tuple, cast
from typing import Any, List, Optional, Sequence, Tuple, cast

from langchain_core.runnables import RunnableConfig
from psycopg.types.json import Jsonb
Expand Down Expand Up @@ -209,7 +209,7 @@ def _dump_writes(
checkpoint_ns: str,
checkpoint_id: str,
task_id: str,
writes: list[tuple[str, Any]],
writes: Sequence[tuple[str, Any]],
) -> list[tuple[str, str, str, str, int, str, str, bytes]]:
return [
(
Expand Down
3 changes: 1 addition & 2 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -216,7 +215,7 @@ def put(
).result()

def put_writes(
self, config: RunnableConfig, writes: List[Tuple[str, Any]], task_id: str
self, config: RunnableConfig, writes: Sequence[Tuple[str, Any]], task_id: str
) -> None:
return asyncio.run_coroutine_threadsafe(
self.aput_writes(config, writes, task_id), self.loop
Expand Down
9 changes: 5 additions & 4 deletions libs/checkpoint/langgraph/checkpoint/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
TypedDict,
TypeVar,
Union,
)

from langchain_core.runnables import ConfigurableFieldSpec, RunnableConfig
from typing_extensions import TypeVar

from langgraph.checkpoint.base.id import uuid6
from langgraph.checkpoint.serde.base import SerializerProtocol, maybe_add_typed_methods
Expand All @@ -28,7 +29,7 @@
SendProtocol,
)

V = TypeVar("V", int, float, str, default=int)
V = TypeVar("V", int, float, str)
PendingWrite = Tuple[str, str, Any]


Expand Down Expand Up @@ -301,7 +302,7 @@ def put(
def put_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Store intermediate writes linked to a checkpoint.
Expand Down Expand Up @@ -393,7 +394,7 @@ async def aput(
async def aput_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Asynchronously store intermediate writes linked to a checkpoint.
Expand Down
6 changes: 3 additions & 3 deletions libs/checkpoint/langgraph/checkpoint/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from functools import partial
from types import TracebackType
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple
from typing import Any, AsyncIterator, Dict, Iterator, Optional, Sequence, Tuple

from langchain_core.runnables import RunnableConfig

Expand Down Expand Up @@ -344,7 +344,7 @@ def put(
def put_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Save a list of writes to the in-memory storage.
Expand Down Expand Up @@ -447,7 +447,7 @@ async def aput(
async def aput_writes(
self,
config: RunnableConfig,
writes: List[Tuple[str, Any]],
writes: Sequence[Tuple[str, Any]],
task_id: str,
) -> None:
"""Asynchronous version of put_writes.
Expand Down
11 changes: 1 addition & 10 deletions libs/checkpoint/langgraph/checkpoint/serde/types.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import (
Any,
AsyncGenerator,
Generator,
Optional,
Protocol,
Sequence,
TypeVar,
runtime_checkable,
)

from langchain_core.runnables import RunnableConfig
from typing_extensions import Self

ERROR = "__error__"
Expand All @@ -31,13 +28,7 @@ def UpdateType(self) -> Any: ...

def checkpoint(self) -> Optional[C]: ...

def from_checkpoint(
self, checkpoint: Optional[C], config: RunnableConfig
) -> Generator[Self, None, None]: ...

async def afrom_checkpoint(
self, checkpoint: Optional[C], config: RunnableConfig
) -> AsyncGenerator[Self, None]: ...
def from_checkpoint(self, checkpoint: Optional[C]) -> Self: ...

def update(self, values: Sequence[Update]) -> bool: ...

Expand Down
3 changes: 2 additions & 1 deletion libs/langgraph/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ lint lint_diff lint_package lint_tests:
poetry run ruff check .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff check --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy langgraph --cache-dir $(MYPY_CACHE)

format format_diff:
poetry run ruff format $(PYTHON_FILES)
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/bench/react_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from langgraph.pregel import Pregel


def react_agent(n_tools: int, checkpointer: BaseCheckpointSaver) -> Pregel:
def react_agent(n_tools: int, checkpointer: Optional[BaseCheckpointSaver]) -> Pregel:
class FakeFuntionChatModel(FakeMessagesListChatModel):
def bind_tools(self, functions: list):
return self
Expand Down
8 changes: 4 additions & 4 deletions libs/langgraph/langgraph/_api/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def decorator(obj: Union[F, C]) -> Union[F, C]:
f" removed in {removal_str}. Use {alternative} instead.{example}"
)
if isinstance(obj, type):
original_init = obj.__init__
original_init = obj.__init__ # type: ignore[misc]

@functools.wraps(original_init)
def new_init(self, *args: Any, **kwargs: Any) -> None:
def new_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore[no-untyped-def]
warnings.warn(message, LangGraphDeprecationWarning, stacklevel=2)
original_init(self, *args, **kwargs)

obj.__init__ = new_init
obj.__init__ = new_init # type: ignore[misc]

docstring = (
f"**Deprecated**: This class is deprecated as of version {since}. "
Expand Down Expand Up @@ -68,7 +68,7 @@ def deprecated_parameter(
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs): # type: ignore[no-untyped-def]
if arg_name in kwargs:
warnings.warn(
f"Parameter '{arg_name}' in function '{func.__name__}' is "
Expand Down
2 changes: 1 addition & 1 deletion libs/langgraph/langgraph/channels/binop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# Adapted from typing_extensions
def _strip_extras(t):
def _strip_extras(t): # type: ignore[no-untyped-def]
"""Strips Annotated, Required and NotRequired from a given type."""
if hasattr(t, "__origin__"):
return _strip_extras(t.__origin__)
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/channels/dynamic_barrier_value.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, NamedTuple, Optional, Sequence, Type, Union
from typing import Any, Generic, NamedTuple, Optional, Sequence, Type, Union

from typing_extensions import Self

Expand All @@ -7,7 +7,7 @@


class WaitForNames(NamedTuple):
names: set[Value]
names: set[Any]


class DynamicBarrierValue(
Expand Down
4 changes: 2 additions & 2 deletions libs/langgraph/langgraph/channels/ephemeral_value.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, Optional, Sequence, Type
from typing import Any, Generic, Optional, Sequence, Type

from typing_extensions import Self

Expand All @@ -11,7 +11,7 @@ class EphemeralValue(Generic[Value], BaseChannel[Value, Value, Value]):

__slots__ = ("value", "guard")

def __init__(self, typ: Type[Value], guard: bool = True) -> None:
def __init__(self, typ: Any, guard: bool = True) -> None:
super().__init__(typ)
self.guard = guard

Expand Down
5 changes: 4 additions & 1 deletion libs/langgraph/langgraph/channels/named_barrier_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ class NamedBarrierValue(Generic[Value], BaseChannel[Value, Value, set[Value]]):

__slots__ = ("names", "seen")

names: set[Value]
seen: set[Value]

def __init__(self, typ: Type[Value], names: set[Value]) -> None:
super().__init__(typ)
self.names = names
self.seen = set()
self.seen: set[str] = set()

def __eq__(self, value: object) -> bool:
return isinstance(value, NamedBarrierValue) and value.names == self.names
Expand Down
Loading

0 comments on commit 0bbe461

Please sign in to comment.