Skip to content

Commit

Permalink
add nn interpolator script
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Feb 12, 2024
1 parent a1187f7 commit 75c610b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
7 changes: 7 additions & 0 deletions bin/rvs_train_nn_interpolator
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env python

import sys
import rvspecfit.nn.train_interpolator as rvt

if __name__ == '__main__':
rvt.main(sys.argv[1:])
33 changes: 20 additions & 13 deletions py/rvspecfit/nn/train_interpolator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import rvspecfit
import pickle
import os
import sys
import argparse
import time
import rvspecfit
import torch.utils.data as toda
# import matplotlib.pyplot as plt
import torch.nn.functional as tofu
import sklearn.decomposition as skde
import pickle
import numpy as np
import torch
import os
from idlsave import idlsave
from NNInterpolator import Mapper, NNInterpolator
import argparse

git_rev = rvspecfit.__version__

Expand Down Expand Up @@ -46,7 +45,7 @@ def getSchedOptim(optim):
return sched


def get_predictions(myint, Tvecs0, dev):
def get_predictions(myint, Tvecs0, dev, batch):
alldat = toda.TensorDataset(Tvecs0)
pred = []
batchdatF = toda.DataLoader(
Expand All @@ -64,7 +63,7 @@ def get_predictions(myint, Tvecs0, dev):
return pred


if __name__ == '__main__':
def main(args):

parser = argparse.ArgumentParser()
parser.add_argument('--cpu', action='store_true', default=False)
Expand All @@ -89,7 +88,7 @@ def get_predictions(myint, Tvecs0, dev):
parser.add_argument('--mask_ids', type=str, default=None)
parser.add_argument('--setup', type=str)
parser.add_argument('--batch', type=int, default=100)
args = parser.parse_args()
args = parser.parse_args(args)
log_ids = [int(_) for _ in (args.log_ids).split(',')]
setup = args.setup
revision = args.revision
Expand Down Expand Up @@ -304,10 +303,14 @@ def get_predictions(myint, Tvecs0, dev):
myint.pc_layer.bias.data[:] = tD_0[:] + myint.pc_layer.bias.data[:] * tSD_0

torch.save(myint.state_dict(), finalfile_path)
pred = get_predictions(myint, Tvecs0, train_dev)
idlsave.save(f'{directory}/pred_{setup}.psav',
'pred,vecs,dats,mapper,vecs_orig,mask_ids', pred, vecs, dats,
mapper, vecs_orig, mask_ids)
pred = get_predictions(myint, Tvecs0, train_dev, batch)
with open(f'{directory}/pred_{setup}.psav', 'wb') as fp:
DD = {}
DD['pred'] = pred,
DD['vecs'] = vecs
DD['dats'] = dats,
DD['mapper'] = mapper
DD['vecs_orig'] = vecs_orig
if os.path.exists(statefile_path):
os.unlink(statefile_path)
import rvspecfit.nn.RVSInterpolator
Expand All @@ -331,3 +334,7 @@ def get_predictions(myint, Tvecs0, dev):
'interpolation_type': 'generic'
}
pickle.dump(D, fp)


if __name__ == '__main__':
main(sys.argv[1:])

0 comments on commit 75c610b

Please sign in to comment.