Skip to content

Commit

Permalink
allow serialization of lists/tuples with different types
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Nov 16, 2024
1 parent a5ae73a commit 8c9bd0e
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions py/rvspecfit/serializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import h5py
import numpy as np
import pickle
import re

CURRENT_VERSION = 1

Expand All @@ -23,21 +24,34 @@ def recursively_save_dict_contents_to_group(h5file,
is_list = isinstance(item, list)
if all(isinstance(x, type(item[0]))
for x in item): # Ensure all elements are of the same type
array = np.array(item)
if array.dtype.char == 'U':
arr = np.array(item)
if arr.dtype.char == 'U':
ds = h5file.create_dataset(key_path,
shape=len(item),
dtype=h5py.string_dtype())
ds[:] = array
ds[:] = arr
else:
h5file.create_dataset(key_path, data=array)
h5file.create_dataset(key_path, data=arr)
if is_list:
h5file[key_path].attrs['type'] = 'list'
else:
h5file[key_path].attrs['type'] = 'tuple'
else: # Empty list or tuple
elif len(item) == 0: # Empty list or tuple
h5file.create_dataset(key_path, data=np.array(item))
h5file[key_path].attrs['type'] = 'empty_array'
else:
fake_dict = {}
pref = '_tuple'
if is_list:
pref = '_list'
for i, cur_it in enumerate(item):
fake_dict['__item_%d' % i] = cur_it
recursively_save_dict_contents_to_group(
h5file, key_path, fake_dict, allow_pickle=allow_pickle)
h5file[key_path].attrs['type'] = 'flattened' + pref
# raise Exception('x')
# different types

elif isinstance(item, np.ndarray):
if item.dtype.char == 'U':
# Unicode strings are handled differently
Expand Down Expand Up @@ -91,16 +105,14 @@ def recursively_load_dict_contents_from_group(h5file, path):
curtyp = item.attrs['type']
if curtyp in ['list', 'tuple']:
ans[key] = item[:]
print('xx', key, item, curtyp)
if item.dtype.kind == 'O': # Decode strings properly
ans[key] = ans[key].astype(str)
if curtyp == 'list':
ans[key] = list(ans[key])
else:
ans[key] = tuple(ans[key])
if item.attrs['type'] == 'ndarray':
elif item.attrs['type'] == 'ndarray':
ans[key] = item[:]
print('aa', item.dtype.kind)
if item.dtype.kind == 'O': # Decode strings properly
ans[key] = ans[key].astype(str)
elif item.attrs['type'] == 'str':
Expand All @@ -109,10 +121,21 @@ def recursively_load_dict_contents_from_group(h5file, path):
ans[key] = item[()]
elif item.attrs['type'] == 'pickle':
ans[key] = pickle.loads(item[()])
else:
raise Exception('unsupported')
elif isinstance(item,
h5py.Group): # Handle groups (nested dictionaries)
ans[key] = recursively_load_dict_contents_from_group(
h5file, f"{path}/{key}")
curtyp = item.attrs.get('type')
if curtyp in ['flattened_tuple', 'flattened_list']:
keys = sorted(ans[key].keys())
assert all(
[re.match('__item_.*', _) is not None for _ in keys])
keys = ['__item_%d' % i for i in range(len(keys))]
ans[key] = [ans[key][_] for _ in keys]
if curtyp == 'flattened_tuple':
ans[key] = tuple(ans[key])
return ans


Expand All @@ -133,7 +156,7 @@ def verify_data(original, loaded, path='/'):
and type.
"""
if not isinstance(loaded, type(original)):
print('fail1', path, (original), (loaded))
print('fail1', path, original, loaded, type(original), type(loaded))
return False
if isinstance(original, dict):
if original.keys() != loaded.keys():
Expand Down Expand Up @@ -167,7 +190,11 @@ def test_code():
'x': np.int64(2),
'vv': np.arange(3, dtype=np.float64),
'y': {
'inside_y': np.arange(5)
'inside_y': np.arange(5),
'intside_y_dict': {
'x': np.int64(55),
'y': np.int64(66)
},
},
'z': 'Hello world!',
'tuple_data': (np.int64(1), np.int64(2), np.int64(3)),
Expand All @@ -176,6 +203,7 @@ def test_code():
'qq': np.array(['x', 'y', 'z']),
'a1': [],
'a2': tuple(),
'a3': (np.int64(1), 'x'),
'myclass': TestClass(1, 2)
}

Expand Down

0 comments on commit 8c9bd0e

Please sign in to comment.