From b288bad83cced080743a9db6ac7f25b42e9d8756 Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Sat, 16 Nov 2024 19:02:24 +0000 Subject: [PATCH] fixes --- py/rvspecfit/nn/train_interpolator.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/py/rvspecfit/nn/train_interpolator.py b/py/rvspecfit/nn/train_interpolator.py index e9df376..40fb163 100644 --- a/py/rvspecfit/nn/train_interpolator.py +++ b/py/rvspecfit/nn/train_interpolator.py @@ -17,7 +17,7 @@ def getData(dir, setup, log_ids=[0]): fname = f'{dir}/specs_{setup}.h5' - dat = serializer.load_dict_from_hdf5(fname)) + dat = serializer.load_dict_from_hdf5(fname) lam = dat['lam'] vecs = dat['vec'].T[:] dats = dat['specs'] @@ -287,13 +287,15 @@ def main(args): torch.save(myint.state_dict(), finalfile_path) pred = get_predictions(myint, Tvecs0, train_dev, batch) - with open(f'{directory}/pred_{setup}.psav', 'wb') as fp: - DD = {} - DD['pred'] = pred, - DD['vecs'] = vecs - DD['dats'] = dats, - DD['mapper'] = mapper - DD['vecs_orig'] = vecs_orig + cur_name = f'{directory}/pred_{setup}.psav' + DD = {} + DD['pred'] = pred, + DD['vecs'] = vecs + DD['dats'] = dats, + DD['mapper'] = mapper + DD['vecs_orig'] = vecs_orig + serializer.save_dict_to_hdf5(cur_name, DD) + if os.path.exists(statefile_path): os.unlink(statefile_path) import rvspecfit.nn.RVSInterpolator # noqa