extend_params fix after n_particles removed #62
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix examples after change to
extend_params
in blackjax-devs/blackjax#694.Additionally in the TemperedSMC example
max_num_doublings
was changed to 6 instead of the default 10 since we regularly hitmax_num_doublings
due to the small step size (I believe this is for illustrative purposes). On a gpu device the example is extraordinarily slow without the change - and still takes ~2 mins with it. It seems far too slow but I haven't been able to find any explanation.For reference:
CPU (10000 samples, max_num_doublings=10):
step_size = 1e-2:
HMC: 50 steps / 1.14s
NUTS: 30 steps / .964s
step_size = 1e-3
HMC: 50 / 1.14s
NUTS: 273 / 1.9s
step_size = 1e-4
HMC: 50 / 1.18s
NUTS: 926 / 4.23s
GPU (1000 samples - 10x fewer samples..., max_num_doublings=10):
step_size = 1e-2:
HMC: 50 / 3.31s
NUTS: 30 / 7.3s
step_size = 1e-3:
HMC: 50 / 3.32s
NUTS: 267 / 63s
step_size = 1e-4
HMC: 50 / 3.31s
NUTS: 926.4 / 237s