Skip to content

Simplified reimplementation of the paper "Bayesian Numerical Integration with Neural Networks" in JAX.

License

Notifications You must be signed in to change notification settings

katharina-ott/bsn-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Bayesian Stein Network

Warning

This is code is not the code used for the paper but a reimplementation in JAX. The original code was written in PyTorch. The code currently lacks functionality to fully reproduce the experiments in the original paper.

Warning

The code is lacking fundamental functionality described in the paper. Most importantly, it currently does not apply the Laplace approximation at the end.

This is a reimplementation of the Bayesian Stein Network (BSN) [1] in JAX. The BSN is a network architecture that allows computing integrals using a neural network and obtain an uncertainty estimate for the predicted value.

The original code used for the paper was written in PyTorch. I wanted to test whether some functionality is easier to implement in JAX than in PyTorch. And indeed, computing the gradients inside the Stein-Operator seems to be easier in JAX than in PyTorch thanks to vmap.

Task

The BSN is used to compute integrals of the form

$$\Pi[f] = \int_{\mathcal{X}} f(x) \pi(x) dx,$$

where $\mathcal{X} \in \mathbb{R}^d$, $f: \mathcal{X} \rightarrow \mathbb{R}$, a function, and $\pi: \mathcal{X} \rightarrow \mathbb{R}^+$, a probability density function.

Given a standard neural network $u_{\theta_u}$, we define the following network architecture (based on the Stein Operator see [2]:

$$g_{\theta} = \left(\nabla_x \log \pi(x)\right)^\top u(x) + \nabla_x \cdot u(x) + \theta_0,$$

where $\theta = \{\theta_u, \theta_0\}$ are the parameters of the network $g$. To train the network, we use data $\{x_i, y_i\}_{i=1}^n$ where $x_i \sim \pi$ and $y_i \sim f(x_i)$.

For more details, see the original paper [1].

Experiments

Currently, the implementation only includes the 1-dimensional Genz-family data set [3]. Here is the result of running the experiment on the 1-dimensional continuous Genz data set. Each value is computed as the mean of three runs:

drawing

As described in the paper, on a 1-dimensional dataset like this, the Stein-Network might not the best option (for a Bayesian option one could consider Bayesian quadrature with an appropriate kernel). However, the interpolation capabilities of the Stein network already lead to a large advantage over plain Monte-Carlo sampling. Of course, on the Genz data set where both $f$ and $\pi$ are cheap to evaluate, sampling is the superior option, as we can obtain large data sets quickly. For illustrative purposes we compare Monte-Carlo and the BSN for a fixed number of data points.

To run the experiment yourself, run

python run_experiment.py

The parameters can be changed in run_experiment.py. Per default, the code uses scipy's L-BFGS-B which does not work with CUDA and does not allow jitting.

Installation

To install the package you can use

pip install git+https://github.com/katharina-ott/bsn-jax

References

[1] Ott, K., Tiemann, M., Hennig, P. & Briol, F.X. (2023). Bayesian numerical integration with neural networks. Proceedings of the Thirty-Ninth Conference on Uncertainty in Artificial Intelligence, in Proceedings of Machine Learning Research 216:1606-1617 Available from https://proceedings.mlr.press/v216/ott23a.html.

[2] Anastasiou, A., Barp, A., Briol, F. X., Ebner, B., Gaunt, R. E., Ghaderinezhad, F., ... & Swan, Y. (2023). Stein’s method meets computational statistics: A review of some recent developments. Statistical Science, 38(1), 120-139.

[3] https://www.sfu.ca/~ssurjano/integration.html

About

Simplified reimplementation of the paper "Bayesian Numerical Integration with Neural Networks" in JAX.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages