From 90056715192c30f2777a9bde8c53f2f629c59c08 Mon Sep 17 00:00:00 2001 From: Shyue Ping Ong Date: Wed, 17 Apr 2024 11:13:46 -0700 Subject: [PATCH] Revert json.py to old version. --- monty/json.py | 293 ++++++++++---------------------------------------- 1 file changed, 55 insertions(+), 238 deletions(-) diff --git a/monty/json.py b/monty/json.py index d7effb19c..4b74eb362 100644 --- a/monty/json.py +++ b/monty/json.py @@ -4,12 +4,10 @@ from __future__ import annotations -import contextlib import datetime import json import os import pathlib -import pickle import traceback import types from collections import OrderedDict, defaultdict @@ -18,8 +16,7 @@ from importlib import import_module from inspect import getfullargspec from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict -from uuid import UUID, uuid4 +from uuid import UUID try: import numpy as np @@ -61,17 +58,12 @@ except ImportError: torch = None # type: ignore -if TYPE_CHECKING: - from typing import Any, Generator, Union - - from typing_extensions import Self - __version__ = "3.0.0" -def _load_redirect(redirect_file: Union[str, Path]) -> dict: +def _load_redirect(redirect_file): try: - with open(redirect_file, encoding="utf-8") as f: + with open(redirect_file) as f: yaml = YAML() d = yaml.load(f) except OSError: @@ -80,7 +72,7 @@ def _load_redirect(redirect_file: Union[str, Path]) -> dict: return {} # Convert the full paths to module/class - redirect_dict: dict = defaultdict(dict) + redirect_dict = defaultdict(dict) for old_path, new_path in d.items(): old_class = old_path.split(".")[-1] old_module = ".".join(old_path.split(".")[:-1]) @@ -96,7 +88,7 @@ def _load_redirect(redirect_file: Union[str, Path]) -> dict: return dict(redirect_dict) -def _check_type(obj: object, type_str: Union[str, tuple[str, ...]]) -> bool: +def _check_type(obj, type_str) -> bool: """Alternative to isinstance that avoids imports. Checks whether obj is an instance of the type defined by type_str. This @@ -124,7 +116,7 @@ class B(A): pass mro = type(obj).mro() except TypeError: return False - return any(f"{o.__module__}.{o.__name__}" == ts for o in mro for ts in type_str) + return any(o.__module__ + "." + o.__name__ == ts for o in mro for ts in type_str) class MSONable: @@ -149,12 +141,12 @@ class MSONable: class MSONClass(MSONable): - def __init__(self, a, b, c, d=1, **kwargs): - self.a = a - self.b = b - self._c = c - self._d = d - self.kwargs = kwargs + def __init__(self, a, b, c, d=1, **kwargs): + self.a = a + self.b = b + self._c = c + self._d = d + self.kwargs = kwargs For such classes, you merely need to inherit from MSONable and you do not need to implement your own as_dict or from_dict protocol. @@ -164,8 +156,8 @@ def __init__(self, a, b, c, d=1, **kwargs): fully qualified path and new fully qualified path into .monty.yaml in the home folder - Examples: - old_module.old_class: new_module.new_class + Example: + old_module.old_class: new_module.new_class """ REDIRECT = _load_redirect(os.path.join(os.path.expanduser("~"), ".monty.yaml")) @@ -209,7 +201,7 @@ def recursive_as_dict(obj): a = getattr(self, c) except AttributeError: try: - a = getattr(self, f"_{c}") + a = getattr(self, "_" + c) except AttributeError: raise NotImplementedError( "Unable to automatically determine as_dict " @@ -232,30 +224,17 @@ def recursive_as_dict(obj): d.update({"value": self.value}) # pylint: disable=E1101 return d - @staticmethod - def decoded_from_dict(d, name_object_map): - decoder = MontyDecoder() - decoder._name_object_map = name_object_map - decoded = { - k: decoder.process_decoded(v) for k, v in d.items() if not k.startswith("@") - } - return decoded - - @classmethod - def _from_dict(cls, d, name_object_map): - decoded = MSONable.decoded_from_dict(d, name_object_map=name_object_map) - return cls(**decoded) - @classmethod - def from_dict(cls, d: dict) -> Self: + def from_dict(cls, d): """ - Args: - d: Dict representation. - - Returns: - MSONable class. + :param d: Dict representation. + :return: MSONable class. """ - decoded = MSONable.decoded_from_dict(d, name_object_map=None) + decoded = { + k: MontyDecoder().process_decoded(v) + for k, v in d.items() + if not k.startswith("@") + } return cls(**decoded) def to_json(self) -> str: @@ -264,111 +243,6 @@ def to_json(self) -> str: """ return json.dumps(self, cls=MontyEncoder) - def save( - self, - save_dir=None, - mkdir=True, - pickle_kwargs=None, - json_kwargs=None, - return_results=False, - strict=True, - ): - """Utility that uses the standard tools of MSONable to convert the - class to json format, but also save it to disk. In addition, this - method intelligently uses pickle to individually pickle class objects - that are not serializable, saving them separately. This maximizes the - readability of the saved class information while allowing _any_ - class to be at least partially serializable to disk. - - For a fully MSONable class, only a class.json file will be saved to - the location {save_dir}/class.json. For a partially MSONable class, - additional information will be saved to the save directory at - {save_dir}. This includes a pickled object for each attribute that - e serialized. - - Parameters - ---------- - save_dir : os.PathLike - The directory to which to save the class information. - mkdir : bool - If True, makes the provided directory, including all parent - directories. - pickle_kwargs : dict - Keyword arguments to pass to pickle.dump. - json_kwargs : dict - Keyword arguments to pass to the serializer. - return_results : bool - If True, also returns the dictionary to save to disk, as well - as the mapping between the object_references and the objects - themselves. - strict : bool - If True, will not allow you to overwrite existing files. - - Returns - ------- - None or tuple - """ - - if save_dir is None and not return_results: - raise ValueError("save_dir must be set and/or return_results must be True") - - if pickle_kwargs is None: - pickle_kwargs = {} - if json_kwargs is None: - json_kwargs = {} - encoder = MontyEncoder(allow_unserializable_objects=True, **json_kwargs) - encoded = encoder.encode(self) - - if save_dir is not None: - save_dir = Path(save_dir) - if mkdir: - save_dir.mkdir(exist_ok=True, parents=True) - json_path = save_dir / "class.json" - pickle_path = save_dir / "class.pkl" - if strict and json_path.exists(): - raise FileExistsError(f"strict is true and file {json_path} exists") - if strict and pickle_path.exists(): - raise FileExistsError(f"strict is true and file {pickle_path} exists") - - with open(json_path, "w") as outfile: - outfile.write(encoded) - pickle.dump( - encoder._name_object_map, - open(pickle_path, "wb"), - **pickle_kwargs, - ) - - if return_results: - return encoded, encoder._name_object_map - - @classmethod - def load(cls, load_dir): - """Loads a class from a provided {load_dir}/class.json and - {load_dir}/class.pkl file (if necessary). - - Parameters - ---------- - load_dir : os.PathLike - The directory from which to reload the class from. - - Returns - ------- - MSONable - An instance of the class being reloaded. - """ - - load_dir = Path(load_dir) - - json_path = load_dir / "class.json" - pickle_path = load_dir / "class.pkl" - - with open(json_path, "r") as infile: - d = json.loads(infile.read()) - name_object_map = pickle.load(open(pickle_path, "rb")) - decoded = MSONable.decoded_from_dict(d, name_object_map) - klass = cls(**decoded) - return klass - def unsafe_hash(self): """ Returns an hash of the current object. This uses a generic but low @@ -376,10 +250,11 @@ def unsafe_hash(self): any nested keys, and then performing a hash on the resulting object """ - def flatten(dct: dict, separator: str = ".") -> dict: - """Flattens a dictionary""" + def flatten(obj, separator="."): + # Flattens a dictionary + flat_dict = {} - for key, value in dct.items(): + for key, value in obj.items(): if isinstance(value, dict): flat_dict.update( { @@ -404,7 +279,7 @@ def flatten(dct: dict, separator: str = ".") -> dict: return sha1(json.dumps(OrderedDict(ordered_keys)).encode("utf-8")) @classmethod - def _validate_monty(cls, __input_value) -> Self: + def _validate_monty(cls, __input_value): """ pydantic Validator for MSONable pattern """ @@ -429,21 +304,21 @@ def _validate_monty(cls, __input_value) -> Self: ) @classmethod - def validate_monty_v1(cls, __input_value) -> Self: + def validate_monty_v1(cls, __input_value): """ Pydantic validator with correct signature for pydantic v1.x """ return cls._validate_monty(__input_value) @classmethod - def validate_monty_v2(cls, __input_value, _) -> Self: + def validate_monty_v2(cls, __input_value, _): """ Pydantic validator with correct signature for pydantic v2.x """ return cls._validate_monty(__input_value) @classmethod - def __get_validators__(cls) -> Generator: + def __get_validators__(cls): """Return validators for use in pydantic""" yield cls.validate_monty_v1 @@ -460,7 +335,7 @@ def __get_pydantic_core_schema__(cls, source_type, handler): return core_schema.json_or_python_schema(json_schema=s, python_schema=s) @classmethod - def _generic_json_schema(cls) -> dict: + def _generic_json_schema(cls): return { "type": "object", "properties": { @@ -472,12 +347,12 @@ def _generic_json_schema(cls) -> dict: } @classmethod - def __get_pydantic_json_schema__(cls, core_schema, handler) -> dict: + def __get_pydantic_json_schema__(cls, core_schema, handler): """JSON schema for MSONable pattern""" return cls._generic_json_schema() @classmethod - def __modify_schema__(cls, field_schema) -> None: + def __modify_schema__(cls, field_schema): """JSON schema for MSONable pattern""" custom_schema = cls._generic_json_schema() field_schema.update(custom_schema) @@ -492,18 +367,6 @@ class MontyEncoder(json.JSONEncoder): json.dumps(object, cls=MontyEncoder) """ - def __init__(self, *args, allow_unserializable_objects=False, **kwargs): - super().__init__(*args, **kwargs) - self._track_unserializable_objects = allow_unserializable_objects - self._name_object_map: Dict[str, Any] = {} - self._index = 0 - - def _update_name_object_map(self, o): - name = f"{self._index:012}-{str(uuid4())}" - self._index += 1 - self._name_object_map[name] = o - return {"@object_reference": name} - def default(self, o) -> dict: # pylint: disable=E0202 """ Overriding default method for JSON encoding. This method does two @@ -511,11 +374,9 @@ def default(self, o) -> dict: # pylint: disable=E0202 output. (b) If the @module and @class keys are not in the to_dict, add them to the output automatically. If the object has no to_dict property, the default Python json encoder default method is called. - Args: o: Python object. - - Returns: + Return: Python dict representation. """ if isinstance(o, datetime.datetime): @@ -573,13 +434,7 @@ def default(self, o) -> dict: # pylint: disable=E0202 return {"@module": "bson.objectid", "@class": "ObjectId", "oid": str(o)} if callable(o) and not isinstance(o, MSONable): - try: - return _serialize_callable(o) - except AttributeError as e: - # Some callables may not have instance __name__ - if self._track_unserializable_objects: - return self._update_name_object_map(o) - raise AttributeError(e) + return _serialize_callable(o) try: if pydantic is not None and isinstance(o, pydantic.BaseModel): @@ -595,11 +450,6 @@ def default(self, o) -> dict: # pylint: disable=E0202 d = o.as_dict() elif isinstance(o, Enum): d = {"value": o.value} - elif self._track_unserializable_objects: - # Last resort logic. We keep track of some name of the object - # as a reference, and instead of the object, store that - # name, which of course is json-serializable - d = self._update_name_object_map(o) else: raise TypeError( f"Object of type {o.__class__.__name__} is not JSON serializable" @@ -636,19 +486,13 @@ class MontyDecoder(json.JSONDecoder): json.loads(json_string, cls=MontyDecoder) """ - _name_object_map = None - def process_decoded(self, d): """ Recursive method to support decoding dicts and lists containing pymatgen objects. """ - if isinstance(d, dict): - if "@object_reference" in d and self._name_object_map is not None: - name = d["@object_reference"] - return self._name_object_map.pop(name) - elif "@module" in d and "@class" in d: + if "@module" in d and "@class" in d: modname = d["@module"] classname = d["@class"] if cls_redirect := MSONable.REDIRECT.get(modname, {}).get(classname): @@ -711,11 +555,6 @@ def process_decoded(self, d): if hasattr(mod, classname): cls_ = getattr(mod, classname) data = {k: v for k, v in d.items() if not k.startswith("@")} - if hasattr(cls_, "_from_dict"): - # New functionality with save/load requires this - return cls_._from_dict( - data, name_object_map=self._name_object_map - ) if hasattr(cls_, "from_dict"): return cls_.from_dict(data) if issubclass(cls_, Enum): @@ -776,24 +615,21 @@ def process_decoded(self, d): return d - def decode(self, s: str) -> object: # type: ignore[override] + def decode(self, s): """ Overrides decode from JSONDecoder. - Args: - s: string - - Returns: - Object. + :param s: string + :return: Object. """ if orjson is not None: try: - _d = orjson.loads(s) # pylint: disable=E1101 + d = orjson.loads(s) # pylint: disable=E1101 except orjson.JSONDecodeError: # pylint: disable=E1101 - _d = json.loads(s) + d = json.loads(s) else: - _d = json.loads(s) - return self.process_decoded(_d) + d = json.loads(s) + return self.process_decoded(d) class MSONError(Exception): @@ -803,12 +639,8 @@ class MSONError(Exception): def jsanitize( - obj: Any, - strict: bool = False, - allow_bson: bool = False, - enum_values: bool = False, - recursive_msonable: bool = False, -) -> Any: + obj, strict=False, allow_bson=False, enum_values=False, recursive_msonable=False +): """ This method cleans an input json-like object, either a list or a dict or some sequence, nested or otherwise, by converting all non-string @@ -846,34 +678,18 @@ def jsanitize( or (bson is not None and isinstance(obj, bson.objectid.ObjectId)) ): return obj - if isinstance(obj, (list, tuple)): return [ - jsanitize( - i, - strict=strict, - allow_bson=allow_bson, - enum_values=enum_values, - recursive_msonable=recursive_msonable, - ) + jsanitize(i, strict=strict, allow_bson=allow_bson, enum_values=enum_values) for i in obj ] - if np is not None and isinstance(obj, np.ndarray): return [ - jsanitize( - i, - strict=strict, - allow_bson=allow_bson, - enum_values=enum_values, - recursive_msonable=recursive_msonable, - ) + jsanitize(i, strict=strict, allow_bson=allow_bson, enum_values=enum_values) for i in obj.tolist() ] - if np is not None and isinstance(obj, np.generic): return obj.item() - if _check_type( obj, ( @@ -883,7 +699,6 @@ def jsanitize( ), ): return obj.to_dict() - if isinstance(obj, dict): return { str(k): jsanitize( @@ -895,22 +710,24 @@ def jsanitize( ) for k, v in obj.items() } - if isinstance(obj, (int, float)): return obj - if obj is None: return None if isinstance(obj, (pathlib.Path, datetime.datetime)): return str(obj) if callable(obj) and not isinstance(obj, MSONable): - with contextlib.suppress(TypeError): + try: return _serialize_callable(obj) + except TypeError: + pass if recursive_msonable: - with contextlib.suppress(AttributeError): + try: return obj.as_dict() + except AttributeError: + pass if not strict: return str(obj) @@ -936,7 +753,7 @@ def jsanitize( ) -def _serialize_callable(o: Any) -> dict: +def _serialize_callable(o): if isinstance(o, types.BuiltinFunctionType): # don't care about what builtin functions (sum, open, etc) are bound to bound = None