Skip to content

Commit

Permalink
Refactor YAML loading to use add_representer
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Prescod committed Jun 6, 2022
1 parent 7844d2c commit 67417fe
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 33 deletions.
2 changes: 1 addition & 1 deletion snowfakery/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def load_continuation_yaml(continuation_file: OpenFileLike):
def save_continuation_yaml(continuation_data: Globals, continuation_file: OpenFileLike):
"""Save the global interpreter state from Globals into a continuation_file"""
yaml.dump(
continuation_data.__getstate__(),
continuation_data,
continuation_file,
Dumper=SnowfakeryDumper,
)
Expand Down
18 changes: 8 additions & 10 deletions snowfakery/data_generator_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from snowfakery.plugins import PluginContext, SnowfakeryPlugin, ScalarTypes
from snowfakery.utils.collections import OrderedSet
from snowfakery.utils.yaml_utils import register_for_continuation

OutputStream = "snowfakery.output_streams.OutputStream"
VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition"
Expand Down Expand Up @@ -60,6 +61,7 @@ def generate_id(self, table_name: str) -> int:
def __getitem__(self, table_name: str) -> int:
return self.last_used_ids[table_name]

# TODO: Fix this to use the new convention of get_continuation_data
def __getstate__(self):
return {"last_used_ids": dict(self.last_used_ids)}

Expand Down Expand Up @@ -195,21 +197,14 @@ def check_slots_filled(self):
def first_new_id(self, tablename):
return self.transients.first_new_id(tablename)

def __getstate__(self):
def serialize_dict_of_object_rows(dct):
return {k: v.__getstate__() for k, v in dct.items()}

persistent_nicknames = serialize_dict_of_object_rows(self.persistent_nicknames)
persistent_objects_by_table = serialize_dict_of_object_rows(
self.persistent_objects_by_table
)
def get_continuation_state(self):
intertable_dependencies = [
dict(v._asdict()) for v in self.intertable_dependencies
] # converts ordered-dict to dict for Python 3.6 and 3.7

state = {
"persistent_nicknames": persistent_nicknames,
"persistent_objects_by_table": persistent_objects_by_table,
"persistent_nicknames": self.persistent_nicknames,
"persistent_objects_by_table": self.persistent_objects_by_table,
"id_manager": self.id_manager.__getstate__(),
"today": self.today,
"nicknames_and_tables": self.nicknames_and_tables,
Expand Down Expand Up @@ -244,6 +239,9 @@ def deserialize_dict_of_object_rows(dct):
self.reset_slots()


register_for_continuation(Globals, Globals.get_continuation_state)


class JinjaTemplateEvaluatorFactory:
def __init__(self, native_types: bool):
if native_types:
Expand Down
17 changes: 11 additions & 6 deletions snowfakery/object_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import yaml
import snowfakery # noqa
from .utils.yaml_utils import SnowfakeryDumper
from .utils.yaml_utils import register_for_continuation
from contextvars import ContextVar

IdManager = "snowfakery.data_generator_runtime.IdManager"
Expand All @@ -14,10 +14,6 @@ class ObjectRow:
Uses __getattr__ so that the template evaluator can use dot-notation."""

yaml_loader = yaml.SafeLoader
yaml_dumper = SnowfakeryDumper
yaml_tag = "!snowfakery_objectrow"

# be careful changing these slots because these objects must be serializable
# to YAML and JSON
__slots__ = ["_tablename", "_values", "_child_index"]
Expand Down Expand Up @@ -49,11 +45,17 @@ def __repr__(self):
except Exception:
return super().__repr__()

def __getstate__(self):
def get_continuation_state(self):
"""Get the state of this ObjectRow for serialization.
Do not include related ObjectRows because circular
references in serialization formats cause problems."""

# If we decided to try to serialize hierarchies, we could
# do it like this:
# * keep track of if an object has already been serialized using a
# property of the SnowfakeryDumper
# * If so, output an ObjectReference instead of an ObjectRow
values = {k: v for k, v in self._values.items() if not isinstance(v, ObjectRow)}
return {"_tablename": self._tablename, "_values": values}

Expand All @@ -62,6 +64,9 @@ def __setstate__(self, state):
setattr(self, slot, value)


register_for_continuation(ObjectRow, ObjectRow.get_continuation_state)


class ObjectReference(yaml.YAMLObject):
def __init__(self, tablename: str, id: int):
self._tablename = tablename
Expand Down
18 changes: 3 additions & 15 deletions snowfakery/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
from functools import wraps
import typing as T

import yaml
from yaml.representer import Representer
from faker.providers import BaseProvider as FakerProvider
from dateutil.relativedelta import relativedelta

import snowfakery.data_gen_exceptions as exc
from .utils.yaml_utils import SnowfakeryDumper
from snowfakery.utils.yaml_utils import register_for_continuation
from .utils.collections import CaseInsensitiveDict

from numbers import Number
Expand Down Expand Up @@ -306,17 +304,7 @@ def _from_continuation(cls, args):

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
_register_for_continuation(cls)


def _register_for_continuation(cls):
SnowfakeryDumper.add_representer(cls, Representer.represent_object)
yaml.SafeLoader.add_constructor(
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
lambda loader, node: cls._from_continuation(
loader.construct_mapping(node.value[0])
),
)
register_for_continuation(cls)


class PluginResultIterator(PluginResult):
Expand Down Expand Up @@ -372,4 +360,4 @@ def convert(self, value):


# round-trip PluginResult objects through continuation YAML if needed.
_register_for_continuation(PluginResult)
register_for_continuation(PluginResult)
29 changes: 28 additions & 1 deletion snowfakery/utils/yaml_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from yaml import SafeDumper
from typing import Callable
from yaml import SafeDumper, SafeLoader
from yaml.representer import Representer


class SnowfakeryDumper(SafeDumper):
Expand All @@ -9,3 +11,28 @@ def hydrate(cls, data):
obj = cls.__new__(cls)
obj.__setstate__(data)
return obj


# Evaluate whether its cleaner for functions to bypass register_for_continuation
# and go directly to SnowfakeryDumper.add_representer.
#
#


def represent_continuation(dumper: SnowfakeryDumper, data):
if isinstance(data, dict):
return Representer.represent_dict(dumper, data)
else:
return Representer.represent_object(dumper, data)


def register_for_continuation(cls, dump_transformer: Callable = lambda x: x):
SnowfakeryDumper.add_representer(
cls, lambda self, data: represent_continuation(self, dump_transformer(data))
)
SafeLoader.add_constructor(
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
lambda loader, node: cls._from_continuation(
loader.construct_mapping(node.value[0])
),
)

0 comments on commit 67417fe

Please sign in to comment.