Please follow our established coding style including variable names, module imports, and function definitions.
The NumPyro codebase follows the PEP8 style guide
(which you can check with make lint
) and follows
import order (which you can enforce with make format
To set up local development environment, install NumPyro from source:
git clone
# install jax/jaxlib first for CUDA support
pip install -e .[dev,test] # contains additional dependencies for NumPyro development
Before submitting a pull request, please autoformat code and ensure that unit tests pass locally
make lint # linting
make format # runs black and isort
make test # linting and unit tests
make doctest # test module's docstrings
To run all tests locally in parallel, use the pytest-xdist
pip install pytest-xdist
pytest -vs -n auto
To run a single test from the command line
pytest -vs {path_to_test}::{test_name}
# or in cuda mode and double precision
JAX_PLATFORM_NAME=gpu JAX_ENABLE_X64=1 pytest -vs {path_to_test}::{test_name}
TensorBoard can be used to profile NumPyro following the instructions following JAX documentation.
For relevant design questions to consider, see past design documents.
For larger changes, please open an issue for discussion before submitting a pull request.
In your PR, please include:
- Changes made
- Links to related issues/PRs
- Tests
- Dependencies
If you add new files, please run make license
to automatically add copyright headers.
For speculative changes meant for early-stage review, include [WIP]
in the PR's title.
(One of the maintainers will add the WIP