diff --git a/elfi/methods/inference/samplers.py b/elfi/methods/inference/samplers.py index 09de508a..abfc9cee 100644 --- a/elfi/methods/inference/samplers.py +++ b/elfi/methods/inference/samplers.py @@ -459,9 +459,21 @@ def prepare_new_batch(self, batch_index): return batch def _init_new_round(self): - round = self.state['round'] - self._update_round_info(round) + self._set_rejection_round(self.state['round']) + + if self.state['round'] == 0 and self._quantiles is not None: + self._rejection.set_objective( + self.objective['n_samples'], quantile=self._quantiles[0]) + else: + if self._quantiles is not None: + self._set_threshold() + self._rejection.set_objective( + self.objective['n_samples'], threshold=self.current_population_threshold) + + def _set_rejection_round(self, round): + + self._update_round_info(self.state['round']) # Get a subseed for this round for ensuring consistent results for the round seed = self.seed if round == 0 else get_sub_seed(self.seed, round) @@ -474,15 +486,6 @@ def _init_new_round(self): seed=seed, max_parallel_batches=self.max_parallel_batches) - if self.state['round'] == 0 and self._quantiles is not None: - self._rejection.set_objective( - self.objective['n_samples'], quantile=self._quantiles[0]) - else: - if self._quantiles is not None: - self._set_threshold() - self._rejection.set_objective( - self.objective['n_samples'], threshold=self.current_population_threshold) - def _update_round_info(self, round): if self.bar: reinit_msg = 'ABC-SMC Round {0} / {1}'.format( @@ -709,15 +712,12 @@ def __init__(self, self.q_threshold = q_threshold self.initial_quantile = initial_quantile - if densratio_estimation is None: - self.densratio = DensityRatioEstimation(n=100, - epsilon=0.001, - max_iter=200, - abs_tol=0.01, - fold=5, - optimize=False) - else: - self.densratio = densratio_estimation + self.densratio = densratio_estimation or DensityRatioEstimation(n=100, + epsilon=0.001, + max_iter=200, + abs_tol=0.01, + fold=5, + optimize=False) def set_objective(self, n_samples, @@ -742,6 +742,8 @@ def set_objective(self, # Initialise threshold selection and adaptive quantile thresholds = np.full((rounds+1), None) + self._quantiles = np.full((rounds+1), None) + self._quantiles[0] = self.initial_quantile self.objective.update( dict( @@ -779,53 +781,13 @@ def update(self, batch, batch_index): self._set_adaptive_quantile() - if self.adaptive_quantile_value < self.q_threshold: + if self._quantiles[self.state['round']+1] < self.q_threshold: self._populations.append(self._new_population) self.state['round'] += 1 self._init_new_round() self._update_objective() - def _init_new_round(self): - round = self.state['round'] - - self._update_round_info(round) - - # Get a subseed for this round for ensuring consistent results for the round - seed = self.seed if round == 0 else get_sub_seed(self.seed, round) - self._round_random_state = np.random.RandomState(seed) - self._rejection = Rejection( - self.model, - discrepancy_name=self.discrepancy_name, - output_names=self.output_names, - batch_size=self.batch_size, - seed=seed, - max_parallel_batches=self.max_parallel_batches) - - if self.state['round'] == 0: - self._rejection.set_objective( - self.objective['n_samples'], quantile=self.initial_quantile) - else: - self._rejection.set_objective( - self.objective['n_samples'], - threshold=self.current_population_threshold) - - @property - def current_population_threshold(self): - """Return the threshold for current population.""" - if self.state['round'] > 0: - self._set_threshold() - return self.objective['thresholds'][self.state['round']] - - def _set_threshold(self): - """Set current population threshold as previous population quantile.""" - threshold = weighted_sample_quantile( - x=self._populations[self.state['round']-1].discrepancies, - alpha=self.adaptive_quantile_value, - weights=self._populations[self.state['round']-1].weights) - logger.info('ABC-SMC: Selected threshold for next population %.3f' % (threshold)) - self.objective['thresholds'][self.state['round']] = threshold - def _set_adaptive_quantile(self): """Set adaptively the new threshold for current population.""" logger.info("ABC-SMC: Adapting quantile threshold...") @@ -847,8 +809,7 @@ def _set_adaptive_quantile(self): max_value = self.densratio.max_ratio() max_value = 1.0 if max_value < 1.0 else max_value - self.adaptive_quantile_value = max(1 / max_value, 0.05) - + self._quantiles[self.state['round']+1] = max(1 / max_value, 0.05) logger.info('ABC-SMC: Estimated maximum density ratio %.5f' % (max_value)) def _resolve_sample(self, backwards_index):