diff --git a/cmake/LoadHIP.cmake b/cmake/LoadHIP.cmake index 09ae6385cf..9227e43d99 100644 --- a/cmake/LoadHIP.cmake +++ b/cmake/LoadHIP.cmake @@ -1,14 +1,22 @@ set(PYTORCH_FOUND_HIP FALSE) if(NOT DEFINED ENV{ROCM_PATH}) - set(ROCM_PATH /opt/rocm) + if(UNIX) + set(ROCM_PATH /opt/rocm) + else() # Win32 + set(ROCM_PATH C:/opt/rocm) + endif() else() set(ROCM_PATH $ENV{ROCM_PATH}) endif() # HIP_PATH if(NOT DEFINED ENV{HIP_PATH}) - set(HIP_PATH ${ROCM_PATH}/hip) + if(UNIX) + set(HIP_PATH ${ROCM_PATH}/hip) + else() #Win32 + set(HIP_PATH ${ROCM_PATH}) + endif() else() set(HIP_PATH $ENV{HIP_PATH}) endif() @@ -129,7 +137,9 @@ else() endif() # Add HIP to the CMAKE Module Path -set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH}) +# needed because the find_package call to this module uses the Module mode search +# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes +set(CMAKE_MODULE_PATH ${HIP_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH}) # Disable Asserts In Code (Can't use asserts on HIP stack.) add_definitions(-DNDEBUG) @@ -145,29 +155,49 @@ find_package_and_print_version(HIP 1.0) if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) - # Find ROCM version for checks - file(READ "${ROCM_PATH}/.info/version-dev" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - if(ROCM_VERSION_DEV_MATCH) - set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) - set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) - set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) - set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") + if(UNIX) + set(ROCM_LIB_NAME "ROCM") + else() # Win32 + set(ROCM_LIB_NAME "HIP") + endif() + if(UNIX) + # Find ROCM version for checks + file(READ "${ROCM_PATH}/.info/version-dev" ${ROCM_LIB_NAME}_VERSION_DEV_RAW) + else() #Win32 + # Find HIP version from hipconfig execution + execute_process( + COMMAND ${ROCM_PATH}/bin/hipconfig.bat --version + OUTPUT_VARIABLE ${ROCM_LIB_NAME}_VERSION_DEV_RAW + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + endif() + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ${ROCM_LIB_NAME}_VERSION_DEV_MATCH ${${ROCM_LIB_NAME}_VERSION_DEV_RAW}) + if(${ROCM_LIB_NAME}_VERSION_DEV_MATCH) + set(${ROCM_LIB_NAME}_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) + set(${ROCM_LIB_NAME}_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) + set(${ROCM_LIB_NAME}_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) + set(${ROCM_LIB_NAME}_VERSION_DEV "${${ROCM_LIB_NAME}_VERSION_DEV_MAJOR}.${${ROCM_LIB_NAME}_VERSION_DEV_MINOR}.${${ROCM_LIB_NAME}_VERSION_DEV_PATCH}") + endif() + if(UNIX) + message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n") + else() #Win32 + message("\n***** HIP version from ${ROCM_PATH}/bin/hipconfig.bat --version ****\n") + endif() + message("${ROCM_LIB_NAME}_VERSION_DEV: ${${ROCM_LIB_NAME}_VERSION_DEV}") + message("${ROCM_LIB_NAME}_VERSION_DEV_MAJOR: ${${ROCM_LIB_NAME}_VERSION_DEV_MAJOR}") + message("${ROCM_LIB_NAME}_VERSION_DEV_MINOR: ${${ROCM_LIB_NAME}_VERSION_DEV_MINOR}") + message("${ROCM_LIB_NAME}_VERSION_DEV_PATCH: ${${ROCM_LIB_NAME}_VERSION_DEV_PATCH}") + + if(UNIX) + message("\n***** Library versions from dpkg *****\n") + execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}") + execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}") endif() - message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n") - message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") - message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") - message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") - message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}") - - message("\n***** Library versions from dpkg *****\n") - execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}") message("\n***** Library versions from cmake find_package *****\n") @@ -176,7 +206,6 @@ if(HIP_FOUND) ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.### set(hip_DIR ${HIP_PATH}/lib/cmake/hip) - set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand) @@ -186,13 +215,11 @@ if(HIP_FOUND) set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft) set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft) set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse) - set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl) set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim) set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub) set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust) find_package_and_print_version(hip REQUIRED) - find_package_and_print_version(hsa-runtime64 REQUIRED) find_package_and_print_version(amd_comgr REQUIRED) find_package_and_print_version(rocrand REQUIRED) find_package_and_print_version(hiprand REQUIRED) @@ -203,7 +230,6 @@ if(HIP_FOUND) find_package_and_print_version(hipfft REQUIRED) endif() find_package_and_print_version(hipsparse REQUIRED) - find_package_and_print_version(rccl) find_package_and_print_version(rocprim REQUIRED) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) @@ -223,12 +249,22 @@ if(HIP_FOUND) # TODO: miopen_LIBRARIES should return fullpath to the library file, # however currently it's just the lib name find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib) - # TODO: rccl_LIBRARIES should return fullpath to the library file, - # however currently it's just the lib name - find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib) # hiprtc is part of HIP find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib) - # roctx is part of roctracer - find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib) - set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include) + + if(UNIX) + set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) + set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl) + + find_package_and_print_version(hsa-runtime64 REQUIRED) + find_package_and_print_version(rccl) + + # TODO: rccl_LIBRARIES should return fullpath to the library file, + # however currently it's just the lib name + find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib) + # roctx is part of roctracer + find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib) + set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include) + endif() endif() + diff --git a/tools/setup_helpers/extension.py b/tools/setup_helpers/extension.py index 2415bbaedb..58f5087854 100644 --- a/tools/setup_helpers/extension.py +++ b/tools/setup_helpers/extension.py @@ -161,9 +161,13 @@ def build_extension(self, ext): import sys python_version = sys.version_info + + cxx_compiler = os.environ.get('CXX', 'cl') + c_compiler = os.environ.get('CC', 'cl') + cmake_args += [ - "-DCMAKE_C_COMPILER=cl", - "-DCMAKE_CXX_COMPILER=cl", + f"-DCMAKE_C_COMPILER={c_compiler}", + f"-DCMAKE_CXX_COMPILER={cxx_compiler}", f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}", ]