From f3ab79d779ecd1c9314452a5bd67b7427f5e84dd Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Sat, 25 Nov 2023 20:42:53 -0500 Subject: [PATCH 1/7] Only compile CUDA extensions if requested --- CMakeLists.txt | 26 ++++++++++++++------------ ci/Jenkinsfile | 3 +++ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7a15776..c4cc543 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,14 +6,19 @@ message(STATUS "Using CMake version: " ${CMAKE_VERSION}) # https://github.com/pybind/pybind11/issues/4825 set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF) -include(CheckLanguage) -check_language(CUDA) - -if(CMAKE_CUDA_COMPILER) - enable_language(CUDA) - set(FINUFFT_USE_CUDA ON) +# Enable CUDA if requested and available +option(JAX_FINUFFT_USE_CUDA "Enable CUDA build if available" OFF) +if(JAX_FINUFFT_USE_CUDA) + include(CheckLanguage) + check_language(CUDA) + if(CMAKE_CUDA_COMPILER) + enable_language(CUDA) + set(FINUFFT_USE_CUDA ON) + else() + message(STATUS "No CUDA compiler found; GPU support will be disabled") + set(FINUFFT_USE_CUDA OFF) + endif() else() - message(STATUS "No CUDA compiler found; GPU support will be disabled") set(FINUFFT_USE_CUDA OFF) endif() @@ -36,11 +41,8 @@ pybind11_add_module(jax_finufft_cpu ${CMAKE_CURRENT_LIST_DIR}/lib/jax_finufft_cp target_link_libraries(jax_finufft_cpu PRIVATE finufft_static) install(TARGETS jax_finufft_cpu LIBRARY DESTINATION .) -# Include the CUDA extensions if possible -include(CheckLanguage) -check_language(CUDA) - -if(CMAKE_CUDA_COMPILER) +# Include the CUDA extensions if possible - see above for where this is set +if(FINUFFT_USE_CUDA) enable_language(CUDA) set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) diff --git a/ci/Jenkinsfile b/ci/Jenkinsfile index 8e6c014..468b953 100644 --- a/ci/Jenkinsfile +++ b/ci/Jenkinsfile @@ -14,6 +14,9 @@ pipeline { } stages { stage('Build') { + environment { + CMAKE_ARGS=-DJAX_FINUFFT_USE_CUDA=ON + } steps { sh 'python3 -m pip install -U pip' sh 'python3 -m pip install "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' From e2057bb3c58fbcf82b86cc611cc5eb7dc10d23bb Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Sat, 25 Nov 2023 20:46:50 -0500 Subject: [PATCH 2/7] more messaged --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c4cc543..a0e6710 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -12,13 +12,15 @@ if(JAX_FINUFFT_USE_CUDA) include(CheckLanguage) check_language(CUDA) if(CMAKE_CUDA_COMPILER) + message(STATUS "CUDA compiler found; compiling with GPU support") enable_language(CUDA) set(FINUFFT_USE_CUDA ON) else() - message(STATUS "No CUDA compiler found; GPU support will be disabled") + message(STATUS "No CUDA compiler found; compiling without GPU support") set(FINUFFT_USE_CUDA OFF) endif() else() + message(STATUS "GPU support was not requested") set(FINUFFT_USE_CUDA OFF) endif() From b2ab169db2cbb31d96c9884caa112abd2088920f Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Sat, 25 Nov 2023 20:52:47 -0500 Subject: [PATCH 3/7] fixing Jenkinsfile syntax --- ci/Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/Jenkinsfile b/ci/Jenkinsfile index 468b953..21cf400 100644 --- a/ci/Jenkinsfile +++ b/ci/Jenkinsfile @@ -15,7 +15,7 @@ pipeline { stages { stage('Build') { environment { - CMAKE_ARGS=-DJAX_FINUFFT_USE_CUDA=ON + CMAKE_ARGS = "-DJAX_FINUFFT_USE_CUDA=ON" } steps { sh 'python3 -m pip install -U pip' From 7268e332fa5aea374dcd8c08514eebea71f891d0 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Sat, 25 Nov 2023 20:55:46 -0500 Subject: [PATCH 4/7] adding cmake_flags to CUDA build GA workflow --- .github/workflows/gpu-build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/gpu-build.yml b/.github/workflows/gpu-build.yml index 2780280..3b06249 100644 --- a/.github/workflows/gpu-build.yml +++ b/.github/workflows/gpu-build.yml @@ -34,6 +34,8 @@ jobs: - name: Compile the extension run: python -m pip install -v . + env: + CMAKE_FLAGS: "-DJAX_FINUFFT_USE_CUDA=ON" - name: Check that the GPU extension was built run: python -c "import jax_finufft.jax_finufft_gpu" From 12bfee4c37476524da7ee279fbe30dd7caf97960 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Sat, 25 Nov 2023 21:04:14 -0500 Subject: [PATCH 5/7] Flag name --- .github/workflows/gpu-build.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gpu-build.yml b/.github/workflows/gpu-build.yml index 3b06249..0a5cfec 100644 --- a/.github/workflows/gpu-build.yml +++ b/.github/workflows/gpu-build.yml @@ -35,7 +35,7 @@ jobs: - name: Compile the extension run: python -m pip install -v . env: - CMAKE_FLAGS: "-DJAX_FINUFFT_USE_CUDA=ON" + CMAKE_ARGS: "-DJAX_FINUFFT_USE_CUDA=ON" - name: Check that the GPU extension was built run: python -c "import jax_finufft.jax_finufft_gpu" diff --git a/pyproject.toml b/pyproject.toml index 3da8250..717fdde 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "jax-finufft" description = "Unofficial JAX bindings for finufft" readme = "README.md" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] -requires-python = ">=3.7" +requires-python = ">=3.9" license = { file = "LICENSE" } urls = { Homepage = "https://github.com/dfm/jax-finufft" } dependencies = ["jax", "jaxlib"] From c13533b738d0ada425eb09617653be6fa7158764 Mon Sep 17 00:00:00 2001 From: Lehman Garrison Date: Mon, 27 Nov 2023 09:22:07 -0500 Subject: [PATCH 6/7] readme: require JAX_FINUFFT_USE_CUDA for CUDA compilation --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5c2fbd5..032e27d 100644 --- a/README.md +++ b/README.md @@ -60,11 +60,12 @@ from NVIDIA may work best (see [related advice for Horovod](https://horovod.read conda create -n gpu-jax-finufft -c conda-forge python=3.10 numpy scipy fftw 'gxx<12' conda activate gpu-jax-finufft export CPATH=$CONDA_PREFIX/include:$CPATH +export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=70 -DJAX_FINUFFT_USE_CUDA=ON" python -m pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -CMAKE_ARGS=-DCMAKE_CUDA_ARCHITECTURES=70 python -m pip install . +python -m pip install . ``` -In the last line, you'll need to select the CUDA architecture(s) you wish to compile for. See the [FINUFFT docs](https://finufft.readthedocs.io/en/latest/install_gpu.html#cmake-installation). +In the `CMAKE_ARGS` line, you'll need to select the CUDA architecture(s) you wish to compile for. See the [FINUFFT docs](https://finufft.readthedocs.io/en/latest/install_gpu.html#cmake-installation). At runtime, you may also need: ```bash @@ -86,7 +87,7 @@ ml cudnn ml nccl export LD_LIBRARY_PATH=$CUDA_HOME/extras/CUPTI/lib64:$LD_LIBRARY_PATH -export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90" +export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=60;70;80;90 -DJAX_FINUFFT_USE_CUDA=ON" ``` From ddd15fbea8172f423109d141123130902ea5b143 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 27 Nov 2023 14:07:29 -0500 Subject: [PATCH 7/7] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 717fdde..7a615de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "jax-finufft" description = "Unofficial JAX bindings for finufft" readme = "README.md" authors = [{ name = "Dan Foreman-Mackey", email = "foreman.mackey@gmail.com" }] -requires-python = ">=3.9" +requires-python = ">=3.8" license = { file = "LICENSE" } urls = { Homepage = "https://github.com/dfm/jax-finufft" } dependencies = ["jax", "jaxlib"]