Warning
This is code is not the code used for the paper but a reimplementation in JAX. The original code was written in PyTorch. The code currently lacks functionality to fully reproduce the experiments in the original paper.
Warning
The code is lacking fundamental functionality described in the paper. Most importantly, it currently does not apply the Laplace approximation at the end.
This is a reimplementation of the Bayesian Stein Network (BSN) [1] in JAX. The BSN is a network architecture that allows computing integrals using a neural network and obtain an uncertainty estimate for the predicted value.
The original code used for the paper was written in PyTorch. I wanted to test whether some functionality is easier to implement in JAX than in PyTorch. And indeed, computing the gradients inside the Stein-Operator seems to be easier in JAX than in PyTorch thanks to vmap
.
The BSN is used to compute integrals of the form
where
Given a standard neural network
where
For more details, see the original paper [1].
Currently, the implementation only includes the 1-dimensional Genz-family data set [3]. Here is the result of running the experiment on the 1-dimensional continuous Genz data set. Each value is computed as the mean of three runs:
As described in the paper, on a 1-dimensional dataset like this, the Stein-Network might not the best option (for a Bayesian option one could consider Bayesian quadrature with an appropriate kernel). However, the interpolation capabilities of the Stein network already lead to a large advantage over plain Monte-Carlo sampling. Of course, on the Genz data set where both
To run the experiment yourself, run
python run_experiment.py
The parameters can be changed in run_experiment.py
. Per default, the code uses scipy's L-BFGS-B which does not work with CUDA and does not allow jitting.
To install the package you can use
pip install git+https://github.com/katharina-ott/bsn-jax
[1] Ott, K., Tiemann, M., Hennig, P. & Briol, F.X. (2023). Bayesian numerical integration with neural networks. Proceedings of the Thirty-Ninth Conference on Uncertainty in Artificial Intelligence, in Proceedings of Machine Learning Research 216:1606-1617 Available from https://proceedings.mlr.press/v216/ott23a.html.
[2] Anastasiou, A., Barp, A., Briol, F. X., Ebner, B., Gaunt, R. E., Ghaderinezhad, F., ... & Swan, Y. (2023). Stein’s method meets computational statistics: A review of some recent developments. Statistical Science, 38(1), 120-139.