From 75c610b7285b24f161fa785a5d2884e8e3393612 Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Mon, 12 Feb 2024 18:41:30 +0000 Subject: [PATCH] add nn interpolator script --- bin/rvs_train_nn_interpolator | 7 ++++++ py/rvspecfit/nn/train_interpolator.py | 33 ++++++++++++++++----------- 2 files changed, 27 insertions(+), 13 deletions(-) create mode 100644 bin/rvs_train_nn_interpolator diff --git a/bin/rvs_train_nn_interpolator b/bin/rvs_train_nn_interpolator new file mode 100644 index 0000000..f2f0a0f --- /dev/null +++ b/bin/rvs_train_nn_interpolator @@ -0,0 +1,7 @@ +#!/usr/bin/env python + +import sys +import rvspecfit.nn.train_interpolator as rvt + +if __name__ == '__main__': + rvt.main(sys.argv[1:]) diff --git a/py/rvspecfit/nn/train_interpolator.py b/py/rvspecfit/nn/train_interpolator.py index 7a9e572..fd8b4ce 100644 --- a/py/rvspecfit/nn/train_interpolator.py +++ b/py/rvspecfit/nn/train_interpolator.py @@ -1,16 +1,15 @@ -import rvspecfit +import pickle +import os +import sys +import argparse import time +import rvspecfit import torch.utils.data as toda -# import matplotlib.pyplot as plt import torch.nn.functional as tofu import sklearn.decomposition as skde -import pickle import numpy as np import torch -import os -from idlsave import idlsave from NNInterpolator import Mapper, NNInterpolator -import argparse git_rev = rvspecfit.__version__ @@ -46,7 +45,7 @@ def getSchedOptim(optim): return sched -def get_predictions(myint, Tvecs0, dev): +def get_predictions(myint, Tvecs0, dev, batch): alldat = toda.TensorDataset(Tvecs0) pred = [] batchdatF = toda.DataLoader( @@ -64,7 +63,7 @@ def get_predictions(myint, Tvecs0, dev): return pred -if __name__ == '__main__': +def main(args): parser = argparse.ArgumentParser() parser.add_argument('--cpu', action='store_true', default=False) @@ -89,7 +88,7 @@ def get_predictions(myint, Tvecs0, dev): parser.add_argument('--mask_ids', type=str, default=None) parser.add_argument('--setup', type=str) parser.add_argument('--batch', type=int, default=100) - args = parser.parse_args() + args = parser.parse_args(args) log_ids = [int(_) for _ in (args.log_ids).split(',')] setup = args.setup revision = args.revision @@ -304,10 +303,14 @@ def get_predictions(myint, Tvecs0, dev): myint.pc_layer.bias.data[:] = tD_0[:] + myint.pc_layer.bias.data[:] * tSD_0 torch.save(myint.state_dict(), finalfile_path) - pred = get_predictions(myint, Tvecs0, train_dev) - idlsave.save(f'{directory}/pred_{setup}.psav', - 'pred,vecs,dats,mapper,vecs_orig,mask_ids', pred, vecs, dats, - mapper, vecs_orig, mask_ids) + pred = get_predictions(myint, Tvecs0, train_dev, batch) + with open(f'{directory}/pred_{setup}.psav', 'wb') as fp: + DD = {} + DD['pred'] = pred, + DD['vecs'] = vecs + DD['dats'] = dats, + DD['mapper'] = mapper + DD['vecs_orig'] = vecs_orig if os.path.exists(statefile_path): os.unlink(statefile_path) import rvspecfit.nn.RVSInterpolator @@ -331,3 +334,7 @@ def get_predictions(myint, Tvecs0, dev): 'interpolation_type': 'generic' } pickle.dump(D, fp) + + +if __name__ == '__main__': + main(sys.argv[1:])