Skip to content

Commit

Permalink
attempt to switch from pickle to hdf5 to avoid numpy issues with desi…
Browse files Browse the repository at this point in the history
… setup..
  • Loading branch information
segasai committed Nov 16, 2024
1 parent 3e719a1 commit 127a6db
Show file tree
Hide file tree
Showing 8 changed files with 257 additions and 69 deletions.
6 changes: 4 additions & 2 deletions py/rvspecfit/fitter_ccf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pickle
import numpy as np
import scipy.optimize
import scipy.interpolate
from rvspecfit import make_ccf
from rvspecfit.spec_fit import SpecData
from rvspecfit import serializer
import logging


Expand Down Expand Up @@ -44,7 +44,9 @@ def get_ccf_info(spec_setup, config):
spec_setup, ccf_continuum)
ccf_mod_fname = prefix + make_ccf.get_ccf_mod_name(
spec_setup, ccf_continuum)
CCFCache.ccf_info[spec_setup] = pickle.load(open(ccf_info_fname, 'rb'))

CCFCache.ccf_info[spec_setup] = serializer.load_dict_from_hdf5(
ccf_info_fname)
C = np.load(ccf_dat_fname, mmap_mode='r')
CCFCache.ccfs[spec_setup] = C['fft']
CCFCache.ccf2s[spec_setup] = C['fft2']
Expand Down
23 changes: 10 additions & 13 deletions py/rvspecfit/make_ccf.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
import argparse
import multiprocessing as mp
import numpy as np
Expand All @@ -10,6 +9,7 @@

from rvspecfit import spec_fit
from rvspecfit import make_interpol
from rvspecfit import serializer
import rvspecfit

git_rev = rvspecfit.__version__
Expand All @@ -23,8 +23,8 @@ def get_continuum_prefix(continuum):
return pref


def get_ccf_pkl_name(setup, continuum=True):
return 'ccf_' + get_continuum_prefix(continuum) + '%s.pkl' % setup
def get_ccf_info_name(setup, continuum=True):
return 'ccf_' + get_continuum_prefix(continuum) + '%s.h5' % setup


def get_ccf_dat_name(setup, continuum=True):
Expand Down Expand Up @@ -447,13 +447,11 @@ def ccf_executor(spec_setup,
"""

with open(('%s/' + make_interpol.SPEC_PKL_NAME) % (prefix, spec_setup),
'rb') as fp:
D = pickle.load(fp)
vec, specs, lam, parnames = D['vec'], D['specs'], D['lam'], D[
'parnames']
log_spec = D['log_spec']
del D
cur_fname = ('%s/' + make_interpol.SPECS_H5_NAME) % (prefix, spec_setup)
D = serializer.load_dict_from_hdf5(cur_fname)
vec, specs, lam, parnames = D['vec'], D['specs'], D['lam'], D['parnames']
log_spec = D['log_spec']
del D

morton_id = get_mortoncurve_id(vec.T)
inds = np.argsort(morton_id)[::every]
Expand All @@ -473,7 +471,7 @@ def ccf_executor(spec_setup,
ffts = np.array([np.fft.rfft(x) for x in models])
fft2s = np.array([np.fft.rfft(x**2) for x in models])
savefile = (oprefix + '/' +
get_ccf_pkl_name(spec_setup, ccfconf.continuum))
get_ccf_info_name(spec_setup, ccfconf.continuum))
datsavefile = (oprefix + '/' +
get_ccf_dat_name(spec_setup, ccfconf.continuum))
modsavefile = (oprefix + '/' +
Expand All @@ -486,8 +484,7 @@ def ccf_executor(spec_setup,
dHash['parnames'] = parnames
dHash['revision'] = revision

with open(savefile, 'wb') as fp:
pickle.dump(dHash, fp)
serializer.save_dict_to_hdf5(savefile, dHash, allow_pickle=True)
np.savez(datsavefile, fft=np.array(ffts), fft2=np.array(fft2s))
np.save(modsavefile, np.array(models))

Expand Down
28 changes: 14 additions & 14 deletions py/rvspecfit/make_interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import os
import sys
import argparse
import pickle
import logging
import scipy.constants
import scipy.optimize
import numpy as np
import sqlite3
from rvspecfit import read_grid
from rvspecfit import serializer
import rvspecfit

git_rev = rvspecfit.__version__

SPEC_PKL_NAME = 'specs_%s.pkl'
SPECS_H5_NAME = 'specs_%s.h5'


class FakePoolResult:
Expand Down Expand Up @@ -278,18 +278,18 @@ def process_all(setupInfo,
except OSError:
raise RuntimeError('Failed to create output directory: %s' %
(oprefix, ))
with open(('%s/' + SPEC_PKL_NAME) % (oprefix, HR), 'wb') as fp:
pickle.dump(
dict(specs=specs,
vec=vec,
lam=lam,
parnames=parnames,
git_rev=git_rev,
mapper=mapper,
revision=revision,
lognorms=lognorms,
logstep=log,
log_spec=log_spec), fp)
curfname = ('%s/' + SPECS_H5_NAME) % (oprefix, HR)
DD = dict(specs=specs,
vec=vec,
lam=lam,
parnames=parnames,
git_rev=git_rev,
mapper=mapper,
revision=revision,
lognorms=lognorms,
logstep=log,
log_spec=log_spec)
serializer.save_dict_to_hdf5(curfname, DD)


def add_bool_arg(parser, name, default=False, help=None):
Expand Down
24 changes: 11 additions & 13 deletions py/rvspecfit/make_nd.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import pickle
import argparse
import sys
import numpy as np
import numpy.random
import scipy.spatial

from rvspecfit import make_interpol
from rvspecfit import serializer
import rvspecfit

git_rev = rvspecfit.__version__

INTERPOL_PKL_NAME = 'interp_%s.pkl'
INTERPOL_H5_NAME = 'interp_%s.h5'
INTERPOL_DAT_NAME = 'interpdat_%s.npy'


Expand Down Expand Up @@ -75,13 +75,12 @@ def execute(spec_setup, prefix=None, regular=False, perturb=True, revision=''):
"""

with open(('%s/' + make_interpol.SPEC_PKL_NAME) % (prefix, spec_setup),
'rb') as fp:
D = pickle.load(fp)
(vec, specs, lam, parnames, mapper, lognorms,
logstep) = (D['vec'], D['specs'], D['lam'], D['parnames'],
D['mapper'], D['lognorms'], D['logstep'])
del D
cur_fname = ('%s/' + make_interpol.SPECS_H5_NAME) % (prefix, spec_setup)
D = serializer.load_dict_from_hdf5(cur_fname)
(vec, specs, lam, parnames, mapper, lognorms,
log_step) = (D['vec'], D['specs'], D['lam'], D['parnames'], D['mapper'],
D['lognorms'], D['log_step'])
del D

vec = vec.astype(float)
vec = mapper.forward(vec)
Expand Down Expand Up @@ -150,18 +149,17 @@ def execute(spec_setup, prefix=None, regular=False, perturb=True, revision=''):
ret_dict['regular'] = True
ret_dict['idgrid'] = idgrid
ret_dict['interpolation_type'] = 'regulargrid'
savefile = ('%s/' + INTERPOL_PKL_NAME) % (prefix, spec_setup)
savefile = ('%s/' + INTERPOL_H5_NAME) % (prefix, spec_setup)
ret_dict['lam'] = lam
ret_dict['logstep'] = logstep
ret_dict['log_step'] = log_step
ret_dict['vec'] = vec
ret_dict['parnames'] = parnames
ret_dict['mapper'] = mapper
ret_dict['revision'] = revision
ret_dict['lognorms'] = lognorms
ret_dict['git_rev'] = git_rev

with open(savefile, 'wb') as fp:
pickle.dump(ret_dict, fp)
serialize.save_dict_to_hdf5(savefile, ret_dict)
np.save(('%s/' + INTERPOL_DAT_NAME) % (prefix, spec_setup),
np.ascontiguousarray(specs))

Expand Down
44 changes: 22 additions & 22 deletions py/rvspecfit/nn/train_interpolator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
import os
import sys
import argparse
Expand All @@ -10,14 +9,15 @@
import numpy as np
import torch
from .NNInterpolator import Mapper, NNInterpolator
from rvscpecfit import serializer

git_rev = rvspecfit.__version__


def getData(dir, setup, log_ids=[0]):

fname = (f'{dir}/specs_{setup}.pkl')
dat = pickle.load(open(fname, 'rb'))
fname = f'{dir}/specs_{setup}.h5'
dat = serializer.load_dict_from_hdf5(fname))
lam = dat['lam']
vecs = dat['vec'].T[:]
dats = dat['specs']
Expand Down Expand Up @@ -298,25 +298,25 @@ def main(args):
os.unlink(statefile_path)
import rvspecfit.nn.RVSInterpolator # noqa

with open(f'{directory}/interp_{setup}.pkl', 'wb') as fp:
D = {
'mapper': mapper,
'parnames': parnames,
'lam': lam,
'log_spec': True,
'logstep': True,
'module': 'rvspecfit.nn.RVSInterpolator',
'class_name': 'RVSInterpolator',
'device': device_name,
'class_kwargs': kwargs,
'outside_class_name': 'OutsideInterpolator',
'outside_kwargs': dict(pts=vecs),
'nn_file': finalfile,
'revision': revision,
'git_rev': git_rev,
'interpolation_type': 'generic'
}
pickle.dump(D, fp)
ofname = f'{directory}/interp_{setup}.pkl'
D = {
'mapper': mapper,
'parnames': parnames,
'lam': lam,
'log_spec': True,
'logstep': True,
'module': 'rvspecfit.nn.RVSInterpolator',
'class_name': 'RVSInterpolator',
'device': device_name,
'class_kwargs': kwargs,
'outside_class_name': 'OutsideInterpolator',
'outside_kwargs': dict(pts=vecs),
'nn_file': finalfile,
'revision': revision,
'git_rev': git_rev,
'interpolation_type': 'generic'
}
serializer.save_dict_to_hdf5(ofname, D)


if __name__ == '__main__':
Expand Down
7 changes: 3 additions & 4 deletions py/rvspecfit/regularize_grid.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import pickle
import sys
import argparse
import scipy.stats
import scipy.interpolate
import numpy as np
import scipy.version
from rvspecfit import serializer


def findbestoverlaps(x, intervals):
Expand Down Expand Up @@ -59,7 +59,7 @@ def converter(path,
newalphagrid = np.arange(min_alpha, max_alpha + step_alpha / 2.,
step_alpha)

dat = pickle.load(open(path, 'rb'))
dat = serializer.load_dict_from_hdf5(path)

vec = dat['vec']
specs = dat['specs']
Expand Down Expand Up @@ -145,8 +145,7 @@ def converter(path,
dat['specs'] = res_spec
dat['vec'] = res_vec

with open(opath, 'wb') as fp:
pickle.dump(dat, fp, protocol=4)
serializer.save_dict_to_hdf5(opath, dat)


def check_scipy_version():
Expand Down
Loading

0 comments on commit 127a6db

Please sign in to comment.