Skip to content

Commit

Permalink
fixed bug reported by Paul + refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Javier Sanchez authored and Javier Sanchez committed Jan 28, 2025
1 parent 876d918 commit aaf2dd2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 102 deletions.
69 changes: 14 additions & 55 deletions augur/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
from augur.utils.diff_utils import five_pt_stencil
from augur import generate
from augur.utils.config_io import parse_config
from firecrown.parameters import ParamsMap
from augur.utils.theory_utils import compute_new_theory_vector
from astropy.table import Table
import warnings
from packaging.version import Version
import firecrown


class Analyze(object):
Expand Down Expand Up @@ -138,7 +136,6 @@ def __init__(self, config, likelihood=None, tools=None, req_params=None,
# Normalize the pivot point given the sampling region
if self.norm_step:
self.norm = self.par_bounds[:, 1] - self.par_bounds[:, 0]
self.x = (self.x - self.par_bounds[:, 0]) * 1/self.norm

def f(self, x, labels, pars_fid, sys_fid, donorm=False):
"""
Expand Down Expand Up @@ -223,21 +220,29 @@ def get_derivatives(self, force=False, method='5pt_stencil', step=None):
# Compute the derivatives with respect to the parameters in var_pars at x
if (self.derivatives is None) or (force):
if '5pt_stencil' in method:
if self.norm_step:
x_here = (self.x - self.par_bounds[:, 0]) * 1/self.norm
else:
x_here = self.x
self.derivatives = five_pt_stencil(lambda y: self.f(y, self.var_pars, self.pars_fid,
self.req_params, donorm=self.norm_step),
self.x, h=step)
x_here, h=step)
elif 'numdifftools' in method:
import numdifftools as nd
if 'numdifftools_kwargs' in self.config.keys():
ndkwargs = self.config['numdifftools_kwargs']
else:
ndkwargs = {}
if self.norm_step:
x_here = (self.x - self.par_bounds[:, 0]) * 1/self.norm
else:
x_here = self.x
jacobian_calc = nd.Jacobian(lambda y: self.f(y, self.var_pars, self.pars_fid,
self.req_params,
donorm=self.norm_step),
step=step,
**ndkwargs)
self.derivatives = jacobian_calc(self.x).T
self.derivatives = jacobian_calc(x_here).T
else:
raise ValueError(f'Selected method: `{method}` is not available. \
Please select 5pt_stencil or numdifftools.')
Expand Down Expand Up @@ -350,52 +355,6 @@ def compute_new_theory_vector(self, _sys_pars, _pars):
f_out : ndarray,
Predicted data vector for the given input parameters _sys_pars, _pars.
"""
self.lk.reset()
self.tools.reset()
if Version(firecrown.__version__) < Version('1.8.0a'):
pmap = ParamsMap(_sys_pars)
cosmo = ccl.Cosmology(**_pars)
self.lk.update(pmap)
self.tools.update(pmap)
self.tools.prepare(cosmo)
f_out = self.lk.compute_theory_vector(self.tools)
return f_out
else:
from firecrown.ccl_factory import CCLFactory
dict_all = {**_sys_pars, **_pars}
extra_dict = {}
if dict_all['A_s'] is None:
extra_dict['amplitude_parameter'] = 'sigma8'
dict_all.pop('A_s')
else:
extra_dict['amplitude_parameter'] = 'A_s'
dict_all.pop('sigma8')

extra_dict['mass_split'] = dict_all['mass_split']
dict_all.pop('mass_split')
if 'extra_parameters' in dict_all.keys():
if 'camb' in dict_all['extra_parameters'].keys():
extra_dict['camb_extra_params'] = dict_all['extra_parameters']['camb']
if 'kmin' in dict_all['extra_parameters']['camb'].keys():
extra_dict['camb_extra_params'].pop('kmin')
dict_all.pop('extra_parameters')
keys = list(dict_all.keys())

# Remove None values
for key in keys:
if (dict_all[key] is None) or (dict_all[key] == 'None'):
dict_all.pop(key)
if self.cf is None:
for key in extra_dict.keys():
print(extra_dict[key], type(extra_dict[key]))
self.cf = CCLFactory(**extra_dict)
self.tools = firecrown.modeling_tools.ModelingTools(ccl_factory=self.cf)
self.tools.reset()
pmap = ParamsMap(dict_all)
self.cf.update(pmap)
self.tools.update(pmap)
self.tools.prepare()
self.lk.update(pmap)
f_out = self.lk.compute_theory_vector(self.tools)

return f_out
f_out = compute_new_theory_vector(self.lk, self.tools, _sys_pars, _pars)

return f_out
52 changes: 5 additions & 47 deletions augur/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@
from augur.tracers.two_point import ZDistFromFile
from augur.utils.cov_utils import get_gaus_cov, get_SRD_cov, get_noise_power
from augur.utils.cov_utils import TJPCovGaus
from augur.utils.theory_utils import compute_new_theory_vector
from packaging.version import Version
import firecrown

if Version(firecrown.__version__) >= Version('1.8.0a'):
import firecrown.likelihood.weak_lensing as wl
import firecrown.likelihood.number_counts as nc
from firecrown.likelihood.two_point import TwoPoint
from firecrown.likelihood.gaussian import ConstGaussian
from firecrown.ccl_factory import CCLFactory
elif Version(firecrown.__version__) >= Version('1.7.4'):
import firecrown.likelihood.gauss_family.statistic.source.weak_lensing as wl
import firecrown.likelihood.gauss_family.statistic.source.number_counts as nc
from firecrown.likelihood.gauss_family.statistic.two_point import TwoPoint
from firecrown.likelihood.gauss_family.gaussian import ConstGaussian
from firecrown.modeling_tools import ModelingTools
from firecrown.parameters import ParamsMap
from augur.utils.config_io import parse_config

Expand Down Expand Up @@ -349,52 +349,10 @@ def generate(config, return_all_outputs=False, write_sacc=True):
lk.read(S)

cosmo.compute_nonlin_power()

_pars = cosmo.__dict__['_params_init_kwargs']
# Populate ModelingTools and likelihood

# Old firecrown
if Version(firecrown.__version__) < Version('1.8.0a'):
tools = ModelingTools()
lk.update(sys_params)
tools.update(sys_params)
tools.prepare(cosmo)
# Run the likelihood (to get the theory)
lk.compute_loglike(tools)
# New firecrown with CCLFactory
else:
_pars = cosmo.__dict__['_params_init_kwargs']
dict_all = {**sys_params, **_pars}
extra_dict = {}
if dict_all['A_s'] is None:
extra_dict['amplitude_parameter'] = 'sigma8'
dict_all.pop('A_s')
else:
extra_dict['amplitude_parameter'] = 'A_s'
dict_all.pop('sigma8')

extra_dict['mass_split'] = dict_all['mass_split']
dict_all.pop('mass_split')
if 'extra_parameters' in dict_all.keys():
if 'camb' in dict_all['extra_parameters'].keys():
extra_dict['camb_extra_params'] = dict_all['extra_parameters']['camb']
if 'kmin' in dict_all['extra_parameters']['camb'].keys():
extra_dict['camb_extra_params'].pop('kmin')
dict_all.pop('extra_parameters')
keys = list(dict_all.keys())

# Remove None values from dict_all
for key in keys:
if (dict_all[key] is None) or (dict_all[key] == 'None'):
dict_all.pop(key)
cf = CCLFactory(**extra_dict)
tools = ModelingTools(ccl_factory=cf)
tools.reset()
pmap = ParamsMap(dict_all)
cf.update(pmap)
tools.update(pmap)
tools.prepare()
lk.update(pmap)
lk.compute_theory_vector(tools)
tools = firecrown.modeling_tools.ModelingTools()
_, lk, tools = compute_new_theory_vector(lk, tools, sys_params, _pars, return_all=True)

# Get all bandpower windows before erasing the placeholder sacc
win_dict = {}
Expand Down

0 comments on commit aaf2dd2

Please sign in to comment.