Skip to content

Commit

Permalink
fix: Adjust serialization to handle PEP-585 generic types (#7690)
Browse files Browse the repository at this point in the history
* Adjust serialization to handle PEP-585 generic types

* Add reno note

* Simplify

* PEP 585 serialization handling in sys.version_info < (3, 9)
  • Loading branch information
vblagoje authored and davidsbatista committed May 16, 2024
1 parent dfc243c commit f19934b
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
24 changes: 20 additions & 4 deletions haystack/utils/type_serialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib
import inspect
import sys
from typing import Any, get_origin
import typing
from typing import Any, get_args, get_origin

from haystack import DeserializationError

Expand All @@ -28,19 +29,23 @@ def serialize_type(target: Any) -> str:
# Determine if the target is a type or an instance of a typing object
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
type_obj = target if is_type_or_typing else type(target)
module = inspect.getmodule(type_obj)
type_obj_repr = repr(type_obj)

if type_obj_repr.startswith("typing."):
# e.g., typing.List[int] -> List[int], we'll add the module below
type_name = type_obj_repr.split(".", 1)[1]
elif origin := get_origin(type_obj): # get the origin (base type of the parameterized generic type)
# get the arguments of the generic type
args = get_args(type_obj)
args_repr = ", ".join(serialize_type(arg) for arg in args)
type_name = f"{origin.__name__}[{args_repr}]"
elif hasattr(type_obj, "__name__"):
type_name = type_obj.__name__
else:
# If type cannot be serialized, raise an error
raise ValueError(f"Could not serialize type: {type_obj_repr}")

# Construct the full path with module name if available
module = inspect.getmodule(type_obj)
if module and hasattr(module, "__name__"):
if module.__name__ == "builtins":
# omit the module name for builtins, it just clutters the output
Expand Down Expand Up @@ -69,6 +74,14 @@ def deserialize_type(type_str: str) -> Any:
If the type cannot be deserialized due to missing module or type.
"""

type_mapping = {
list: typing.List,
dict: typing.Dict,
set: typing.Set,
tuple: typing.Tuple,
frozenset: typing.FrozenSet,
}

def parse_generic_args(args_str):
args = []
bracket_count = 0
Expand Down Expand Up @@ -100,7 +113,10 @@ def parse_generic_args(args_str):
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))

# Reconstruct
return main_type[generic_args]
if sys.version_info >= (3, 9) or repr(main_type).startswith("typing."):
return main_type[generic_args]
else:
return type_mapping[main_type][generic_args] # type: ignore

else:
# Handle non-generics
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Improves/fixes type serialization of PEP 585 types (e.g. list[Document], and their nested version). This improvement enables better serialization of generics and nested types and improves/fixes matching of list[X] and List[X] types in component connections after serialization.
27 changes: 27 additions & 0 deletions test/utils/test_type_serialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import copy
import sys
import typing
from typing import List, Dict

import pytest

from haystack.dataclasses import ChatMessage
from haystack.components.routers.conditional_router import serialize_type, deserialize_type

Expand All @@ -21,6 +24,18 @@ def test_output_type_serialization():
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"


@pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP 585 types are only available in Python 3.9+")
def test_output_type_serialization_pep585():
# Only Python 3.9+ supports PEP 585 types and can serialize them
# PEP 585 types
assert serialize_type(list[int]) == "list[int]"
assert serialize_type(list[list[int]]) == "list[list[int]]"

# more nested types
assert serialize_type(list[list[list[int]]]) == "list[list[list[int]]]"
assert serialize_type(dict[str, int]) == "dict[str, int]"


def test_output_type_deserialization():
assert deserialize_type("str") == str
assert deserialize_type("typing.List[int]") == typing.List[int]
Expand All @@ -35,3 +50,15 @@ def test_output_type_deserialization():
)
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
assert deserialize_type("int") == int


def test_output_type_deserialization_pep585():
is_pep585 = sys.version_info >= (3, 9)

# Although only Python 3.9+ supports PEP 585 types, we can still deserialize them in older Python versions
# as their typing equivalents
assert deserialize_type("list[int]") == list[int] if is_pep585 else List[int]
assert deserialize_type("dict[str, int]") == dict[str, int] if is_pep585 else Dict[str, int]
# more nested types
assert deserialize_type("list[list[int]]") == list[list[int]] if is_pep585 else List[List[int]]
assert deserialize_type("list[list[list[int]]]") == list[list[list[int]]] if is_pep585 else List[List[List[int]]]

0 comments on commit f19934b

Please sign in to comment.