Skip to content

Commit

Permalink
try to not store mapper as a class
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Nov 16, 2024
1 parent 8c9bd0e commit 48d4e5e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 19 deletions.
8 changes: 6 additions & 2 deletions py/rvspecfit/make_interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ def process_all(setupInfo,
dbfile=dbfile,
prefix=prefix,
wavefile=wavefile)
mapper = read_grid.LogParamMapper(log_parameters)
mapper_module = 'rvspecfit.read_grid'
mapper_class = 'LogParamMapper'
mapper_args = (log_parameters, )
HR, lamleft, lamright, resol_function, step, log = setupInfo
if templ_lam.min() > lamleft or templ_lam.max() < lamright:
raise RuntimeError(f'''Cannot generate the spectra as the wavelength
Expand Down Expand Up @@ -284,7 +286,9 @@ def process_all(setupInfo,
lam=lam,
parnames=parnames,
git_rev=git_rev,
mapper=mapper,
mapper_module=mapper_module,
mapper_class_name=mapper_class,
mapper_args=mapper_args,
revision=revision,
lognorms=lognorms,
logstep=log,
Expand Down
7 changes: 4 additions & 3 deletions py/rvspecfit/make_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ def execute(spec_setup, prefix=None, regular=False, perturb=True, revision=''):

cur_fname = ('%s/' + make_interpol.SPECS_H5_NAME) % (prefix, spec_setup)
D = serializer.load_dict_from_hdf5(cur_fname)
(vec, specs, lam, parnames, mapper, lognorms,
log_step) = (D['vec'], D['specs'], D['lam'], D['parnames'], D['mapper'],
D['lognorms'], D['log_step'])
(vec, specs, lam, parnames, mapper_class, lognorms,
log_step) = (D['vec'], D['specs'], D['lam'], D['parnames'],
D['mapper_class'], D['mapper_args'], D['lognorms'],
D['log_step'])
del D

vec = vec.astype(float)
Expand Down
20 changes: 12 additions & 8 deletions py/rvspecfit/nn/train_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ def getData(dir, setup, log_ids=[0]):
xvecs[:, ii] = np.log10(vecs[:, ii])
M = xvecs.mean(axis=0)
S = xvecs.std(axis=0)
mapper = Mapper(M, S, log_ids)
mapper_args = (M, S, log_ids)
mapper = Mapper(*mapper_args)

# vecs[:, 0] = np.log10(vecs[:, 0])

vecs_trans = mapper.forward(vecs)
return lam, vecs_trans, dats, mapper, vecs
return lam, vecs_trans, dats, mapper, vecs, mapper_args


def getSchedOptim(optim):
Expand Down Expand Up @@ -108,9 +109,9 @@ def main(args):
rstate = np.random.default_rng(44)

# vecs are transformed already
lam, vecs, dats, mapper, vecs_orig = getData(directory,
setup,
log_ids=log_ids)
lam, vecs, dats, mapper, vecs_orig, mapper_args = getData(directory,
setup,
log_ids=log_ids)
D_0 = np.mean(dats, axis=0)
SD_0 = np.std(dats, axis=0)
tD_0 = torch.tensor(D_0)
Expand Down Expand Up @@ -299,14 +300,17 @@ def main(args):
if os.path.exists(statefile_path):
os.unlink(statefile_path)
import rvspecfit.nn.RVSInterpolator # noqa

mapper_module = 'rvspecfit.nn.NNInterpolator'
mapper_class_name = 'Mapper'
ofname = f'{directory}/interp_{setup}.pkl'
D = {
'mapper': mapper,
'mapper_module': mapper_module,
'mapper_class_name': mapper_class_name,
'mapper_args': mapper_args,
'parnames': parnames,
'lam': lam,
'log_spec': True,
'logstep': True,
'log_step': True,
'module': 'rvspecfit.nn.RVSInterpolator',
'class_name': 'RVSInterpolator',
'device': device_name,
Expand Down
17 changes: 11 additions & 6 deletions py/rvspecfit/spec_inter.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(self,
revision='',
filename='',
creation_soft_version='',
logstep=None):
log_step=None):
""" Construct the interpolator object
Parameters
Expand Down Expand Up @@ -241,7 +241,7 @@ def __init__(self,
self.revision = revision
self.filename = filename
self.creation_soft_version = creation_soft_version
self.logstep = logstep
self.log_step = log_step
self.objid = hash((self.name, self.parnames, self.revision,
self.filename, self.creation_soft_version))

Expand Down Expand Up @@ -317,9 +317,14 @@ def getInterpolator(HR, config, warmup_cache=False, cache=None):
fd = pickle.load(fd0)
log_spec = fd.get('log_spec') or True

(templ_lam, mapper, parnames) = (fd['lam'], fd['mapper'],
fd['parnames'])
logstep = fd['logstep']
(templ_lam, parnames) = (fd['lam'], fd['parnames'])
mapper_module = fd['mapper_module']
mapper_class_name = fd['mapper_class_name']
mapper_args = fd['mapper_args']
mod = importlib.import_module(mapper_module)
mapper = getattr(mod, mapper_class_name)(*mapper_args)

log_step = fd['log_step']

if 'interpolation_type' in fd:
interp_type = fd['interpolation_type']
Expand Down Expand Up @@ -373,7 +378,7 @@ def getInterpolator(HR, config, warmup_cache=False, cache=None):
revision=revision,
creation_soft_version=creation_soft_version,
filename=savefile,
logstep=logstep)
log_step=log_step)
cache[HR] = interpObj
if system_cache:
interp_cache.template_lib = config['template_lib']
Expand Down

0 comments on commit 48d4e5e

Please sign in to comment.