From d19e0743742efe0b7f9bc8391cba90e3c0a5d666 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Sun, 1 Sep 2024 16:44:33 -0700 Subject: [PATCH] core[patch]: handle serializable fields that cant be converted to bool (#25903) --- libs/core/langchain_core/load/serializable.py | 22 ++++++++++- .../unit_tests/load/test_serializable.py | 38 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/load/serializable.py b/libs/core/langchain_core/load/serializable.py index cf00cad514ef8..06e46b2667c0b 100644 --- a/libs/core/langchain_core/load/serializable.py +++ b/libs/core/langchain_core/load/serializable.py @@ -262,7 +262,27 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool: field = inst.__fields__.get(key) if not field: return False - return field.required is True or value or field.get_default() != value + # Handle edge case: a value cannot be converted to a boolean (e.g. a + # Pandas DataFrame). + try: + value_is_truthy = bool(value) + except Exception as _: + value_is_truthy = False + + # Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two + # Pandas DataFrames). + try: + value_neq_default = bool(field.get_default() != value) + except Exception as _: + try: + value_neq_default = all(field.get_default() != value) + except Exception as _: + try: + value_neq_default = value is not field.default + except Exception as _: + value_neq_default = False + + return field.required is True or value_is_truthy or value_neq_default def _replace_secrets( diff --git a/libs/core/tests/unit_tests/load/test_serializable.py b/libs/core/tests/unit_tests/load/test_serializable.py index ce6964933f8f7..9ef83a6c89c20 100644 --- a/libs/core/tests/unit_tests/load/test_serializable.py +++ b/libs/core/tests/unit_tests/load/test_serializable.py @@ -1,6 +1,8 @@ from typing import Dict from langchain_core.load import Serializable, dumpd +from langchain_core.load.serializable import _is_field_useful +from langchain_core.pydantic_v1 import Field def test_simple_serialization() -> None: @@ -69,3 +71,39 @@ def lc_secrets(self) -> Dict[str, str]: "lc": 1, "type": "constructor", } + + +def test__is_field_useful() -> None: + class ArrayObj: + def __bool__(self) -> bool: + raise ValueError("Truthiness can't be determined") + + def __eq__(self, other: object) -> bool: + return self # type: ignore[return-value] + + class NonBoolObj: + def __bool__(self) -> bool: + raise ValueError("Truthiness can't be determined") + + def __eq__(self, other: object) -> bool: + raise ValueError("Equality can't be determined") + + default_x = ArrayObj() + default_y = NonBoolObj() + + class Foo(Serializable): + x: ArrayObj = Field(default=default_x) + y: NonBoolObj = Field(default=default_y) + # Make sure works for fields without default. + z: ArrayObj + + class Config: + arbitrary_types_allowed = True + + foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj()) + assert _is_field_useful(foo, "x", foo.x) + assert _is_field_useful(foo, "y", foo.y) + + foo = Foo(x=default_x, y=default_y, z=ArrayObj()) + assert not _is_field_useful(foo, "x", foo.x) + assert not _is_field_useful(foo, "y", foo.y)