Skip to content

Commit

Permalink
core[patch]: make get_all_basemodel_annotations public
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Oct 30, 2024
1 parent 8073146 commit fa7c082
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
8 changes: 4 additions & 4 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def args(self) -> dict:
def tool_call_schema(self) -> Type[BaseModel]:
full_schema = self.get_input_schema()
fields = []
for name, type_ in _get_all_basemodel_annotations(full_schema).items():
for name, type_ in get_all_basemodel_annotations(full_schema).items():
if not _is_injected_arg_type(type_):
fields.append(name)
return _create_subset_model(
Expand Down Expand Up @@ -858,7 +858,7 @@ def _is_injected_arg_type(type_: Type) -> bool:
)


def _get_all_basemodel_annotations(
def get_all_basemodel_annotations(
cls: Union[TypeBaseModel, Any], *, default_to_bound: bool = True
) -> Dict[str, Type]:
# cls has no subscript: cls = FooBar
Expand All @@ -876,7 +876,7 @@ def _get_all_basemodel_annotations(
orig_bases: Tuple = getattr(cls, "__orig_bases__", tuple())
# cls has subscript: cls = FooBar[int]
else:
annotations = _get_all_basemodel_annotations(
annotations = get_all_basemodel_annotations(
get_origin(cls), default_to_bound=False
)
orig_bases = (cls,)
Expand All @@ -890,7 +890,7 @@ def _get_all_basemodel_annotations(
# if class = FooBar inherits from Baz, parent = Baz
if isinstance(parent, type) and is_pydantic_v1_subclass(parent):
annotations.update(
_get_all_basemodel_annotations(parent, default_to_bound=False)
get_all_basemodel_annotations(parent, default_to_bound=False)
)
continue

Expand Down
26 changes: 13 additions & 13 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
from langchain_core.tools.base import (
InjectedToolArg,
SchemaAnnotationError,
_get_all_basemodel_annotations,
_is_message_content_block,
_is_message_content_type,
get_all_basemodel_annotations,
)
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, _create_subset_model
Expand Down Expand Up @@ -1773,19 +1773,19 @@ class ModelC(Mixin, ModelB):
c: dict

expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC)
actual = get_all_basemodel_annotations(ModelC)
assert actual == expected

expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB)
actual = get_all_basemodel_annotations(ModelB)
assert actual == expected

expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA)
actual = get_all_basemodel_annotations(ModelA)
assert actual == expected

expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int])
actual = get_all_basemodel_annotations(ModelA[int])
assert actual == expected

D = TypeVar("D", bound=Union[str, int])
Expand All @@ -1799,7 +1799,7 @@ class ModelD(ModelC, Generic[D]):
"c": dict,
"d": Union[str, int, None],
}
actual = _get_all_basemodel_annotations(ModelD)
actual = get_all_basemodel_annotations(ModelD)
assert actual == expected

expected = {
Expand All @@ -1808,7 +1808,7 @@ class ModelD(ModelC, Generic[D]):
"c": dict,
"d": Union[int, None],
}
actual = _get_all_basemodel_annotations(ModelD[int])
actual = get_all_basemodel_annotations(ModelD[int])
assert actual == expected


Expand All @@ -1830,19 +1830,19 @@ class ModelC(Mixin, ModelB):
c: dict

expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"], "c": dict}
actual = _get_all_basemodel_annotations(ModelC)
actual = get_all_basemodel_annotations(ModelC)
assert actual == expected

expected = {"a": str, "b": Annotated[ModelA[Dict[str, Any]], "foo"]}
actual = _get_all_basemodel_annotations(ModelB)
actual = get_all_basemodel_annotations(ModelB)
assert actual == expected

expected = {"a": Any}
actual = _get_all_basemodel_annotations(ModelA)
actual = get_all_basemodel_annotations(ModelA)
assert actual == expected

expected = {"a": int}
actual = _get_all_basemodel_annotations(ModelA[int])
actual = get_all_basemodel_annotations(ModelA[int])
assert actual == expected

D = TypeVar("D", bound=Union[str, int])
Expand All @@ -1856,7 +1856,7 @@ class ModelD(ModelC, Generic[D]):
"c": dict,
"d": Union[str, int, None],
}
actual = _get_all_basemodel_annotations(ModelD)
actual = get_all_basemodel_annotations(ModelD)
assert actual == expected

expected = {
Expand All @@ -1865,7 +1865,7 @@ class ModelD(ModelC, Generic[D]):
"c": dict,
"d": Union[int, None],
}
actual = _get_all_basemodel_annotations(ModelD[int])
actual = get_all_basemodel_annotations(ModelD[int])
assert actual == expected


Expand Down

0 comments on commit fa7c082

Please sign in to comment.