From 48d4e5e1b8eee20cd32124f12b89ce41506c0ba7 Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Sat, 16 Nov 2024 21:24:04 +0000 Subject: [PATCH] try to not store mapper as a class --- py/rvspecfit/make_interpol.py | 8 ++++++-- py/rvspecfit/make_nd.py | 7 ++++--- py/rvspecfit/nn/train_interpolator.py | 20 ++++++++++++-------- py/rvspecfit/spec_inter.py | 17 +++++++++++------ 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/py/rvspecfit/make_interpol.py b/py/rvspecfit/make_interpol.py index 3850818..a7dfedd 100644 --- a/py/rvspecfit/make_interpol.py +++ b/py/rvspecfit/make_interpol.py @@ -221,7 +221,9 @@ def process_all(setupInfo, dbfile=dbfile, prefix=prefix, wavefile=wavefile) - mapper = read_grid.LogParamMapper(log_parameters) + mapper_module = 'rvspecfit.read_grid' + mapper_class = 'LogParamMapper' + mapper_args = (log_parameters, ) HR, lamleft, lamright, resol_function, step, log = setupInfo if templ_lam.min() > lamleft or templ_lam.max() < lamright: raise RuntimeError(f'''Cannot generate the spectra as the wavelength @@ -284,7 +286,9 @@ def process_all(setupInfo, lam=lam, parnames=parnames, git_rev=git_rev, - mapper=mapper, + mapper_module=mapper_module, + mapper_class_name=mapper_class, + mapper_args=mapper_args, revision=revision, lognorms=lognorms, logstep=log, diff --git a/py/rvspecfit/make_nd.py b/py/rvspecfit/make_nd.py index 09d0138..a5478ad 100644 --- a/py/rvspecfit/make_nd.py +++ b/py/rvspecfit/make_nd.py @@ -77,9 +77,10 @@ def execute(spec_setup, prefix=None, regular=False, perturb=True, revision=''): cur_fname = ('%s/' + make_interpol.SPECS_H5_NAME) % (prefix, spec_setup) D = serializer.load_dict_from_hdf5(cur_fname) - (vec, specs, lam, parnames, mapper, lognorms, - log_step) = (D['vec'], D['specs'], D['lam'], D['parnames'], D['mapper'], - D['lognorms'], D['log_step']) + (vec, specs, lam, parnames, mapper_class, lognorms, + log_step) = (D['vec'], D['specs'], D['lam'], D['parnames'], + D['mapper_class'], D['mapper_args'], D['lognorms'], + D['log_step']) del D vec = vec.astype(float) diff --git a/py/rvspecfit/nn/train_interpolator.py b/py/rvspecfit/nn/train_interpolator.py index e396a3c..a39a8d7 100644 --- a/py/rvspecfit/nn/train_interpolator.py +++ b/py/rvspecfit/nn/train_interpolator.py @@ -27,12 +27,13 @@ def getData(dir, setup, log_ids=[0]): xvecs[:, ii] = np.log10(vecs[:, ii]) M = xvecs.mean(axis=0) S = xvecs.std(axis=0) - mapper = Mapper(M, S, log_ids) + mapper_args = (M, S, log_ids) + mapper = Mapper(*mapper_args) # vecs[:, 0] = np.log10(vecs[:, 0]) vecs_trans = mapper.forward(vecs) - return lam, vecs_trans, dats, mapper, vecs + return lam, vecs_trans, dats, mapper, vecs, mapper_args def getSchedOptim(optim): @@ -108,9 +109,9 @@ def main(args): rstate = np.random.default_rng(44) # vecs are transformed already - lam, vecs, dats, mapper, vecs_orig = getData(directory, - setup, - log_ids=log_ids) + lam, vecs, dats, mapper, vecs_orig, mapper_args = getData(directory, + setup, + log_ids=log_ids) D_0 = np.mean(dats, axis=0) SD_0 = np.std(dats, axis=0) tD_0 = torch.tensor(D_0) @@ -299,14 +300,17 @@ def main(args): if os.path.exists(statefile_path): os.unlink(statefile_path) import rvspecfit.nn.RVSInterpolator # noqa - + mapper_module = 'rvspecfit.nn.NNInterpolator' + mapper_class_name = 'Mapper' ofname = f'{directory}/interp_{setup}.pkl' D = { - 'mapper': mapper, + 'mapper_module': mapper_module, + 'mapper_class_name': mapper_class_name, + 'mapper_args': mapper_args, 'parnames': parnames, 'lam': lam, 'log_spec': True, - 'logstep': True, + 'log_step': True, 'module': 'rvspecfit.nn.RVSInterpolator', 'class_name': 'RVSInterpolator', 'device': device_name, diff --git a/py/rvspecfit/spec_inter.py b/py/rvspecfit/spec_inter.py index 3d3ad45..0cfe04b 100644 --- a/py/rvspecfit/spec_inter.py +++ b/py/rvspecfit/spec_inter.py @@ -204,7 +204,7 @@ def __init__(self, revision='', filename='', creation_soft_version='', - logstep=None): + log_step=None): """ Construct the interpolator object Parameters @@ -241,7 +241,7 @@ def __init__(self, self.revision = revision self.filename = filename self.creation_soft_version = creation_soft_version - self.logstep = logstep + self.log_step = log_step self.objid = hash((self.name, self.parnames, self.revision, self.filename, self.creation_soft_version)) @@ -317,9 +317,14 @@ def getInterpolator(HR, config, warmup_cache=False, cache=None): fd = pickle.load(fd0) log_spec = fd.get('log_spec') or True - (templ_lam, mapper, parnames) = (fd['lam'], fd['mapper'], - fd['parnames']) - logstep = fd['logstep'] + (templ_lam, parnames) = (fd['lam'], fd['parnames']) + mapper_module = fd['mapper_module'] + mapper_class_name = fd['mapper_class_name'] + mapper_args = fd['mapper_args'] + mod = importlib.import_module(mapper_module) + mapper = getattr(mod, mapper_class_name)(*mapper_args) + + log_step = fd['log_step'] if 'interpolation_type' in fd: interp_type = fd['interpolation_type'] @@ -373,7 +378,7 @@ def getInterpolator(HR, config, warmup_cache=False, cache=None): revision=revision, creation_soft_version=creation_soft_version, filename=savefile, - logstep=logstep) + log_step=log_step) cache[HR] = interpObj if system_cache: interp_cache.template_lib = config['template_lib']