From 127a6db003d34e2e7250db62ddc4e81d867425e5 Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Sat, 16 Nov 2024 18:51:46 +0000 Subject: [PATCH] attempt to switch from pickle to hdf5 to avoid numpy issues with desi setup.. --- py/rvspecfit/fitter_ccf.py | 6 +- py/rvspecfit/make_ccf.py | 23 ++- py/rvspecfit/make_interpol.py | 28 ++-- py/rvspecfit/make_nd.py | 24 ++-- py/rvspecfit/nn/train_interpolator.py | 44 +++--- py/rvspecfit/regularize_grid.py | 7 +- py/rvspecfit/serializer.py | 192 ++++++++++++++++++++++++++ py/rvspecfit/spec_inter.py | 2 +- 8 files changed, 257 insertions(+), 69 deletions(-) create mode 100644 py/rvspecfit/serializer.py diff --git a/py/rvspecfit/fitter_ccf.py b/py/rvspecfit/fitter_ccf.py index 789512c..e160c9a 100644 --- a/py/rvspecfit/fitter_ccf.py +++ b/py/rvspecfit/fitter_ccf.py @@ -1,9 +1,9 @@ -import pickle import numpy as np import scipy.optimize import scipy.interpolate from rvspecfit import make_ccf from rvspecfit.spec_fit import SpecData +from rvspecfit import serializer import logging @@ -44,7 +44,9 @@ def get_ccf_info(spec_setup, config): spec_setup, ccf_continuum) ccf_mod_fname = prefix + make_ccf.get_ccf_mod_name( spec_setup, ccf_continuum) - CCFCache.ccf_info[spec_setup] = pickle.load(open(ccf_info_fname, 'rb')) + + CCFCache.ccf_info[spec_setup] = serializer.load_dict_from_hdf5( + ccf_info_fname) C = np.load(ccf_dat_fname, mmap_mode='r') CCFCache.ccfs[spec_setup] = C['fft'] CCFCache.ccf2s[spec_setup] = C['fft2'] diff --git a/py/rvspecfit/make_ccf.py b/py/rvspecfit/make_ccf.py index 7e7cb86..a2a9a56 100644 --- a/py/rvspecfit/make_ccf.py +++ b/py/rvspecfit/make_ccf.py @@ -1,4 +1,3 @@ -import pickle import argparse import multiprocessing as mp import numpy as np @@ -10,6 +9,7 @@ from rvspecfit import spec_fit from rvspecfit import make_interpol +from rvspecfit import serializer import rvspecfit git_rev = rvspecfit.__version__ @@ -23,8 +23,8 @@ def get_continuum_prefix(continuum): return pref -def get_ccf_pkl_name(setup, continuum=True): - return 'ccf_' + get_continuum_prefix(continuum) + '%s.pkl' % setup +def get_ccf_info_name(setup, continuum=True): + return 'ccf_' + get_continuum_prefix(continuum) + '%s.h5' % setup def get_ccf_dat_name(setup, continuum=True): @@ -447,13 +447,11 @@ def ccf_executor(spec_setup, """ - with open(('%s/' + make_interpol.SPEC_PKL_NAME) % (prefix, spec_setup), - 'rb') as fp: - D = pickle.load(fp) - vec, specs, lam, parnames = D['vec'], D['specs'], D['lam'], D[ - 'parnames'] - log_spec = D['log_spec'] - del D + cur_fname = ('%s/' + make_interpol.SPECS_H5_NAME) % (prefix, spec_setup) + D = serializer.load_dict_from_hdf5(cur_fname) + vec, specs, lam, parnames = D['vec'], D['specs'], D['lam'], D['parnames'] + log_spec = D['log_spec'] + del D morton_id = get_mortoncurve_id(vec.T) inds = np.argsort(morton_id)[::every] @@ -473,7 +471,7 @@ def ccf_executor(spec_setup, ffts = np.array([np.fft.rfft(x) for x in models]) fft2s = np.array([np.fft.rfft(x**2) for x in models]) savefile = (oprefix + '/' + - get_ccf_pkl_name(spec_setup, ccfconf.continuum)) + get_ccf_info_name(spec_setup, ccfconf.continuum)) datsavefile = (oprefix + '/' + get_ccf_dat_name(spec_setup, ccfconf.continuum)) modsavefile = (oprefix + '/' + @@ -486,8 +484,7 @@ def ccf_executor(spec_setup, dHash['parnames'] = parnames dHash['revision'] = revision - with open(savefile, 'wb') as fp: - pickle.dump(dHash, fp) + serializer.save_dict_to_hdf5(savefile, dHash, allow_pickle=True) np.savez(datsavefile, fft=np.array(ffts), fft2=np.array(fft2s)) np.save(modsavefile, np.array(models)) diff --git a/py/rvspecfit/make_interpol.py b/py/rvspecfit/make_interpol.py index 8df3299..3850818 100644 --- a/py/rvspecfit/make_interpol.py +++ b/py/rvspecfit/make_interpol.py @@ -3,18 +3,18 @@ import os import sys import argparse -import pickle import logging import scipy.constants import scipy.optimize import numpy as np import sqlite3 from rvspecfit import read_grid +from rvspecfit import serializer import rvspecfit git_rev = rvspecfit.__version__ -SPEC_PKL_NAME = 'specs_%s.pkl' +SPECS_H5_NAME = 'specs_%s.h5' class FakePoolResult: @@ -278,18 +278,18 @@ def process_all(setupInfo, except OSError: raise RuntimeError('Failed to create output directory: %s' % (oprefix, )) - with open(('%s/' + SPEC_PKL_NAME) % (oprefix, HR), 'wb') as fp: - pickle.dump( - dict(specs=specs, - vec=vec, - lam=lam, - parnames=parnames, - git_rev=git_rev, - mapper=mapper, - revision=revision, - lognorms=lognorms, - logstep=log, - log_spec=log_spec), fp) + curfname = ('%s/' + SPECS_H5_NAME) % (oprefix, HR) + DD = dict(specs=specs, + vec=vec, + lam=lam, + parnames=parnames, + git_rev=git_rev, + mapper=mapper, + revision=revision, + lognorms=lognorms, + logstep=log, + log_spec=log_spec) + serializer.save_dict_to_hdf5(curfname, DD) def add_bool_arg(parser, name, default=False, help=None): diff --git a/py/rvspecfit/make_nd.py b/py/rvspecfit/make_nd.py index c88110b..09d0138 100644 --- a/py/rvspecfit/make_nd.py +++ b/py/rvspecfit/make_nd.py @@ -1,4 +1,3 @@ -import pickle import argparse import sys import numpy as np @@ -6,11 +5,12 @@ import scipy.spatial from rvspecfit import make_interpol +from rvspecfit import serializer import rvspecfit git_rev = rvspecfit.__version__ -INTERPOL_PKL_NAME = 'interp_%s.pkl' +INTERPOL_H5_NAME = 'interp_%s.h5' INTERPOL_DAT_NAME = 'interpdat_%s.npy' @@ -75,13 +75,12 @@ def execute(spec_setup, prefix=None, regular=False, perturb=True, revision=''): """ - with open(('%s/' + make_interpol.SPEC_PKL_NAME) % (prefix, spec_setup), - 'rb') as fp: - D = pickle.load(fp) - (vec, specs, lam, parnames, mapper, lognorms, - logstep) = (D['vec'], D['specs'], D['lam'], D['parnames'], - D['mapper'], D['lognorms'], D['logstep']) - del D + cur_fname = ('%s/' + make_interpol.SPECS_H5_NAME) % (prefix, spec_setup) + D = serializer.load_dict_from_hdf5(cur_fname) + (vec, specs, lam, parnames, mapper, lognorms, + log_step) = (D['vec'], D['specs'], D['lam'], D['parnames'], D['mapper'], + D['lognorms'], D['log_step']) + del D vec = vec.astype(float) vec = mapper.forward(vec) @@ -150,9 +149,9 @@ def execute(spec_setup, prefix=None, regular=False, perturb=True, revision=''): ret_dict['regular'] = True ret_dict['idgrid'] = idgrid ret_dict['interpolation_type'] = 'regulargrid' - savefile = ('%s/' + INTERPOL_PKL_NAME) % (prefix, spec_setup) + savefile = ('%s/' + INTERPOL_H5_NAME) % (prefix, spec_setup) ret_dict['lam'] = lam - ret_dict['logstep'] = logstep + ret_dict['log_step'] = log_step ret_dict['vec'] = vec ret_dict['parnames'] = parnames ret_dict['mapper'] = mapper @@ -160,8 +159,7 @@ def execute(spec_setup, prefix=None, regular=False, perturb=True, revision=''): ret_dict['lognorms'] = lognorms ret_dict['git_rev'] = git_rev - with open(savefile, 'wb') as fp: - pickle.dump(ret_dict, fp) + serialize.save_dict_to_hdf5(savefile, ret_dict) np.save(('%s/' + INTERPOL_DAT_NAME) % (prefix, spec_setup), np.ascontiguousarray(specs)) diff --git a/py/rvspecfit/nn/train_interpolator.py b/py/rvspecfit/nn/train_interpolator.py index 6a13b1e..e9df376 100644 --- a/py/rvspecfit/nn/train_interpolator.py +++ b/py/rvspecfit/nn/train_interpolator.py @@ -1,4 +1,3 @@ -import pickle import os import sys import argparse @@ -10,14 +9,15 @@ import numpy as np import torch from .NNInterpolator import Mapper, NNInterpolator +from rvscpecfit import serializer git_rev = rvspecfit.__version__ def getData(dir, setup, log_ids=[0]): - fname = (f'{dir}/specs_{setup}.pkl') - dat = pickle.load(open(fname, 'rb')) + fname = f'{dir}/specs_{setup}.h5' + dat = serializer.load_dict_from_hdf5(fname)) lam = dat['lam'] vecs = dat['vec'].T[:] dats = dat['specs'] @@ -298,25 +298,25 @@ def main(args): os.unlink(statefile_path) import rvspecfit.nn.RVSInterpolator # noqa - with open(f'{directory}/interp_{setup}.pkl', 'wb') as fp: - D = { - 'mapper': mapper, - 'parnames': parnames, - 'lam': lam, - 'log_spec': True, - 'logstep': True, - 'module': 'rvspecfit.nn.RVSInterpolator', - 'class_name': 'RVSInterpolator', - 'device': device_name, - 'class_kwargs': kwargs, - 'outside_class_name': 'OutsideInterpolator', - 'outside_kwargs': dict(pts=vecs), - 'nn_file': finalfile, - 'revision': revision, - 'git_rev': git_rev, - 'interpolation_type': 'generic' - } - pickle.dump(D, fp) + ofname = f'{directory}/interp_{setup}.pkl' + D = { + 'mapper': mapper, + 'parnames': parnames, + 'lam': lam, + 'log_spec': True, + 'logstep': True, + 'module': 'rvspecfit.nn.RVSInterpolator', + 'class_name': 'RVSInterpolator', + 'device': device_name, + 'class_kwargs': kwargs, + 'outside_class_name': 'OutsideInterpolator', + 'outside_kwargs': dict(pts=vecs), + 'nn_file': finalfile, + 'revision': revision, + 'git_rev': git_rev, + 'interpolation_type': 'generic' + } + serializer.save_dict_to_hdf5(ofname, D) if __name__ == '__main__': diff --git a/py/rvspecfit/regularize_grid.py b/py/rvspecfit/regularize_grid.py index d376c7a..35b0921 100644 --- a/py/rvspecfit/regularize_grid.py +++ b/py/rvspecfit/regularize_grid.py @@ -1,10 +1,10 @@ -import pickle import sys import argparse import scipy.stats import scipy.interpolate import numpy as np import scipy.version +from rvspecfit import serializer def findbestoverlaps(x, intervals): @@ -59,7 +59,7 @@ def converter(path, newalphagrid = np.arange(min_alpha, max_alpha + step_alpha / 2., step_alpha) - dat = pickle.load(open(path, 'rb')) + dat = serializer.load_dict_from_hdf5(path) vec = dat['vec'] specs = dat['specs'] @@ -145,8 +145,7 @@ def converter(path, dat['specs'] = res_spec dat['vec'] = res_vec - with open(opath, 'wb') as fp: - pickle.dump(dat, fp, protocol=4) + serializer.save_dict_to_hdf5(opath, dat) def check_scipy_version(): diff --git a/py/rvspecfit/serializer.py b/py/rvspecfit/serializer.py new file mode 100644 index 0000000..e84d495 --- /dev/null +++ b/py/rvspecfit/serializer.py @@ -0,0 +1,192 @@ +import h5py +import numpy as np +import pickle + +CURRENT_VERSION = 1 + + +def recursively_save_dict_contents_to_group(h5file, + path, + dic, + allow_pickle=False): + """ + Recursively saves dictionary contents to HDF5 groups and datasets. + """ + for key, item in dic.items(): + key_path = f"{path}/{key}" + if isinstance(item, dict): + recursively_save_dict_contents_to_group(h5file, + key_path, + item, + allow_pickle=allow_pickle) + elif isinstance(item, (list, tuple)): + is_list = isinstance(item, list) + if all(isinstance(x, type(item[0])) + for x in item): # Ensure all elements are of the same type + array = np.array(item) + if array.dtype.char == 'U': + ds = h5file.create_dataset(key_path, + shape=len(item), + dtype=h5py.string_dtype()) + ds[:] = array + else: + h5file.create_dataset(key_path, data=array) + if is_list: + h5file[key_path].attrs['type'] = 'list' + else: + h5file[key_path].attrs['type'] = 'tuple' + else: # Empty list or tuple + h5file.create_dataset(key_path, data=np.array(item)) + h5file[key_path].attrs['type'] = 'empty_array' + elif isinstance(item, np.ndarray): + if item.dtype.char == 'U': + # Unicode strings are handled differently + ds = h5file.create_dataset(key_path, + shape=len(item), + dtype=h5py.string_dtype()) + ds[:] = item + + else: + h5file.create_dataset(key_path, + data=item) # Directly save numpy arrays + h5file[key_path].attrs['type'] = 'ndarray' + elif isinstance(item, str): + dt = h5py.string_dtype('utf-8') + h5file.create_dataset(key_path, data=item, dtype=dt) + h5file[key_path].attrs['type'] = 'str' + elif isinstance(item, (int, float, complex, np.generic)): + # Handle numbers: int, float + h5file.create_dataset(key_path, data=item) + h5file[key_path].attrs['type'] = 'scalar' + else: + if allow_pickle: + print('Warning, type not understood, pickling', type(item)) + item = pickle.dumps(item) + h5file[key_path] = np.void(item) + h5file[key_path].attrs['type'] = 'pickle' + else: + raise ValueError( + f'Cannot save {type(item)} and pickling is not allowed') + + +def save_dict_to_hdf5(filename, dic, allow_pickle=False): + """ + Saves the provided dictionary to an HDF5 file. + """ + with h5py.File(filename, 'w') as h5file: + h5file.attrs['version'] = CURRENT_VERSION + recursively_save_dict_contents_to_group(h5file, + '/', + dic, + allow_pickle=allow_pickle) + + +def recursively_load_dict_contents_from_group(h5file, path): + """ + Recursively loads dictionary contents from HDF5 groups and datasets. + """ + ans = {} + for key, item in h5file[path].items(): + if isinstance(item, h5py.Dataset): # Handle datasets + curtyp = item.attrs['type'] + if curtyp in ['list', 'tuple']: + ans[key] = item[:] + print('xx', key, item, curtyp) + if item.dtype.kind == 'O': # Decode strings properly + ans[key] = ans[key].astype(str) + if curtyp == 'list': + ans[key] = list(ans[key]) + else: + ans[key] = tuple(ans[key]) + if item.attrs['type'] == 'ndarray': + ans[key] = item[:] + print('aa', item.dtype.kind) + if item.dtype.kind == 'O': # Decode strings properly + ans[key] = ans[key].astype(str) + elif item.attrs['type'] == 'str': + ans[key] = item[()].decode('utf-8') + elif item.attrs['type'] in ['scalar', 'empty_array']: + ans[key] = item[()] + elif item.attrs['type'] == 'pickle': + ans[key] = pickle.loads(item[()]) + elif isinstance(item, + h5py.Group): # Handle groups (nested dictionaries) + ans[key] = recursively_load_dict_contents_from_group( + h5file, f"{path}/{key}") + return ans + + +def load_dict_from_hdf5(filename): + """ + Loads the dictionary from an HDF5 file. + """ + with h5py.File(filename, 'r') as h5file: + version = h5file.attrs.get('version', None) + if version != CURRENT_VERSION: + raise ValueError(f'Incompatible version: {version}') + return recursively_load_dict_contents_from_group(h5file, '/') + + +def verify_data(original, loaded, path='/'): + """ + Recursively verify that two dictionaries are identical in both value + and type. + """ + if not isinstance(loaded, type(original)): + print('fail1', path, (original), (loaded)) + return False + if isinstance(original, dict): + if original.keys() != loaded.keys(): + print('fail2', path) + return False + return all( + verify_data(original[key], loaded[key], path + '/' + key) + for key in original) + if isinstance(original, (list, tuple, np.ndarray)): + if len(original) != len(loaded): + print('fail3', path) + return False + return all(verify_data(o, l, path) for o, l in zip(original, loaded)) + return original == loaded + + +class TestClass: + + def __init__(self, x, y): + self.x = x + self.y = y + + def __eq__(self, other): + return self.x == other.x and self.y == other.y + + +def test_code(): + + # Example data containing various data types + data = { + 'x': np.int64(2), + 'vv': np.arange(1000, dtype=np.float64), + 'y': { + 'inside_y': np.arange(5) + }, + 'z': 'Hello world!', + 'tuple_data': (np.int64(1), np.int64(2), np.int64(3)), + 'list_data': [1.1, 2.2, 3.3], + 'xliststr': ['test', 'example'], + 'qq': np.array(['x', 'y', 'z']), + 'a1': [], + 'a2': tuple(), + 'myclass': TestClass(1, 2) + } + + # Save dictionary to an HDF5 file + save_dict_to_hdf5(data, 'data.h5', allow_pickle=True) + + # Load dictionary from HDF5 file + loaded_data = load_dict_from_hdf5('data.h5') + print(loaded_data) + print(verify_data(data, loaded_data)) + + +if __name__ == '__main__': + test_code() diff --git a/py/rvspecfit/spec_inter.py b/py/rvspecfit/spec_inter.py index 62b8b32..3d3ad45 100644 --- a/py/rvspecfit/spec_inter.py +++ b/py/rvspecfit/spec_inter.py @@ -312,7 +312,7 @@ def getInterpolator(HR, config, warmup_cache=False, cache=None): system_cache = False if HR not in cache: savefile = (config['template_lib'] + '/' + - make_nd.INTERPOL_PKL_NAME % HR) + make_nd.INTERPOL_H5_NAME % HR) with open(savefile, 'rb') as fd0: fd = pickle.load(fd0) log_spec = fd.get('log_spec') or True