Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(config): extract default values, description from pydantic models, typeddict and dataclass #2691

Merged
merged 4 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import annotations

Check notice on line 1 in libs/langgraph/langgraph/pregel/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Benchmark results

......................................... fanout_to_subgraph_10x: Mean +- std dev: 62.0 ms +- 1.2 ms ......................................... fanout_to_subgraph_10x_sync: Mean +- std dev: 52.5 ms +- 0.6 ms ......................................... fanout_to_subgraph_10x_checkpoint: Mean +- std dev: 93.3 ms +- 7.0 ms ......................................... fanout_to_subgraph_10x_checkpoint_sync: Mean +- std dev: 95.1 ms +- 1.5 ms ......................................... fanout_to_subgraph_100x: Mean +- std dev: 614 ms +- 10 ms ......................................... fanout_to_subgraph_100x_sync: Mean +- std dev: 515 ms +- 8 ms ......................................... fanout_to_subgraph_100x_checkpoint: Mean +- std dev: 954 ms +- 53 ms ......................................... fanout_to_subgraph_100x_checkpoint_sync: Mean +- std dev: 948 ms +- 17 ms ......................................... react_agent_10x: Mean +- std dev: 31.1 ms +- 0.6 ms ......................................... react_agent_10x_sync: Mean +- std dev: 22.9 ms +- 0.2 ms ......................................... react_agent_10x_checkpoint: Mean +- std dev: 47.0 ms +- 0.8 ms ......................................... react_agent_10x_checkpoint_sync: Mean +- std dev: 36.9 ms +- 0.5 ms ......................................... react_agent_100x: Mean +- std dev: 348 ms +- 6 ms ......................................... react_agent_100x_sync: Mean +- std dev: 277 ms +- 3 ms ......................................... react_agent_100x_checkpoint: Mean +- std dev: 932 ms +- 9 ms ......................................... react_agent_100x_checkpoint_sync: Mean +- std dev: 831 ms +- 7 ms ......................................... wide_state_25x300: Mean +- std dev: 23.7 ms +- 0.4 ms ......................................... wide_state_25x300_sync: Mean +- std dev: 15.0 ms +- 0.2 ms ......................................... wide_state_25x300_checkpoint: Mean +- std dev: 286 ms +- 14 ms ......................................... wide_state_25x300_checkpoint_sync: Mean +- std dev: 274 ms +- 12 ms ......................................... wide_state_15x600: Mean +- std dev: 27.6 ms +- 0.4 ms ......................................... wide_state_15x600_sync: Mean +- std dev: 17.4 ms +- 0.1 ms ......................................... wide_state_15x600_checkpoint: Mean +- std dev: 487 ms +- 14 ms ......................................... wide_state_15x600_checkpoint_sync: Mean +- std dev: 473 ms +- 15 ms ......................................... wide_state_9x1200: Mean +- std dev: 27.5 ms +- 0.5 ms ......................................... wide_state_9x1200_sync: Mean +- std dev: 17.5 ms +- 0.3 ms ......................................... wide_state_9x1200_checkpoint: Mean +- std dev: 323 ms +- 16 ms ......................................... wide_state_9x1200_checkpoint_sync: Mean +- std dev: 307 ms +- 13 ms

Check notice on line 1 in libs/langgraph/langgraph/pregel/__init__.py

View workflow job for this annotation

GitHub Actions / benchmark

Comparison against main

+----------------------------------------+---------+-----------------------+ | Benchmark | main | changes | +========================================+=========+=======================+ | fanout_to_subgraph_100x | 628 ms | 614 ms: 1.02x faster | +----------------------------------------+---------+-----------------------+ | react_agent_100x | 346 ms | 348 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_sync | 22.8 ms | 22.9 ms: 1.00x slower | +----------------------------------------+---------+-----------------------+ | react_agent_100x_sync | 275 ms | 277 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint_sync | 36.6 ms | 36.9 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_15x600_sync | 17.3 ms | 17.4 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_sync | 52.1 ms | 52.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_10x_checkpoint_sync | 94.4 ms | 95.1 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | react_agent_10x_checkpoint | 46.6 ms | 47.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | fanout_to_subgraph_100x_sync | 510 ms | 515 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_25x300_sync | 14.9 ms | 15.0 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200_checkpoint | 319 ms | 323 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | wide_state_9x1200_sync | 17.3 ms | 17.5 ms: 1.01x slower | +----------------------------------------+---------+-----------------------+ | Geometric mean | (ref) | 1.00x slower | +----------------------------------------+---------+-----------------------+ Benchmark hidden because not significant (15): fanout_to_subgraph_10x_checkpoint, react_agent_100x_checkpoint_sync, fanout_to_subgraph_10x, wide_state_9x1200_checkpoint_sync, wide_state_25x300, wide_state_9x1200, wide_state_25x300_checkpoint, fanout_to_subgraph_100x_checkpoint, react_agent_100x_checkpoint, wide_state_15x600, react_agent_10x, wide_state_25x300_checkpoint_sync, fanout_to_subgraph_100x_checkpoint_sync, wide_state_15x600_checkpoint, wide_state_15x600_checkpoint_sync

import asyncio
import concurrent
Expand All @@ -18,7 +18,6 @@
Type,
Union,
cast,
get_type_hints,
overload,
)
from uuid import UUID, uuid5
Expand Down Expand Up @@ -117,6 +116,7 @@
patch_config,
patch_configurable,
)
from langgraph.utils.fields import get_enhanced_type_hints
from langgraph.utils.pydantic import create_model
from langgraph.utils.queue import AsyncQueue, SyncQueue # type: ignore[attr-defined]

Expand Down Expand Up @@ -319,8 +319,15 @@
)
+ (
[
ConfigurableFieldSpec(id=name, annotation=typ)
for name, typ in get_type_hints(self.config_type).items()
ConfigurableFieldSpec(
id=name,
annotation=typ,
default=default,
description=description,
)
for name, typ, default, description in get_enhanced_type_hints(
self.config_type
)
]
if self.config_type is not None
else []
Expand Down
43 changes: 42 additions & 1 deletion libs/langgraph/langgraph/utils/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Optional, Type, Union
from typing import Any, Generator, Optional, Type, Union, get_type_hints

from typing_extensions import Annotated, NotRequired, ReadOnly, Required, get_origin

Expand Down Expand Up @@ -106,3 +106,44 @@ def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any:
if _is_optional_type(type_):
return None
return ...


def get_enhanced_type_hints(
type: Type[Any],
) -> Generator[tuple[str, Any, Any, Optional[str]], None, None]:
"""Attempt to extract default values and descriptions from provided type, used for config schema."""
for name, typ in get_type_hints(type).items():
default = None
description = None

# Pydantic models
try:
if hasattr(type, "__fields__") and name in type.__fields__:
field = type.__fields__[name]

if hasattr(field, "description") and field.description is not None:
description = field.description

if hasattr(field, "default") and field.default is not None:
default = field.default
if (
hasattr(default, "__class__")
and getattr(default.__class__, "__name__", "")
== "PydanticUndefinedType"
):
default = None

except (AttributeError, KeyError, TypeError):
pass

# TypedDict, dataclass
try:
if hasattr(type, "__dict__"):
type_dict = getattr(type, "__dict__")

if name in type_dict:
default = type_dict[name]
except (AttributeError, KeyError, TypeError):
pass

yield name, typ, default, description
60 changes: 59 additions & 1 deletion libs/langgraph/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.utils.fields import _is_optional_type, get_field_default
from langgraph.utils.fields import (
_is_optional_type,
get_enhanced_type_hints,
get_field_default,
)
from langgraph.utils.runnable import is_async_callable, is_async_generator

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -227,3 +231,57 @@ class MyGrandChildDict(MyChildDict, total=False):
assert get_field_default("val_12", gcannos["val_12"], MyGrandChildDict) is None
assert get_field_default("val_9", gcannos["val_9"], MyGrandChildDict) is None
assert get_field_default("val_13", gcannos["val_13"], MyGrandChildDict) == ...


def test_enhanced_type_hints() -> None:
from dataclasses import dataclass
from typing import Annotated

from pydantic import BaseModel, Field

class MyTypedDict(TypedDict):
val_1: str
val_2: int = 42
val_3: str = "default"

hints = list(get_enhanced_type_hints(MyTypedDict))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", None)

@dataclass
class MyDataclass:
val_1: str
val_2: int = 42
val_3: str = "default"

hints = list(get_enhanced_type_hints(MyDataclass))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", None)

class MyPydanticModel(BaseModel):
val_1: str
val_2: int = 42
val_3: str = Field(default="default", description="A description")

hints = list(get_enhanced_type_hints(MyPydanticModel))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", "A description")

class MyPydanticModelWithAnnotated(BaseModel):
val_1: Annotated[str, Field(description="A description")]
val_2: Annotated[int, Field(default=42)]
val_3: Annotated[
str, Field(default="default", description="Another description")
]

hints = list(get_enhanced_type_hints(MyPydanticModelWithAnnotated))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, "A description")
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", "Another description")
Loading