From 605fc8a07aa32b52cdb7fe9a28d6f4865d8b7278 Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 28 Jun 2024 12:23:56 +0200 Subject: [PATCH 1/3] Migrate to Scikit-build tools --- CMakeLists.txt | 4 +++ pyproject.toml | 36 +++++++++++++++++++ setup.py | 96 -------------------------------------------------- 3 files changed, 40 insertions(+), 96 deletions(-) create mode 100644 pyproject.toml delete mode 100644 setup.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e0d58b..c70f143 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,3 +67,7 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA) target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY}) target_link_libraries(_jaxdecomp PRIVATE cudecomp) set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX) + + +install(TARGETS cudecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS _jaxdecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ea6b09d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = [ "scikit-build-core","pybind11"] +build-backend = "scikit_build_core.build" + +[project] +name = "jaxdecomp" +version = "0.1.0" +description = "JAX bindings for the cuDecomp library" +authors = [ + { name = "Wassim Kabalan" }, + { name = "Francois Lanusse"} +] +urls = { "Homepage" = "https://github.com/DifferentiableUniverseInitiative/jaxDecomp" } +readme = "README.md" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent" +] +dependencies = [] + +[project.optional-dependencies] +test = ["pytest"] + +[tool.scikit-build] +minimum-version = "0.8" +cmake.version = ">=3.19" +build-dir = "build/{wheel_tag}" +wheel.py-api = "py3" +cmake.build-type = "Release" +# Add any additional configurations for scikit-build if necessary +cmake.args = [ + "-DCMAKE_INSTALL_PREFIX:PATH=jaxdecomp/_src/_jaxdecomp/", + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=jaxdecomp/_src/_jaxdecomp/", +] diff --git a/setup.py b/setup.py deleted file mode 100644 index b07affb..0000000 --- a/setup.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import subprocess -import sys -from pathlib import Path - -from setuptools import Extension, find_packages, setup -from setuptools.command.build_ext import build_ext - - -# A CMakeExtension needs a sourcedir instead of a file list. -# The name must be the _single_ output extension from the CMake build. -# If you need multiple extensions, see scikit-build. -class CMakeExtension(Extension): - - def __init__(self, name: str, sourcedir: str = "") -> None: - super().__init__(name, sources=[]) - self.sourcedir = os.fspath(Path(sourcedir).resolve()) - - -class CMakeBuild(build_ext): - - def build_extension(self, ext: CMakeExtension) -> None: - # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ - ext_fullpath = Path.cwd() / self.get_ext_fullpath( - ext.name) # type: ignore[no-untyped-call] - extdir = ext_fullpath.parent.resolve() - - # Using this requires trailing slash for auto-detection & inclusion of - # auxiliary "native" libs - - debug = int(os.environ.get("DEBUG", - 0)) if self.debug is None else self.debug - cfg = "Debug" if debug else "Release" - - # CMake lets you override the generator - we need to check this. - # Can be set with Conda-Build, for example. - cmake_generator = os.environ.get("CMAKE_GENERATOR", "") - - # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON - # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code - # from Python. - cmake_args = [ - f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", - f"-DPYTHON_EXECUTABLE={sys.executable}", - f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm - ] - build_args = [] - # Adding CMake arguments set as environment variable - # (needed e.g. to build for ARM OSx on conda-forge) - if "CMAKE_ARGS" in os.environ: - cmake_args += [ - item for item in os.environ["CMAKE_ARGS"].split(" ") if item - ] - - # Single config generators are handled "normally" - single_config = any(x in cmake_generator for x in {"NMake", "Ninja"}) - - # CMake allows an arch-in-generator style for backward compatibility - contains_arch = any(x in cmake_generator for x in {"ARM", "Win64"}) - - # Multi-config generators have a different way to specify configs - if not single_config: - cmake_args += [f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{cfg.upper()}={extdir}"] - build_args += ["--config", cfg] - - # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level - # across all generators. - if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: - # self.parallel is a Python 3 only way to set parallel jobs by hand - # using -j in the build_ext call, not supported by pip or PyPA-build. - if hasattr(self, "parallel") and self.parallel: - # CMake 3.12+ only. - build_args += [f"-j{self.parallel}"] - - build_temp = Path(self.build_temp) / ext.name - if not build_temp.exists(): - build_temp.mkdir(parents=True) - - subprocess.run( - ["cmake", ext.sourcedir] + cmake_args, cwd=build_temp, check=True) - - subprocess.run( - ["cmake", "--build", "."] + build_args, cwd=build_temp, check=True) - - -setup( - name='jaxDecomp', - url='https://github.com/DifferentiableUniverseInitiative/jaxDecomp', - author='Wassim Kabalan, Francois Lanusse', - description='JAX bindings for the cuDecomp library', - ext_modules=[CMakeExtension("jaxdecomp/_src/_jaxdecomp")], - cmdclass={"build_ext": CMakeBuild}, - packages=find_packages(), - include_package_data=True, - use_scm_version=True, - setup_requires=["setuptools_scm"]) From 015d1cc209fa0d88d084ba04738cf3903ef2723d Mon Sep 17 00:00:00 2001 From: Wassim KABALAN Date: Fri, 28 Jun 2024 17:05:37 +0200 Subject: [PATCH 2/3] fix --- pyproject.toml | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ea6b09d..80e9360 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ build-dir = "build/{wheel_tag}" wheel.py-api = "py3" cmake.build-type = "Release" # Add any additional configurations for scikit-build if necessary -cmake.args = [ - "-DCMAKE_INSTALL_PREFIX:PATH=jaxdecomp/_src/_jaxdecomp/", - "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY:PATH=jaxdecomp/_src/_jaxdecomp/", -] +wheel.install-dir = "jaxdecomp/_src" + + +[tool.scikit-build.cmake.define] +CMAKE_LIBRARY_OUTPUT_DIRECTORY = "" From 815b15505f77e4120e3f4aaf8dbc8d4db4518067 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Fri, 28 Jun 2024 12:07:22 -0400 Subject: [PATCH 3/3] fix: some CMake fixes --- CMakeLists.txt | 17 +++++++++++------ pyproject.toml | 4 ---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c70f143..c7ef9c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,15 +1,19 @@ cmake_minimum_required(VERSION 3.19...3.25) +project(jaxdecomp LANGUAGES CXX CUDA) + # NVCC 12 does not support C++20 set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) + # Latest JAX v0.4.26 no longer supports cuda 11.8 -set(NVHPC_CUDA_VERSION 12.2) +find_package(CUDAToolkit REQUIRED VERSION 12) +set(NVHPC_CUDA_VERSION ${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}) + # Build debug # set(CMAKE_BUILD_TYPE Debug) -add_subdirectory(third_party/cuDecomp) -project(jaxdecomp LANGUAGES CXX CUDA) +add_subdirectory(third_party/cuDecomp) option(CUDECOMP_BUILD_FORTRAN "Build Fortran bindings" OFF) option(CUDECOMP_ENABLE_NVSHMEM "Enable NVSHMEM" OFF) @@ -32,7 +36,7 @@ find_library(NCCL_LIBRARY NAMES nccl HINTS ${NVHPC_NCCL_LIBRARY_DIR} ) - string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) +string(REPLACE "/lib" "/include" NCCL_INCLUDE_DIR ${NVHPC_NCCL_LIBRARY_DIR}) message(STATUS "Using NCCL library: ${NCCL_LIBRARY}") @@ -66,8 +70,9 @@ target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUTENSOR) target_link_libraries(_jaxdecomp PRIVATE NVHPC::CUDA) target_link_libraries(_jaxdecomp PRIVATE ${NCCL_LIBRARY}) target_link_libraries(_jaxdecomp PRIVATE cudecomp) +target_link_libraries(_jaxdecomp PRIVATE stdc++fs) set_target_properties(_jaxdecomp PROPERTIES LINKER_LANGUAGE CXX) +set_target_properties(_jaxdecomp PROPERTIES INSTALL_RPATH "$ORIGIN/lib") -install(TARGETS cudecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) -install(TARGETS _jaxdecomp LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX} PUBLIC_HEADER DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS _jaxdecomp LIBRARY DESTINATION . PUBLIC_HEADER DESTINATION .) diff --git a/pyproject.toml b/pyproject.toml index 80e9360..d016159 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,3 @@ wheel.py-api = "py3" cmake.build-type = "Release" # Add any additional configurations for scikit-build if necessary wheel.install-dir = "jaxdecomp/_src" - - -[tool.scikit-build.cmake.define] -CMAKE_LIBRARY_OUTPUT_DIRECTORY = ""