Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Nov 16, 2024
1 parent 22c6190 commit b288bad
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions py/rvspecfit/nn/train_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
def getData(dir, setup, log_ids=[0]):

fname = f'{dir}/specs_{setup}.h5'
dat = serializer.load_dict_from_hdf5(fname))
dat = serializer.load_dict_from_hdf5(fname)
lam = dat['lam']
vecs = dat['vec'].T[:]
dats = dat['specs']
Expand Down Expand Up @@ -287,13 +287,15 @@ def main(args):

torch.save(myint.state_dict(), finalfile_path)
pred = get_predictions(myint, Tvecs0, train_dev, batch)
with open(f'{directory}/pred_{setup}.psav', 'wb') as fp:
DD = {}
DD['pred'] = pred,
DD['vecs'] = vecs
DD['dats'] = dats,
DD['mapper'] = mapper
DD['vecs_orig'] = vecs_orig
cur_name = f'{directory}/pred_{setup}.psav'
DD = {}
DD['pred'] = pred,
DD['vecs'] = vecs
DD['dats'] = dats,
DD['mapper'] = mapper
DD['vecs_orig'] = vecs_orig
serializer.save_dict_to_hdf5(cur_name, DD)

if os.path.exists(statefile_path):
os.unlink(statefile_path)
import rvspecfit.nn.RVSInterpolator # noqa
Expand Down

0 comments on commit b288bad

Please sign in to comment.