Skip to content

Commit

Permalink
core[patch]: handle serializable fields that cant be converted to bool (
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Sep 1, 2024
1 parent 7f857a0 commit d19e074
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
22 changes: 21 additions & 1 deletion libs/core/langchain_core/load/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,27 @@ def _is_field_useful(inst: Serializable, key: str, value: Any) -> bool:
field = inst.__fields__.get(key)
if not field:
return False
return field.required is True or value or field.get_default() != value
# Handle edge case: a value cannot be converted to a boolean (e.g. a
# Pandas DataFrame).
try:
value_is_truthy = bool(value)
except Exception as _:
value_is_truthy = False

# Handle edge case: inequality of two objects does not evaluate to a bool (e.g. two
# Pandas DataFrames).
try:
value_neq_default = bool(field.get_default() != value)
except Exception as _:
try:
value_neq_default = all(field.get_default() != value)
except Exception as _:
try:
value_neq_default = value is not field.default
except Exception as _:
value_neq_default = False

return field.required is True or value_is_truthy or value_neq_default


def _replace_secrets(
Expand Down
38 changes: 38 additions & 0 deletions libs/core/tests/unit_tests/load/test_serializable.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict

from langchain_core.load import Serializable, dumpd
from langchain_core.load.serializable import _is_field_useful
from langchain_core.pydantic_v1 import Field


def test_simple_serialization() -> None:
Expand Down Expand Up @@ -69,3 +71,39 @@ def lc_secrets(self) -> Dict[str, str]:
"lc": 1,
"type": "constructor",
}


def test__is_field_useful() -> None:
class ArrayObj:
def __bool__(self) -> bool:
raise ValueError("Truthiness can't be determined")

def __eq__(self, other: object) -> bool:
return self # type: ignore[return-value]

class NonBoolObj:
def __bool__(self) -> bool:
raise ValueError("Truthiness can't be determined")

def __eq__(self, other: object) -> bool:
raise ValueError("Equality can't be determined")

default_x = ArrayObj()
default_y = NonBoolObj()

class Foo(Serializable):
x: ArrayObj = Field(default=default_x)
y: NonBoolObj = Field(default=default_y)
# Make sure works for fields without default.
z: ArrayObj

class Config:
arbitrary_types_allowed = True

foo = Foo(x=ArrayObj(), y=NonBoolObj(), z=ArrayObj())
assert _is_field_useful(foo, "x", foo.x)
assert _is_field_useful(foo, "y", foo.y)

foo = Foo(x=default_x, y=default_y, z=ArrayObj())
assert not _is_field_useful(foo, "x", foo.x)
assert not _is_field_useful(foo, "y", foo.y)

0 comments on commit d19e074

Please sign in to comment.