From a1826c5f34b5e4ed668772988153e39b89813591 Mon Sep 17 00:00:00 2001 From: "Sergey E. Koposov" Date: Tue, 16 Jan 2024 15:01:39 +0000 Subject: [PATCH] generalize param mapper --- py/rvspecfit/make_interpol.py | 47 ++++++++++++++++++++++++++++++++--- py/rvspecfit/read_grid.py | 20 ++++++++++----- 2 files changed, 57 insertions(+), 10 deletions(-) diff --git a/py/rvspecfit/make_interpol.py b/py/rvspecfit/make_interpol.py index 376f910..54a4fe0 100644 --- a/py/rvspecfit/make_interpol.py +++ b/py/rvspecfit/make_interpol.py @@ -149,7 +149,7 @@ def __call__(self, x): def process_all(setupInfo, - postf='', + parnames=('teff', 'logg', 'feh', 'alpha'), dbfile='/tmp/files.db', oprefix='psavs/', prefix=None, @@ -158,12 +158,31 @@ def process_all(setupInfo, resolution0=None, normalize=True, revision='', - nthreads=8): + nthreads=8, + log_parameters=None): + """ + Process the whole library of spectra and prepare the pickle file + with arrays of convolved spectra, wavelength arrays, transformed + parameters + + Parameters + ----------- + setupInfo: string + The name of spectral configuration + parnames: list of strings + The parameter names of spectra + log_parameters: integer positions of parameters that needs + to be log10() for interpolation. I.e. if the first parameter si teff + and we want to perform interpolation in log(teff) space + this needs to be [0] + air: boolean + Transform from vacuum to air + + """ if not os.path.exists(dbfile): raise RuntimeError('The template database file %s does not exist' % dbfile) conn = sqlite3.connect(dbfile) - parnames = ('teff', 'logg', 'feh', 'alpha') parname_str = ','.join(list(parnames)) cur = conn.execute(f'''select id, {parname_str} from files where not bad order by {parname_str}''') @@ -180,7 +199,7 @@ def process_all(setupInfo, dbfile=dbfile, prefix=prefix, wavefile=wavefile) - mapper = read_grid.ParamMapper() + mapper = read_grid.ParamMapper(log_parameters) HR, lamleft, lamright, resol_function, step, log = setupInfo deltav = 1000. # extra padding @@ -281,6 +300,20 @@ def main(args): help='The revision of the templates', default='', required=False) + parser.add_argument( + '--parameter_names', + type=str, + default='teff,logg,feh,alpha', + help= + 'comma separated list of parameters defined to make the interpolator', + required=False) + + parser.add_argument( + '--log_parameters', + type=str, + default='0', + help='Which parameters we are taking the log() of when interpolating', + required=False) parser.add_argument( '--resol_func', @@ -365,8 +398,14 @@ def main(args): else: resol_func = Resolution(resol_func=args.resol_func) + log_parameters = [int(_) for _ in args.log_parameters.split(',')] + + parnames = args.parameter_names.split(',') + process_all((args.setup, args.lambda0, args.lambda1, resol_func, args.step, args.log), + parnames=parnames, + log_parameters=log_parameters, dbfile=args.templdb, oprefix=args.oprefix, prefix=args.templprefix, diff --git a/py/rvspecfit/read_grid.py b/py/rvspecfit/read_grid.py index c1ff4a9..17e9f67 100644 --- a/py/rvspecfit/read_grid.py +++ b/py/rvspecfit/read_grid.py @@ -100,14 +100,16 @@ def pix_integrator(x1, x2, l1, l2, s): return ret1, ret2 -class ParamMapper: +class LogParamMapper: """ Class used to map stellar atmospheric parameters into more suitable space used for interpolation """ - def __init__(self): - pass + def __init__(self, log_ids): + # Specify which parameter numbers to log() for + # interpolation + self.log_ids = log_ids def forward(self, vec): """ @@ -117,14 +119,17 @@ def forward(self, vec): Parameters ----------- vec: numpy array - The vector of atmospheric parameters Teff, logg, feh, alpha + The vector of atmospheric parameters i.e. Teff, logg, feh, alpha Returns ---------- ret: numpy array The vector of transformed parameters used in interpolation """ - return np.array([np.log10(vec[0]), vec[1], vec[2], vec[3]]) + vec1 = np.array(vec) + for i in self.log_ids: + vec1[i] = np.log10(vec1[i]) + return vec1 def inverse(self, vec): """ @@ -143,7 +148,10 @@ def inverse(self, vec): ret: numpy array The vector of original atmospheric parameters. """ - return np.array([10**vec[0], vec[1], vec[2], vec[3]]) + vec1 = np.array(vec) + for i in self.log_ids: + vec1[i] = 10**(vec1[i]) + return vec1 def makedb(prefix='/PHOENIX-ACES-AGSS-COND-2011/',