From 2f6ae41a727b46c7fbe713a312860c54d5c5d397 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Fri, 14 Jun 2024 17:52:32 +0200 Subject: [PATCH] feat: Extend YAML serialization to allow Python tuples (#7853) --- haystack/marshal/yaml.py | 12 +++++- ...zation-tuple-support-ffe176417e7099f5.yaml | 4 ++ test/marshal/__init__.py | 3 ++ test/marshal/test_yaml.py | 42 +++++++++++++++++++ 4 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 releasenotes/notes/serialization-tuple-support-ffe176417e7099f5.yaml create mode 100644 test/marshal/__init__.py create mode 100644 test/marshal/test_yaml.py diff --git a/haystack/marshal/yaml.py b/haystack/marshal/yaml.py index 94d3e37761..c25498a3c5 100644 --- a/haystack/marshal/yaml.py +++ b/haystack/marshal/yaml.py @@ -7,6 +7,16 @@ import yaml +# Custom YAML safe loader that supports loading Python tuples +class YamlLoader(yaml.SafeLoader): # pylint: disable=too-many-ancestors + def construct_python_tuple(self, node: yaml.SequenceNode): + """Construct a Python tuple from the sequence.""" + return tuple(self.construct_sequence(node)) + + +YamlLoader.add_constructor("tag:yaml.org,2002:python/tuple", YamlLoader.construct_python_tuple) + + class YamlMarshaller: def marshal(self, dict_: Dict[str, Any]) -> str: """Return a YAML representation of the given dictionary.""" @@ -14,4 +24,4 @@ def marshal(self, dict_: Dict[str, Any]) -> str: def unmarshal(self, data_: Union[str, bytes, bytearray]) -> Dict[str, Any]: """Return a dictionary from the given YAML data.""" - return yaml.safe_load(data_) + return yaml.load(data_, Loader=YamlLoader) diff --git a/releasenotes/notes/serialization-tuple-support-ffe176417e7099f5.yaml b/releasenotes/notes/serialization-tuple-support-ffe176417e7099f5.yaml new file mode 100644 index 0000000000..a3053ee6f2 --- /dev/null +++ b/releasenotes/notes/serialization-tuple-support-ffe176417e7099f5.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Pipeline serialization to YAML now supports tuples as field values. diff --git a/test/marshal/__init__.py b/test/marshal/__init__.py new file mode 100644 index 0000000000..c1764a6e03 --- /dev/null +++ b/test/marshal/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/test/marshal/test_yaml.py b/test/marshal/test_yaml.py new file mode 100644 index 0000000000..552370e9ac --- /dev/null +++ b/test/marshal/test_yaml.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from haystack.marshal.yaml import YamlMarshaller + + +@pytest.fixture +def yaml_data(): + return {"key": "value", 1: 0.221, "list": [1, 2, 3], "tuple": (1, None, True), "dict": {"set": {False}}} + + +@pytest.fixture +def serialized_yaml_str(): + return """key: value +1: 0.221 +list: +- 1 +- 2 +- 3 +tuple: !!python/tuple +- 1 +- null +- true +dict: + set: !!set + false: null +""" + + +def test_yaml_marshal(yaml_data, serialized_yaml_str): + marshaller = YamlMarshaller() + marshalled = marshaller.marshal(yaml_data) + assert isinstance(marshalled, str) + assert marshalled.strip().replace("\n", "") == serialized_yaml_str.strip().replace("\n", "") + + +def test_yaml_unmarshal(yaml_data, serialized_yaml_str): + marshaller = YamlMarshaller() + unmarshalled = marshaller.unmarshal(serialized_yaml_str) + assert unmarshalled == yaml_data