Skip to content

Commit

Permalink
Fix set state when loading from pickle.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiocaccamo committed Mar 9, 2023
1 parent 39301e7 commit 5def5b3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 2 deletions.
3 changes: 2 additions & 1 deletion benedict/dicts/base/base_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def __setitem__(self, key, value):
super().__setitem__(key, value)

def __setstate__(self, state):
self.__dict__ = state
self._dict = state["_dict"]
self._pointer = state["_pointer"]

def __str__(self):
if self._pointer:
Expand Down
4 changes: 4 additions & 0 deletions benedict/dicts/keyattr/keyattr_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ def __setattr__(self, attr, value):
if not self._keyattr_enabled:
raise AttributeError
self.__setitem__(attr, value)

def __setstate__(self, state):
super().__setstate__(state)
self._keyattr_enabled = state["_keyattr_enabled"]
4 changes: 4 additions & 0 deletions benedict/dicts/keypath/keypath_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __setitem__(self, key, value):
keypath_util.check_keys(value, self._keypath_separator)
super().__setitem__(self._parse_key(key), value)

def __setstate__(self, state):
super().__setstate__(state)
self._keypath_separator = state["_keypath_separator"]

def _parse_key(self, key):
keys = keypath_util.parse_keys(key, self._keypath_separator)
keys_count = len(keys)
Expand Down
7 changes: 6 additions & 1 deletion tests/dicts/io/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@ def test_pickle(self):
"h": "0",
"i": benedict({"h": True}),
}
b = benedict(d, keypath_separator="/")
b = benedict(
d,
keyattr_enabled=False,
keypath_separator="/",
)
b_encoded = pickle.dumps(b)
# print(b_encoded)
b_decoded = pickle.loads(b_encoded)
# print(b_decoded)
# print(b_decoded.keypath_separator)
self.assertTrue(isinstance(b_decoded, benedict))
self.assertEqual(b_decoded.keyattr_enabled, b.keyattr_enabled)
self.assertEqual(b_decoded.keypath_separator, b.keypath_separator)
self.assertEqual(b_decoded, b)

0 comments on commit 5def5b3

Please sign in to comment.