From 8c9bd0e3eb5e5da1ea0a9644720bc4505e4a7f5b Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Sat, 16 Nov 2024 21:20:50 +0000 Subject: [PATCH] allow serialization of lists/tuples with different types --- py/rvspecfit/serializer.py | 48 ++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/py/rvspecfit/serializer.py b/py/rvspecfit/serializer.py index 6b6b0cb..4c68530 100644 --- a/py/rvspecfit/serializer.py +++ b/py/rvspecfit/serializer.py @@ -1,6 +1,7 @@ import h5py import numpy as np import pickle +import re CURRENT_VERSION = 1 @@ -23,21 +24,34 @@ def recursively_save_dict_contents_to_group(h5file, is_list = isinstance(item, list) if all(isinstance(x, type(item[0])) for x in item): # Ensure all elements are of the same type - array = np.array(item) - if array.dtype.char == 'U': + arr = np.array(item) + if arr.dtype.char == 'U': ds = h5file.create_dataset(key_path, shape=len(item), dtype=h5py.string_dtype()) - ds[:] = array + ds[:] = arr else: - h5file.create_dataset(key_path, data=array) + h5file.create_dataset(key_path, data=arr) if is_list: h5file[key_path].attrs['type'] = 'list' else: h5file[key_path].attrs['type'] = 'tuple' - else: # Empty list or tuple + elif len(item) == 0: # Empty list or tuple h5file.create_dataset(key_path, data=np.array(item)) h5file[key_path].attrs['type'] = 'empty_array' + else: + fake_dict = {} + pref = '_tuple' + if is_list: + pref = '_list' + for i, cur_it in enumerate(item): + fake_dict['__item_%d' % i] = cur_it + recursively_save_dict_contents_to_group( + h5file, key_path, fake_dict, allow_pickle=allow_pickle) + h5file[key_path].attrs['type'] = 'flattened' + pref + # raise Exception('x') + # different types + elif isinstance(item, np.ndarray): if item.dtype.char == 'U': # Unicode strings are handled differently @@ -91,16 +105,14 @@ def recursively_load_dict_contents_from_group(h5file, path): curtyp = item.attrs['type'] if curtyp in ['list', 'tuple']: ans[key] = item[:] - print('xx', key, item, curtyp) if item.dtype.kind == 'O': # Decode strings properly ans[key] = ans[key].astype(str) if curtyp == 'list': ans[key] = list(ans[key]) else: ans[key] = tuple(ans[key]) - if item.attrs['type'] == 'ndarray': + elif item.attrs['type'] == 'ndarray': ans[key] = item[:] - print('aa', item.dtype.kind) if item.dtype.kind == 'O': # Decode strings properly ans[key] = ans[key].astype(str) elif item.attrs['type'] == 'str': @@ -109,10 +121,21 @@ def recursively_load_dict_contents_from_group(h5file, path): ans[key] = item[()] elif item.attrs['type'] == 'pickle': ans[key] = pickle.loads(item[()]) + else: + raise Exception('unsupported') elif isinstance(item, h5py.Group): # Handle groups (nested dictionaries) ans[key] = recursively_load_dict_contents_from_group( h5file, f"{path}/{key}") + curtyp = item.attrs.get('type') + if curtyp in ['flattened_tuple', 'flattened_list']: + keys = sorted(ans[key].keys()) + assert all( + [re.match('__item_.*', _) is not None for _ in keys]) + keys = ['__item_%d' % i for i in range(len(keys))] + ans[key] = [ans[key][_] for _ in keys] + if curtyp == 'flattened_tuple': + ans[key] = tuple(ans[key]) return ans @@ -133,7 +156,7 @@ def verify_data(original, loaded, path='/'): and type. """ if not isinstance(loaded, type(original)): - print('fail1', path, (original), (loaded)) + print('fail1', path, original, loaded, type(original), type(loaded)) return False if isinstance(original, dict): if original.keys() != loaded.keys(): @@ -167,7 +190,11 @@ def test_code(): 'x': np.int64(2), 'vv': np.arange(3, dtype=np.float64), 'y': { - 'inside_y': np.arange(5) + 'inside_y': np.arange(5), + 'intside_y_dict': { + 'x': np.int64(55), + 'y': np.int64(66) + }, }, 'z': 'Hello world!', 'tuple_data': (np.int64(1), np.int64(2), np.int64(3)), @@ -176,6 +203,7 @@ def test_code(): 'qq': np.array(['x', 'y', 'z']), 'a1': [], 'a2': tuple(), + 'a3': (np.int64(1), 'x'), 'myclass': TestClass(1, 2) }