Skip to content

Commit

Permalink
Updating the installation documentation (#85)
Browse files Browse the repository at this point in the history
* Updating the installation documentation

* More examples

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Formatting

* Undo formatting

* Update README.md

* Update README.md

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
dfm and pre-commit-ci[bot] authored Apr 22, 2024
1 parent 21e3528 commit 79a3fb0
Showing 1 changed file with 126 additions and 43 deletions.
169 changes: 126 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,50 +25,119 @@ forward, reverse, and higher-order differentiation, as well as batching using

## Installation

_For now, only a source build is supported._
The easiest ways to install jax-finufft is to install a pre-compiled binary from
PyPI or conda-forge, but if you need GPU support or want to get tuned
performance, you'll want to follow the instructions to install from source as
described below.

For building, you should only need a recent version of Python (>3.6) and
[FFTW](https://www.fftw.org/). GPU-enabled builds also require a working CUDA
compiler (i.e. the CUDA Toolkit), CUDA >= 11.8, and a compatible cuDNN (older versions of CUDA may work but
are untested). At runtime, you'll need `numpy` and `jax`.
### Install binary from PyPI

First, clone the repo and `cd` into the repo root (don't forget the `--recursive` flag because FINUFFT is included as a submodule):
> [!NOTE]
> Only the CPU-enabled build of jax-finufft is available as a binary wheel on
> PyPI. For a GPU-enabled build, you'll need to build from source as described
> below.
To install a binary wheel from [PyPI](https://pypi.org/project/jax-finufft/)
using pip, run the following commands:

```bash
git clone --recursive https://github.com/flatironinstitute/jax-finufft
cd jax-finufft
python -m pip install "jax[cpu]"
python -m pip install jax-finufft
```

Then, you can use `conda` to set up a build environment (but you're welcome to
use whatever workflow works for you!). For example, for a CPU build, you can use:
If this fails, you may need to use a conda-forge binary, or install from source.

### Install binary from conda-forge

> [!NOTE]
> Only the CPU-enabled build of jax-finufft is available as a binary from
> conda-forge. For a GPU-enabled build, you'll need to build from source as
> described below.
To install using [mamba](https://github.com/mamba-org/mamba) (or
[conda](https://docs.conda.io)), run:

```bash
conda create -n jax-finufft -c conda-forge python=3.10 numpy scipy fftw cxx-compiler
conda activate jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
python -m pip install "jax[cpu]"
python -m pip install .
mamba install -c conda-forge jax-finufft
```

The `CPATH` export is needed so that the build can find the headers for libraries like FFTW installed through conda.
### Install from source

#### Dependencies

For a GPU build, while the CUDA libraries and compiler are nominally available through conda,
our experience trying to install them this way suggests that the "traditional"
way of obtaining the [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads) directly
from NVIDIA may work best (see [related advice for Horovod](https://horovod.readthedocs.io/en/stable/conda_include.html)). After installing the CUDA Toolkit, one can set up the rest of the dependencies with:
Unsurprisingly, a key dependency is JAX, which can be installed following the
directions in [the JAX
documentation](https://jax.readthedocs.io/en/latest/installation.html). If
you're going to want to run on a GPU, make sure that you install the appropriate
JAX build.

The non-Python dependencies that you'll need are:

- [FFTW](https://www.fftw.org),
- [OpenMP](https://www.openmp.org) (for CPU, optional),
- CUDA (for GPU, >= 11.8), and
- cuDNN (for GPU).

Older versions of CUDA may work, but they are untested.

Below we provide some example workflows for installing the required dependencies:

<details>
<summary>Install CPU dependencies with mamba or conda</summary>

```bash
conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
mamba create -n jax-finufft -c conda-forge python jax fftw cxx-compiler
mamba activate jax-finufft
```
</details>

<details>
<summary>Install GPU dependencies with mamba or conda</summary>

For a GPU build, while the CUDA libraries and compiler are nominally available
through conda, our experience trying to install them this way suggests that the
"traditional" way of obtaining the [CUDA
Toolkit](https://developer.nvidia.com/cuda-downloads) directly from NVIDIA may
work best (see [related advice for
Horovod](https://horovod.readthedocs.io/en/stable/conda_include.html)). After
installing the CUDA Toolkit, one can set up the rest of the dependencies with:

```bash
mamba create -n gpu-jax-finufft -c conda-forge python numpy scipy fftw 'gxx<12'
mamba activate gpu-jax-finufft
export CMAKE_PREFIX_PATH=$CONDA_PREFIX:$CMAKE_PREFIX_PATH
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python -m pip install .
```

Other ways of installing JAX are given on the JAX website; the ["local CUDA" install methods](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-locally-harder) are preferred for jax-finufft as this ensures the CUDA extensions are compiled with the same Toolkit version as the CUDA runtime.
Other ways of installing JAX are given on the JAX website; the ["local CUDA"
install
methods](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-locally-harder)
are preferred for jax-finufft as this ensures the CUDA extensions are compiled
with the same Toolkit version as the CUDA runtime.
</details>

<details>
<summary>Install GPU dependencies using Flatiron module system</summary>

```bash
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
```
</details>

#### GPU build configuration

In the above `CMAKE_ARGS` line, you'll need to select the CUDA architecture(s) you wish to compile for. To query your GPU's CUDA architecture (compute capability), you can run:
You'll need to configure your build to select the appropriate CUDA
architecture(s) using the environment variable `CMAKE_ARGS`. To query your GPU's
CUDA architecture (compute capability), you can run:

```bash
$ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
Expand All @@ -78,38 +147,52 @@ $ nvidia-smi --query-gpu=compute_cap --format=csv,noheader
This corresponds to `CMAKE_CUDA_ARCHITECTURES=70`, i.e.:

```bash
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
export CMAKE_ARGS="$CMAKE_ARGS -DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
```

Note that the pip installation is running CMake, so `CMAKE_ARGS` has to be set before then, but is not needed at runtime.
Note that the pip installation below uses CMake, so `CMAKE_ARGS` has to be set
before then, but is not needed at runtime.

At runtime, you may also need:

```bash
export LD_LIBRARY_PATH="$CUDA_PATH/extras/CUPTI/lib64:$LD_LIBRARY_PATH"
```

If `CUDA_PATH` isn't set, you'll need to replace it with the path to your CUDA installation in the above line, often something like `/usr/local/cuda`.
If `CUDA_PATH` isn't set, you'll need to replace it with the path to your CUDA
installation in the above line, often something like `/usr/local/cuda`.

For Flatiron users, the following environment setup script can be used instead of conda:
#### Install source from PyPI

<details>
<summary>Environment script</summary>
The source code for all released versions of jax-finufft are available on PyPI,
and this can be installed using:

```bash
ml modules/2.2
ml gcc
ml python/3.11
ml fftw
ml cuda/11
ml cudnn
ml nccl
python -m pip install --no-binary jax-finufft
```

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
#### Install source from GitHub

Alternatively, you can check out the source repository from GitHub:

```bash
git clone --recurse-submodules https://github.com/flatironinstitute/jax-finufft
cd jax-finufft
```

</details>
> [!NOTE]
> Don't forget the `--recurse-submodules` argument when cloning the repo because
> the upstream FINUFFT library is included as a git submodule. If you do forget,
> you can run `git submodule update --init --recursive` in your local copy to
> checkout the submodule after the initial clone.
After cloning the repository, you can install the local copy using:

```bash
python -m pip install -e .
```

where the `-e` flag optionally runs an "editable" install.

## Usage

Expand Down

0 comments on commit 79a3fb0

Please sign in to comment.