Skip to content

Commit

Permalink
core[patch]: Add pydantic metadata to subset model (#25032)
Browse files Browse the repository at this point in the history
- **Description:** This includes Pydantic field metadata in
`_create_subset_model_v2` so that it gets included in the final
serialized form that get sent out.
- **Issue:** #25031 
- **Dependencies:** n/a
- **Twitter handle:** @gramliu

---------

Co-authored-by: Bagatur <[email protected]>
Co-authored-by: Bagatur <[email protected]>
  • Loading branch information
3 people authored Aug 6, 2024
1 parent 8f33fce commit 88a9a6a
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 12 deletions.
4 changes: 3 additions & 1 deletion libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,10 +1055,12 @@ def add(a: int, b: int) -> int:
)


# TODO: Type args_schema as TypeBaseModel if we can get mypy to correctly recognize
# pydantic v2 BaseModel classes.
def tool(
*args: Union[str, Callable, Runnable],
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
args_schema: Optional[Type] = None,
infer_schema: bool = True,
response_format: Literal["content", "content_and_artifact"] = "content",
parse_docstring: bool = False,
Expand Down
16 changes: 7 additions & 9 deletions libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ def get_pydantic_major_version() -> int:
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel # pydantic: ignore

# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
Expand Down Expand Up @@ -199,12 +197,12 @@ def _create_subset_model_v1(

def _create_subset_model_v2(
name: str,
model: Type[BaseModel],
model: Type[pydantic.BaseModel],
field_names: List[str],
*,
descriptions: Optional[dict] = None,
fn_description: Optional[str] = None,
) -> Type[BaseModel]:
) -> Type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import create_model # pydantic: ignore
from pydantic.fields import FieldInfo # pydantic: ignore
Expand All @@ -214,10 +212,10 @@ def _create_subset_model_v2(
for field_name in field_names:
field = model.model_fields[field_name] # type: ignore
description = descriptions_.get(field_name, field.description)
fields[field_name] = (
field.annotation,
FieldInfo(description=description, default=field.default),
)
field_info = FieldInfo(description=description, default=field.default)
if field.metadata:
field_info.metadata = field.metadata
fields[field_name] = (field.annotation, field_info)
rtn = create_model(name, **fields) # type: ignore

rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
Expand All @@ -230,7 +228,7 @@ def _create_subset_model_v2(
# However, can't find a way to type hint this.
def _create_subset_model(
name: str,
model: Type[BaseModel],
model: TypeBaseModel,
field_names: List[str],
*,
descriptions: Optional[dict] = None,
Expand Down
38 changes: 38 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,3 +1863,41 @@ class ModelD(ModelC, Generic[D]):
}
actual = _get_all_basemodel_annotations(ModelD[int])
assert actual == expected


@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic import Field as FieldV2 # pydantic: ignore
from pydantic import ValidationError as ValidationErrorV2 # pydantic: ignore

class Foo(BaseModelV2):
x: List[int] = FieldV2(
description="List of integers", min_length=10, max_length=15
)

@tool(args_schema=Foo)
def foo(x): # type: ignore[no-untyped-def]
"""foo"""
return x

assert foo.tool_call_schema.schema() == {
"description": "foo",
"properties": {
"x": {
"description": "List of integers",
"items": {"type": "integer"},
"maxItems": 15,
"minItems": 10,
"title": "X",
"type": "array",
}
},
"required": ["x"],
"title": "foo",
"type": "object",
}

assert foo.invoke({"x": [0] * 10})
with pytest.raises(ValidationErrorV2):
foo.invoke({"x": [0] * 9})
34 changes: 33 additions & 1 deletion libs/core/tests/unit_tests/utils/test_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Test for some custom pydantic decorators."""

from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import pytest

from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
_create_subset_model_v2,
is_basemodel_instance,
is_basemodel_subclass,
pre_init,
Expand Down Expand Up @@ -121,3 +124,32 @@ class Bar(BaseModelV1):
assert is_basemodel_instance(Bar(x=5))
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")


@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")
def test_with_field_metadata() -> None:
"""Test pydantic with field metadata"""
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic import Field as FieldV2 # pydantic: ignore

class Foo(BaseModelV2):
x: List[int] = FieldV2(
description="List of integers", min_length=10, max_length=15
)

subset_model = _create_subset_model_v2("Foo", Foo, ["x"])
assert subset_model.model_json_schema() == {
"properties": {
"x": {
"description": "List of integers",
"items": {"type": "integer"},
"maxItems": 15,
"minItems": 10,
"title": "X",
"type": "array",
}
},
"required": ["x"],
"title": "Foo",
"type": "object",
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import tool
from pydantic import BaseModel as RawBaseModel
from pydantic import Field as RawField

from langchain_standard_tests.unit_tests.chat_models import (
ChatModelTests,
Expand All @@ -26,7 +28,11 @@
from langchain_standard_tests.utils.pydantic import PYDANTIC_MAJOR_VERSION


@tool
class MagicFunctionSchema(RawBaseModel):
input: int = RawField(..., gt=-1000, lt=1000)


@tool(args_schema=MagicFunctionSchema)
def magic_function(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2
Expand Down

0 comments on commit 88a9a6a

Please sign in to comment.