This is my implementation of several bayesian learning methods including SGLD, pSGLD, and MALA using JAX. The implementation is used to get the posterior estimate of the parameter of a RingMixture model. You can read my report for more details.
To run the code, JAX needs to be installed in advance. You can also run the code through this Google Colab link without installing anything. GPU runtime type can be selected in Google Colab to accelerate the computation.
If you find the code useful, please star this repo. Thank you!