From 67417feba2c0ec1d15c4fc0a5727068767c300e0 Mon Sep 17 00:00:00 2001 From: Paul Prescod Date: Mon, 6 Jun 2022 10:59:44 -0700 Subject: [PATCH] Refactor YAML loading to use add_representer --- snowfakery/data_generator.py | 2 +- snowfakery/data_generator_runtime.py | 18 ++++++++--------- snowfakery/object_rows.py | 17 ++++++++++------ snowfakery/plugins.py | 18 +++-------------- snowfakery/utils/yaml_utils.py | 29 +++++++++++++++++++++++++++- 5 files changed, 51 insertions(+), 33 deletions(-) diff --git a/snowfakery/data_generator.py b/snowfakery/data_generator.py index 53c1b6a6..89d385e9 100644 --- a/snowfakery/data_generator.py +++ b/snowfakery/data_generator.py @@ -95,7 +95,7 @@ def load_continuation_yaml(continuation_file: OpenFileLike): def save_continuation_yaml(continuation_data: Globals, continuation_file: OpenFileLike): """Save the global interpreter state from Globals into a continuation_file""" yaml.dump( - continuation_data.__getstate__(), + continuation_data, continuation_file, Dumper=SnowfakeryDumper, ) diff --git a/snowfakery/data_generator_runtime.py b/snowfakery/data_generator_runtime.py index 400cb9c0..4f3d4740 100644 --- a/snowfakery/data_generator_runtime.py +++ b/snowfakery/data_generator_runtime.py @@ -27,6 +27,7 @@ ) from snowfakery.plugins import PluginContext, SnowfakeryPlugin, ScalarTypes from snowfakery.utils.collections import OrderedSet +from snowfakery.utils.yaml_utils import register_for_continuation OutputStream = "snowfakery.output_streams.OutputStream" VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition" @@ -60,6 +61,7 @@ def generate_id(self, table_name: str) -> int: def __getitem__(self, table_name: str) -> int: return self.last_used_ids[table_name] + # TODO: Fix this to use the new convention of get_continuation_data def __getstate__(self): return {"last_used_ids": dict(self.last_used_ids)} @@ -195,21 +197,14 @@ def check_slots_filled(self): def first_new_id(self, tablename): return self.transients.first_new_id(tablename) - def __getstate__(self): - def serialize_dict_of_object_rows(dct): - return {k: v.__getstate__() for k, v in dct.items()} - - persistent_nicknames = serialize_dict_of_object_rows(self.persistent_nicknames) - persistent_objects_by_table = serialize_dict_of_object_rows( - self.persistent_objects_by_table - ) + def get_continuation_state(self): intertable_dependencies = [ dict(v._asdict()) for v in self.intertable_dependencies ] # converts ordered-dict to dict for Python 3.6 and 3.7 state = { - "persistent_nicknames": persistent_nicknames, - "persistent_objects_by_table": persistent_objects_by_table, + "persistent_nicknames": self.persistent_nicknames, + "persistent_objects_by_table": self.persistent_objects_by_table, "id_manager": self.id_manager.__getstate__(), "today": self.today, "nicknames_and_tables": self.nicknames_and_tables, @@ -244,6 +239,9 @@ def deserialize_dict_of_object_rows(dct): self.reset_slots() +register_for_continuation(Globals, Globals.get_continuation_state) + + class JinjaTemplateEvaluatorFactory: def __init__(self, native_types: bool): if native_types: diff --git a/snowfakery/object_rows.py b/snowfakery/object_rows.py index 3e836a35..a3aa9b9e 100644 --- a/snowfakery/object_rows.py +++ b/snowfakery/object_rows.py @@ -2,7 +2,7 @@ import yaml import snowfakery # noqa -from .utils.yaml_utils import SnowfakeryDumper +from .utils.yaml_utils import register_for_continuation from contextvars import ContextVar IdManager = "snowfakery.data_generator_runtime.IdManager" @@ -14,10 +14,6 @@ class ObjectRow: Uses __getattr__ so that the template evaluator can use dot-notation.""" - yaml_loader = yaml.SafeLoader - yaml_dumper = SnowfakeryDumper - yaml_tag = "!snowfakery_objectrow" - # be careful changing these slots because these objects must be serializable # to YAML and JSON __slots__ = ["_tablename", "_values", "_child_index"] @@ -49,11 +45,17 @@ def __repr__(self): except Exception: return super().__repr__() - def __getstate__(self): + def get_continuation_state(self): """Get the state of this ObjectRow for serialization. Do not include related ObjectRows because circular references in serialization formats cause problems.""" + + # If we decided to try to serialize hierarchies, we could + # do it like this: + # * keep track of if an object has already been serialized using a + # property of the SnowfakeryDumper + # * If so, output an ObjectReference instead of an ObjectRow values = {k: v for k, v in self._values.items() if not isinstance(v, ObjectRow)} return {"_tablename": self._tablename, "_values": values} @@ -62,6 +64,9 @@ def __setstate__(self, state): setattr(self, slot, value) +register_for_continuation(ObjectRow, ObjectRow.get_continuation_state) + + class ObjectReference(yaml.YAMLObject): def __init__(self, tablename: str, id: int): self._tablename = tablename diff --git a/snowfakery/plugins.py b/snowfakery/plugins.py index c7c548ce..71b25429 100644 --- a/snowfakery/plugins.py +++ b/snowfakery/plugins.py @@ -8,13 +8,11 @@ from functools import wraps import typing as T -import yaml -from yaml.representer import Representer from faker.providers import BaseProvider as FakerProvider from dateutil.relativedelta import relativedelta import snowfakery.data_gen_exceptions as exc -from .utils.yaml_utils import SnowfakeryDumper +from snowfakery.utils.yaml_utils import register_for_continuation from .utils.collections import CaseInsensitiveDict from numbers import Number @@ -306,17 +304,7 @@ def _from_continuation(cls, args): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - _register_for_continuation(cls) - - -def _register_for_continuation(cls): - SnowfakeryDumper.add_representer(cls, Representer.represent_object) - yaml.SafeLoader.add_constructor( - f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}", - lambda loader, node: cls._from_continuation( - loader.construct_mapping(node.value[0]) - ), - ) + register_for_continuation(cls) class PluginResultIterator(PluginResult): @@ -372,4 +360,4 @@ def convert(self, value): # round-trip PluginResult objects through continuation YAML if needed. -_register_for_continuation(PluginResult) +register_for_continuation(PluginResult) diff --git a/snowfakery/utils/yaml_utils.py b/snowfakery/utils/yaml_utils.py index 73a5a367..8483d045 100644 --- a/snowfakery/utils/yaml_utils.py +++ b/snowfakery/utils/yaml_utils.py @@ -1,4 +1,6 @@ -from yaml import SafeDumper +from typing import Callable +from yaml import SafeDumper, SafeLoader +from yaml.representer import Representer class SnowfakeryDumper(SafeDumper): @@ -9,3 +11,28 @@ def hydrate(cls, data): obj = cls.__new__(cls) obj.__setstate__(data) return obj + + +# Evaluate whether its cleaner for functions to bypass register_for_continuation +# and go directly to SnowfakeryDumper.add_representer. +# +# + + +def represent_continuation(dumper: SnowfakeryDumper, data): + if isinstance(data, dict): + return Representer.represent_dict(dumper, data) + else: + return Representer.represent_object(dumper, data) + + +def register_for_continuation(cls, dump_transformer: Callable = lambda x: x): + SnowfakeryDumper.add_representer( + cls, lambda self, data: represent_continuation(self, dump_transformer(data)) + ) + SafeLoader.add_constructor( + f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}", + lambda loader, node: cls._from_continuation( + loader.construct_mapping(node.value[0]) + ), + )