From f19934bcbae8c6f9e4969904266569a7367e5ebe Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 15 May 2024 14:25:19 +0200 Subject: [PATCH] fix: Adjust serialization to handle PEP-585 generic types (#7690) * Adjust serialization to handle PEP-585 generic types * Add reno note * Simplify * PEP 585 serialization handling in sys.version_info < (3, 9) --- haystack/utils/type_serialization.py | 24 ++++++++++++++--- ...erialization-support-18822a5b978b1e77.yaml | 4 +++ test/utils/test_type_serialization.py | 27 +++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) create mode 100644 releasenotes/notes/improve-type-serialization-support-18822a5b978b1e77.yaml diff --git a/haystack/utils/type_serialization.py b/haystack/utils/type_serialization.py index 0e133c010b..a750cf1c69 100644 --- a/haystack/utils/type_serialization.py +++ b/haystack/utils/type_serialization.py @@ -1,7 +1,8 @@ import importlib import inspect import sys -from typing import Any, get_origin +import typing +from typing import Any, get_args, get_origin from haystack import DeserializationError @@ -28,19 +29,23 @@ def serialize_type(target: Any) -> str: # Determine if the target is a type or an instance of a typing object is_type_or_typing = isinstance(target, type) or bool(get_origin(target)) type_obj = target if is_type_or_typing else type(target) - module = inspect.getmodule(type_obj) type_obj_repr = repr(type_obj) if type_obj_repr.startswith("typing."): # e.g., typing.List[int] -> List[int], we'll add the module below type_name = type_obj_repr.split(".", 1)[1] + elif origin := get_origin(type_obj): # get the origin (base type of the parameterized generic type) + # get the arguments of the generic type + args = get_args(type_obj) + args_repr = ", ".join(serialize_type(arg) for arg in args) + type_name = f"{origin.__name__}[{args_repr}]" elif hasattr(type_obj, "__name__"): type_name = type_obj.__name__ else: # If type cannot be serialized, raise an error raise ValueError(f"Could not serialize type: {type_obj_repr}") - # Construct the full path with module name if available + module = inspect.getmodule(type_obj) if module and hasattr(module, "__name__"): if module.__name__ == "builtins": # omit the module name for builtins, it just clutters the output @@ -69,6 +74,14 @@ def deserialize_type(type_str: str) -> Any: If the type cannot be deserialized due to missing module or type. """ + type_mapping = { + list: typing.List, + dict: typing.Dict, + set: typing.Set, + tuple: typing.Tuple, + frozenset: typing.FrozenSet, + } + def parse_generic_args(args_str): args = [] bracket_count = 0 @@ -100,7 +113,10 @@ def parse_generic_args(args_str): generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str)) # Reconstruct - return main_type[generic_args] + if sys.version_info >= (3, 9) or repr(main_type).startswith("typing."): + return main_type[generic_args] + else: + return type_mapping[main_type][generic_args] # type: ignore else: # Handle non-generics diff --git a/releasenotes/notes/improve-type-serialization-support-18822a5b978b1e77.yaml b/releasenotes/notes/improve-type-serialization-support-18822a5b978b1e77.yaml new file mode 100644 index 0000000000..b2c96fe488 --- /dev/null +++ b/releasenotes/notes/improve-type-serialization-support-18822a5b978b1e77.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Improves/fixes type serialization of PEP 585 types (e.g. list[Document], and their nested version). This improvement enables better serialization of generics and nested types and improves/fixes matching of list[X] and List[X] types in component connections after serialization. diff --git a/test/utils/test_type_serialization.py b/test/utils/test_type_serialization.py index aee5968538..b3decfaf47 100644 --- a/test/utils/test_type_serialization.py +++ b/test/utils/test_type_serialization.py @@ -1,7 +1,10 @@ import copy +import sys import typing from typing import List, Dict +import pytest + from haystack.dataclasses import ChatMessage from haystack.components.routers.conditional_router import serialize_type, deserialize_type @@ -21,6 +24,18 @@ def test_output_type_serialization(): assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage" +@pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP 585 types are only available in Python 3.9+") +def test_output_type_serialization_pep585(): + # Only Python 3.9+ supports PEP 585 types and can serialize them + # PEP 585 types + assert serialize_type(list[int]) == "list[int]" + assert serialize_type(list[list[int]]) == "list[list[int]]" + + # more nested types + assert serialize_type(list[list[list[int]]]) == "list[list[list[int]]]" + assert serialize_type(dict[str, int]) == "dict[str, int]" + + def test_output_type_deserialization(): assert deserialize_type("str") == str assert deserialize_type("typing.List[int]") == typing.List[int] @@ -35,3 +50,15 @@ def test_output_type_deserialization(): ) assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage assert deserialize_type("int") == int + + +def test_output_type_deserialization_pep585(): + is_pep585 = sys.version_info >= (3, 9) + + # Although only Python 3.9+ supports PEP 585 types, we can still deserialize them in older Python versions + # as their typing equivalents + assert deserialize_type("list[int]") == list[int] if is_pep585 else List[int] + assert deserialize_type("dict[str, int]") == dict[str, int] if is_pep585 else Dict[str, int] + # more nested types + assert deserialize_type("list[list[int]]") == list[list[int]] if is_pep585 else List[List[int]] + assert deserialize_type("list[list[list[int]]]") == list[list[list[int]]] if is_pep585 else List[List[List[int]]]