Skip to content

Commit

Permalink
Deploying to gh-pages from @ a5f7482 🚀
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Apr 8, 2024
1 parent a4589bd commit c7977f0
Show file tree
Hide file tree
Showing 63 changed files with 131 additions and 137 deletions.
Binary file modified .doctrees/autoapi/blackjax/mcmc/trajectory/index.doctree
Binary file not shown.
Binary file modified .doctrees/environment.pickle
Binary file not shown.
Binary file modified .doctrees/examples/howto_custom_gradients.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_metropolis_within_gibbs.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_reproduce_the_blackjax_image.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_sample_multiple_chains.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_aesara.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_numpyro.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_oryx.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_pymc.doctree
Binary file not shown.
Binary file modified .doctrees/examples/howto_use_tfp.doctree
Binary file not shown.
Binary file modified .doctrees/examples/quickstart.doctree
Binary file not shown.
Binary file modified .doctrees/index.doctree
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions _modules/blackjax/_version.html
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ <h1>Source code for blackjax._version</h1><div class="highlight"><pre>
<span class="n">version_tuple</span><span class="p">:</span> <span class="n">VERSION_TUPLE</span></div>


<span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;0.1.dev1+g7cf4f9d&#39;</span>
<span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;dev1&#39;</span><span class="p">,</span> <span class="s1">&#39;g7cf4f9d&#39;</span><span class="p">)</span>
<span class="n">__version__</span> <span class="o">=</span> <span class="n">version</span> <span class="o">=</span> <span class="s1">&#39;0.1.dev1+ga5f7482&#39;</span>
<span class="n">__version_tuple__</span> <span class="o">=</span> <span class="n">version_tuple</span> <span class="o">=</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="s1">&#39;dev1&#39;</span><span class="p">,</span> <span class="s1">&#39;ga5f7482&#39;</span><span class="p">)</span>
</pre></div>

</article>
Expand Down
16 changes: 7 additions & 9 deletions _modules/blackjax/adaptation/chees_adaptation.html
Original file line number Diff line number Diff line change
Expand Up @@ -764,20 +764,18 @@ <h1>Source code for blackjax.adaptation.chees_adaptation</h1><div class="highlig
<span class="p">),</span> <span class="s2">&quot;initial `positions` leading dimension must be equal to the `num_chains`&quot;</span>
<span class="n">num_dim</span> <span class="o">=</span> <span class="n">pytree_size</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span> <span class="o">//</span> <span class="n">num_chains</span>

<span class="n">key_init</span><span class="p">,</span> <span class="n">key_step</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_key</span><span class="p">)</span>
<span class="n">next_random_arg_fn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">init_random_arg</span> <span class="o">=</span> <span class="mi">0</span>

<span class="k">if</span> <span class="n">jitter_generator</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">jitter_gn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">key</span><span class="p">:</span> <span class="n">jitter_generator</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="o">*</span> <span class="n">jitter_amount</span> <span class="o">+</span> <span class="p">(</span>
<span class="mf">1.0</span> <span class="o">-</span> <span class="n">jitter_amount</span>
<span class="p">)</span>
<span class="n">next_random_arg_fn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">key</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
<span class="n">init_random_arg</span> <span class="o">=</span> <span class="n">key_init</span>
<span class="n">rng_key</span><span class="p">,</span> <span class="n">carry_key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_key</span><span class="p">)</span>
<span class="n">jitter_gn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">:</span> <span class="n">jitter_generator</span><span class="p">(</span>
<span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">fold_in</span><span class="p">(</span><span class="n">carry_key</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
<span class="p">)</span> <span class="o">*</span> <span class="n">jitter_amount</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">jitter_amount</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">jitter_gn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">:</span> <span class="n">dynamic_hmc</span><span class="o">.</span><span class="n">halton_sequence</span><span class="p">(</span>
<span class="n">i</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">log2</span><span class="p">(</span><span class="n">num_steps</span> <span class="o">+</span> <span class="n">max_sampling_steps</span><span class="p">))</span>
<span class="p">)</span> <span class="o">*</span> <span class="n">jitter_amount</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">jitter_amount</span><span class="p">)</span>
<span class="n">next_random_arg_fn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">i</span><span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span>
<span class="n">init_random_arg</span> <span class="o">=</span> <span class="mi">0</span>

<span class="k">def</span> <span class="nf">integration_steps_fn</span><span class="p">(</span><span class="n">random_generator_arg</span><span class="p">,</span> <span class="n">trajectory_length_adjusted</span><span class="p">):</span>
<span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span>
Expand Down Expand Up @@ -828,7 +826,7 @@ <h1>Source code for blackjax.adaptation.chees_adaptation</h1><div class="highlig
<span class="n">init_states</span> <span class="o">=</span> <span class="n">batch_init</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span>
<span class="n">init_adaptation_state</span> <span class="o">=</span> <span class="n">init</span><span class="p">(</span><span class="n">init_random_arg</span><span class="p">,</span> <span class="n">step_size</span><span class="p">)</span>

<span class="n">keys_step</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">key_step</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">)</span>
<span class="n">keys_step</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">rng_key</span><span class="p">,</span> <span class="n">num_steps</span><span class="p">)</span>
<span class="p">(</span><span class="n">last_states</span><span class="p">,</span> <span class="n">last_adaptation_state</span><span class="p">),</span> <span class="n">info</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span>
<span class="n">one_step</span><span class="p">,</span> <span class="p">(</span><span class="n">init_states</span><span class="p">,</span> <span class="n">init_adaptation_state</span><span class="p">),</span> <span class="n">keys_step</span>
<span class="p">)</span>
Expand Down
4 changes: 2 additions & 2 deletions _modules/blackjax/adaptation/mclmc_adaptation.html
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,12 @@ <h1>Source code for blackjax.adaptation.mclmc_adaptation</h1><div class="highlig
<span class="n">kalman_state</span> <span class="o">=</span> <span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dim</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">dim</span><span class="p">))</span>

<span class="c1"># run the steps</span>
<span class="n">kalman_state</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span>
<span class="n">kalman_state</span><span class="p">,</span> <span class="o">*</span><span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span>
<span class="n">step</span><span class="p">,</span>
<span class="n">init</span><span class="o">=</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">adap0</span><span class="p">,</span> <span class="n">kalman_state</span><span class="p">),</span>
<span class="n">xs</span><span class="o">=</span><span class="p">(</span><span class="n">outer_weights</span><span class="p">,</span> <span class="n">L_step_size_adaptation_keys</span><span class="p">),</span>
<span class="n">length</span><span class="o">=</span><span class="n">num_steps1</span> <span class="o">+</span> <span class="n">num_steps2</span><span class="p">,</span>
<span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
<span class="p">)</span>
<span class="n">state</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">kalman_state_output</span> <span class="o">=</span> <span class="n">kalman_state</span>

<span class="n">L</span> <span class="o">=</span> <span class="n">params</span><span class="o">.</span><span class="n">L</span>
Expand Down
8 changes: 4 additions & 4 deletions _modules/blackjax/adaptation/meads_adaptation.html
Original file line number Diff line number Diff line change
Expand Up @@ -481,16 +481,16 @@ <h1>Source code for blackjax.adaptation.meads_adaptation</h1><div class="highlig
<span class="sd"> of the generalized HMC algorithm.</span>

<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">mean_position</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">positions</span><span class="p">)</span>
<span class="n">sd_position</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">positions</span><span class="p">)</span>
<span class="n">normalized_positions</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span>
<span class="n">mean_position</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">positions</span><span class="p">)</span>
<span class="n">sd_position</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">positions</span><span class="p">)</span>
<span class="n">normalized_positions</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span>
<span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">mu</span><span class="p">,</span> <span class="n">sd</span><span class="p">:</span> <span class="p">(</span><span class="n">p</span> <span class="o">-</span> <span class="n">mu</span><span class="p">)</span> <span class="o">/</span> <span class="n">sd</span><span class="p">,</span>
<span class="n">positions</span><span class="p">,</span>
<span class="n">mean_position</span><span class="p">,</span>
<span class="n">sd_position</span><span class="p">,</span>
<span class="p">)</span>

<span class="n">batch_grad_scaled</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span>
<span class="n">batch_grad_scaled</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span>
<span class="k">lambda</span> <span class="n">grad</span><span class="p">,</span> <span class="n">sd</span><span class="p">:</span> <span class="n">grad</span> <span class="o">*</span> <span class="n">sd</span><span class="p">,</span> <span class="n">logdensity_grad</span><span class="p">,</span> <span class="n">sd_position</span>
<span class="p">)</span>

Expand Down
Loading

0 comments on commit c7977f0

Please sign in to comment.