Skip to content

Commit

Permalink
generalize param mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Jan 16, 2024
1 parent 08b9ae5 commit a1826c5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
47 changes: 43 additions & 4 deletions py/rvspecfit/make_interpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __call__(self, x):


def process_all(setupInfo,
postf='',
parnames=('teff', 'logg', 'feh', 'alpha'),
dbfile='/tmp/files.db',
oprefix='psavs/',
prefix=None,
Expand All @@ -158,12 +158,31 @@ def process_all(setupInfo,
resolution0=None,
normalize=True,
revision='',
nthreads=8):
nthreads=8,
log_parameters=None):
"""
Process the whole library of spectra and prepare the pickle file
with arrays of convolved spectra, wavelength arrays, transformed
parameters
Parameters
-----------
setupInfo: string
The name of spectral configuration
parnames: list of strings
The parameter names of spectra
log_parameters: integer positions of parameters that needs
to be log10() for interpolation. I.e. if the first parameter si teff
and we want to perform interpolation in log(teff) space
this needs to be [0]
air: boolean
Transform from vacuum to air
"""
if not os.path.exists(dbfile):
raise RuntimeError('The template database file %s does not exist' %
dbfile)
conn = sqlite3.connect(dbfile)
parnames = ('teff', 'logg', 'feh', 'alpha')
parname_str = ','.join(list(parnames))
cur = conn.execute(f'''select id, {parname_str} from files
where not bad order by {parname_str}''')
Expand All @@ -180,7 +199,7 @@ def process_all(setupInfo,
dbfile=dbfile,
prefix=prefix,
wavefile=wavefile)
mapper = read_grid.ParamMapper()
mapper = read_grid.ParamMapper(log_parameters)
HR, lamleft, lamright, resol_function, step, log = setupInfo

deltav = 1000. # extra padding
Expand Down Expand Up @@ -281,6 +300,20 @@ def main(args):
help='The revision of the templates',
default='',
required=False)
parser.add_argument(
'--parameter_names',
type=str,
default='teff,logg,feh,alpha',
help=
'comma separated list of parameters defined to make the interpolator',
required=False)

parser.add_argument(
'--log_parameters',
type=str,
default='0',
help='Which parameters we are taking the log() of when interpolating',
required=False)

parser.add_argument(
'--resol_func',
Expand Down Expand Up @@ -365,8 +398,14 @@ def main(args):
else:
resol_func = Resolution(resol_func=args.resol_func)

log_parameters = [int(_) for _ in args.log_parameters.split(',')]

parnames = args.parameter_names.split(',')

process_all((args.setup, args.lambda0, args.lambda1, resol_func, args.step,
args.log),
parnames=parnames,
log_parameters=log_parameters,
dbfile=args.templdb,
oprefix=args.oprefix,
prefix=args.templprefix,
Expand Down
20 changes: 14 additions & 6 deletions py/rvspecfit/read_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ def pix_integrator(x1, x2, l1, l2, s):
return ret1, ret2


class ParamMapper:
class LogParamMapper:
"""
Class used to map stellar atmospheric parameters into more suitable
space used for interpolation
"""

def __init__(self):
pass
def __init__(self, log_ids):
# Specify which parameter numbers to log() for
# interpolation
self.log_ids = log_ids

def forward(self, vec):
"""
Expand All @@ -117,14 +119,17 @@ def forward(self, vec):
Parameters
-----------
vec: numpy array
The vector of atmospheric parameters Teff, logg, feh, alpha
The vector of atmospheric parameters i.e. Teff, logg, feh, alpha
Returns
----------
ret: numpy array
The vector of transformed parameters used in interpolation
"""
return np.array([np.log10(vec[0]), vec[1], vec[2], vec[3]])
vec1 = np.array(vec)
for i in self.log_ids:
vec1[i] = np.log10(vec1[i])
return vec1

def inverse(self, vec):
"""
Expand All @@ -143,7 +148,10 @@ def inverse(self, vec):
ret: numpy array
The vector of original atmospheric parameters.
"""
return np.array([10**vec[0], vec[1], vec[2], vec[3]])
vec1 = np.array(vec)
for i in self.log_ids:
vec1[i] = 10**(vec1[i])
return vec1


def makedb(prefix='/PHOENIX-ACES-AGSS-COND-2011/',
Expand Down

0 comments on commit a1826c5

Please sign in to comment.