Skip to content

Commit

Permalink
Merge pull request #21 from henryiii/henryiii/fix/cmake
Browse files Browse the repository at this point in the history
fix: some CMake fixes
  • Loading branch information
ASKabalan authored Jul 4, 2024
2 parents 6e26027 + 0b355c0 commit 5ae64ba
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
17 changes: 13 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
cmake_minimum_required(VERSION 3.19...3.25)

project(jaxdecomp LANGUAGES CXX CUDA)

# NVCC 12 does not support C++20
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)

# Latest JAX v0.4.26 no longer supports cuda 11.8
set(NVHPC_CUDA_VERSION 12.2)
find_package(CUDAToolkit REQUIRED VERSION 12)
set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR})

# Build debug
# set(CMAKE_BUILD_TYPE Debug)
add_subdirectory(third_party/cuDecomp)

project(jaxdecomp LANGUAGES CXX CUDA)
add_subdirectory(third_party/cuDecomp)

option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF)
option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF)
Expand All @@ -32,7 +36,7 @@ find_library(NCCL_LIBRARY
NAMES nccl
HINTS ${NVHPC_NCCL_LIBRARY_DIR}
)
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})
string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR})


message(STATUS "Using NCCL library: ${NCCL_LIBRARY}")
Expand Down Expand Up @@ -66,4 +70,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR)
target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA)
target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY})
target_link_libraries(_jaxdecomp PRIVATE cudecomp)
target_link_libraries(_jaxdecomp PRIVATE stdc++fs)
set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX)

set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib")

install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .)
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,5 @@ cmake.build-type = "Release"
# Add any additional configurations for scikit-build if necessary
wheel.install-dir = "jaxdecomp/_src"


[tool.scikit-build.cmake.define]
CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""

0 comments on commit 5ae64ba

Please sign in to comment.