Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating CUDA build to be opt in #43

Merged
merged 7 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/gpu-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ jobs:

- name: Compile the extension
run: python -m pip install -v .
env:
CMAKE_ARGS: "-DJAX_FINUFFT_USE_CUDA=ON"

- name: Check that the GPU extension was built
run: python -c "import jax_finufft.jax_finufft_gpu"
28 changes: 16 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@ message(STATUS "Using CMake version: " ${CMAKE_VERSION})
# https://github.com/pybind/pybind11/issues/4825
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF)

include(CheckLanguage)
check_language(CUDA)

if(CMAKE_CUDA_COMPILER)
enable_language(CUDA)
set(FINUFFT_USE_CUDA ON)
# Enable CUDA if requested and available
option(JAX_FINUFFT_USE_CUDA "Enable CUDA build if available" OFF)
if(JAX_FINUFFT_USE_CUDA)
include(CheckLanguage)
check_language(CUDA)
if(CMAKE_CUDA_COMPILER)
message(STATUS "CUDA compiler found; compiling with GPU support")
enable_language(CUDA)
set(FINUFFT_USE_CUDA ON)
else()
message(STATUS "No CUDA compiler found; compiling without GPU support")
set(FINUFFT_USE_CUDA OFF)
endif()
else()
message(STATUS "No CUDA compiler found; GPU support will be disabled")
message(STATUS "GPU support was not requested")
set(FINUFFT_USE_CUDA OFF)
endif()

Expand All @@ -36,11 +43,8 @@ pybind11_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cp
target_link_libraries(jax_finufft_cpu PRIVATE finufft_static)
install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .)

# Include the CUDA extensions if possible
include(CheckLanguage)
check_language(CUDA)

if(CMAKE_CUDA_COMPILER)
# Include the CUDA extensions if possible - see above for where this is set
if(FINUFFT_USE_CUDA)
enable_language(CUDA)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)

Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ from NVIDIA may work best (see [related advice for Horovod](https://horovod.read
conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12'
conda activate gpu-jax-finufft
export CPATH=$CONDA_PREFIX/include:$CPATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON"
python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
CMAKE_ARGS=-DCMAKE_CUDA_ARCHITECTURES=70 python -m pip install .
python -m pip install .
```

In the last line, you'll need to select the CUDA architecture(s) you wish to compile for. See the [FINUFFT docs](https://finufft.readthedocs.io/en/latest/install_gpu.html#cmake-installation).
In the `CMAKE_ARGS` line, you'll need to select the CUDA architecture(s) you wish to compile for. See the [FINUFFT docs](https://finufft.readthedocs.io/en/latest/install_gpu.html#cmake-installation).

At runtime, you may also need:
```bash
Expand All @@ -86,7 +87,7 @@ ml cudnn
ml nccl

export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90"
export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON"
```
</details>

Expand Down
3 changes: 3 additions & 0 deletions ci/Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ pipeline {
}
stages {
stage('Build') {
environment {
CMAKE_ARGS = "-DJAX_FINUFFT_USE_CUDA=ON"
}
steps {
sh 'python3 -m pip install -U pip'
sh 'python3 -m pip install "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "jax-finufft"
description = "Unofficial JAX bindings for finufft"
readme = "README.md"
authors = [{ name = "Dan Foreman-Mackey", email = "[email protected]" }]
requires-python = ">=3.7"
requires-python = ">=3.8"
license = { file = "LICENSE" }
urls = { Homepage = "https://github.com/dfm/jax-finufft" }
dependencies = ["jax", "jaxlib"]
Expand Down