Skip to content

Commit

Permalink
#1294 allow setting of hyper parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
fcooper8472 committed Mar 16, 2021
1 parent fdbf30d commit 1225905
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions pints/functionaltests/_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class RunMcmcMethodOnTwoDimGaussian(RunMcmcMethodOnProblem):
distribution.
"""

def __init__(self, method, n_chains, n_iterations, n_warmup):
def __init__(self, method, n_chains, n_iterations, n_warmup, method_hyper_parameters=None):
pdf = pints.toy.GaussianLogPDF(mean=[0, 0], sigma=[1, 1])

# Get initial parameters
Expand All @@ -59,13 +59,21 @@ def __init__(self, method, n_chains, n_iterations, n_warmup):
initial_parameters = log_prior.sample(n=n_chains)

# Set up sampler
sampler = pints.MCMCController(
controller = pints.MCMCController(
pdf, n_chains, initial_parameters, method=method)
sampler.set_max_iterations(n_iterations)
sampler.set_log_to_screen(False)
controller.set_max_iterations(n_iterations)
controller.set_log_to_screen(False)

# Set hyper parameters, if required. This is different based on single/multi chain
if method_hyper_parameters is not None:
if issubclass(method, pints.MultiChainMCMC):
controller.sampler().set_hyper_parameters(method_hyper_parameters)
else:
for sampler in controller.samplers():
sampler.set_hyper_parameters(method_hyper_parameters)

# Infer posterior and throw away warm-up
chains = sampler.run()
chains = controller.run()
chains = chains[:, n_warmup:]

super().__init__(pdf, chains)

0 comments on commit 1225905

Please sign in to comment.