Skip to content

Commit

Permalink
Merge pull request #2691 from langchain-ai/dqbd/enhanced-config-type-…
Browse files Browse the repository at this point in the history
…extraction

fix(config): extract default values, description from pydantic models, typeddict and dataclass
  • Loading branch information
nfcampos authored Dec 10, 2024
2 parents 70a5ef6 + 17c1a8d commit 1fd9da6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 5 deletions.
13 changes: 10 additions & 3 deletions libs/langgraph/langgraph/pregel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
Type,
Union,
cast,
get_type_hints,
overload,
)
from uuid import UUID, uuid5
Expand Down Expand Up @@ -117,6 +116,7 @@
patch_config,
patch_configurable,
)
from langgraph.utils.fields import get_enhanced_type_hints
from langgraph.utils.pydantic import create_model
from langgraph.utils.queue import AsyncQueue, SyncQueue # type: ignore[attr-defined]

Expand Down Expand Up @@ -319,8 +319,15 @@ def config_specs(self) -> list[ConfigurableFieldSpec]:
)
+ (
[
ConfigurableFieldSpec(id=name, annotation=typ)
for name, typ in get_type_hints(self.config_type).items()
ConfigurableFieldSpec(
id=name,
annotation=typ,
default=default,
description=description,
)
for name, typ, default, description in get_enhanced_type_hints(
self.config_type
)
]
if self.config_type is not None
else []
Expand Down
43 changes: 42 additions & 1 deletion libs/langgraph/langgraph/utils/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Any, Optional, Type, Union
from typing import Any, Generator, Optional, Type, Union, get_type_hints

from typing_extensions import Annotated, NotRequired, ReadOnly, Required, get_origin

Expand Down Expand Up @@ -106,3 +106,44 @@ def get_field_default(name: str, type_: Any, schema: Type[Any]) -> Any:
if _is_optional_type(type_):
return None
return ...


def get_enhanced_type_hints(
type: Type[Any],
) -> Generator[tuple[str, Any, Any, Optional[str]], None, None]:
"""Attempt to extract default values and descriptions from provided type, used for config schema."""
for name, typ in get_type_hints(type).items():
default = None
description = None

# Pydantic models
try:
if hasattr(type, "__fields__") and name in type.__fields__:
field = type.__fields__[name]

if hasattr(field, "description") and field.description is not None:
description = field.description

if hasattr(field, "default") and field.default is not None:
default = field.default
if (
hasattr(default, "__class__")
and getattr(default.__class__, "__name__", "")
== "PydanticUndefinedType"
):
default = None

except (AttributeError, KeyError, TypeError):
pass

# TypedDict, dataclass
try:
if hasattr(type, "__dict__"):
type_dict = getattr(type, "__dict__")

if name in type_dict:
default = type_dict[name]
except (AttributeError, KeyError, TypeError):
pass

yield name, typ, default, description
60 changes: 59 additions & 1 deletion libs/langgraph/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@

from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.utils.fields import _is_optional_type, get_field_default
from langgraph.utils.fields import (
_is_optional_type,
get_enhanced_type_hints,
get_field_default,
)
from langgraph.utils.runnable import is_async_callable, is_async_generator

pytestmark = pytest.mark.anyio
Expand Down Expand Up @@ -227,3 +231,57 @@ class MyGrandChildDict(MyChildDict, total=False):
assert get_field_default("val_12", gcannos["val_12"], MyGrandChildDict) is None
assert get_field_default("val_9", gcannos["val_9"], MyGrandChildDict) is None
assert get_field_default("val_13", gcannos["val_13"], MyGrandChildDict) == ...


def test_enhanced_type_hints() -> None:
from dataclasses import dataclass
from typing import Annotated

from pydantic import BaseModel, Field

class MyTypedDict(TypedDict):
val_1: str
val_2: int = 42
val_3: str = "default"

hints = list(get_enhanced_type_hints(MyTypedDict))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", None)

@dataclass
class MyDataclass:
val_1: str
val_2: int = 42
val_3: str = "default"

hints = list(get_enhanced_type_hints(MyDataclass))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", None)

class MyPydanticModel(BaseModel):
val_1: str
val_2: int = 42
val_3: str = Field(default="default", description="A description")

hints = list(get_enhanced_type_hints(MyPydanticModel))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, None)
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", "A description")

class MyPydanticModelWithAnnotated(BaseModel):
val_1: Annotated[str, Field(description="A description")]
val_2: Annotated[int, Field(default=42)]
val_3: Annotated[
str, Field(default="default", description="Another description")
]

hints = list(get_enhanced_type_hints(MyPydanticModelWithAnnotated))
assert len(hints) == 3
assert hints[0] == ("val_1", str, None, "A description")
assert hints[1] == ("val_2", int, 42, None)
assert hints[2] == ("val_3", str, "default", "Another description")

0 comments on commit 1fd9da6

Please sign in to comment.