Skip to content

[ROCm][Windows] Enable build with ROCm on Windows #3883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 71 additions & 35 deletions cmake/LoadHIP.cmake
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
set(PYTORCH_FOUND_HIP FALSE)

if(NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH /opt/rocm)
if(UNIX)
set(ROCM_PATH /opt/rocm)
else() # Win32
set(ROCM_PATH C:/opt/rocm)
endif()
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()

# HIP_PATH
if(NOT DEFINED ENV{HIP_PATH})
set(HIP_PATH ${ROCM_PATH}/hip)
if(UNIX)
set(HIP_PATH ${ROCM_PATH}/hip)
else() #Win32
set(HIP_PATH ${ROCM_PATH})
endif()
else()
set(HIP_PATH $ENV{HIP_PATH})
endif()
Expand Down Expand Up @@ -129,7 +137,9 @@ else()
endif()

# Add HIP to the CMAKE Module Path
set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})
# needed because the find_package call to this module uses the Module mode search
# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes
set(CMAKE_MODULE_PATH ${HIP_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})

# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions(-DNDEBUG)
Expand All @@ -145,29 +155,49 @@ find_package_and_print_version(HIP 1.0)
if(HIP_FOUND)
set(PYTORCH_FOUND_HIP TRUE)

# Find ROCM version for checks
file(READ "${ROCM_PATH}/.info/version-dev" ROCM_VERSION_DEV_RAW)
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
if(ROCM_VERSION_DEV_MATCH)
set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}")
if(UNIX)
set(ROCM_LIB_NAME "ROCM")
else() # Win32
set(ROCM_LIB_NAME "HIP")
endif()
if(UNIX)
# Find ROCM version for checks
file(READ "${ROCM_PATH}/.info/version-dev" ${ROCM_LIB_NAME}_VERSION_DEV_RAW)
else() #Win32
# Find HIP version from hipconfig execution
execute_process(
COMMAND ${ROCM_PATH}/bin/hipconfig.bat --version
OUTPUT_VARIABLE ${ROCM_LIB_NAME}_VERSION_DEV_RAW
OUTPUT_STRIP_TRAILING_WHITESPACE
)
endif()
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ${ROCM_LIB_NAME}_VERSION_DEV_MATCH ${${ROCM_LIB_NAME}_VERSION_DEV_RAW})
if(${ROCM_LIB_NAME}_VERSION_DEV_MATCH)
set(${ROCM_LIB_NAME}_VERSION_DEV_MAJOR ${CMAKE_MATCH_1})
set(${ROCM_LIB_NAME}_VERSION_DEV_MINOR ${CMAKE_MATCH_2})
set(${ROCM_LIB_NAME}_VERSION_DEV_PATCH ${CMAKE_MATCH_3})
set(${ROCM_LIB_NAME}_VERSION_DEV "${${ROCM_LIB_NAME}_VERSION_DEV_MAJOR}.${${ROCM_LIB_NAME}_VERSION_DEV_MINOR}.${${ROCM_LIB_NAME}_VERSION_DEV_PATCH}")
endif()
if(UNIX)
message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n")
else() #Win32
message("\n***** HIP version from ${ROCM_PATH}/bin/hipconfig.bat --version ****\n")
endif()
message("${ROCM_LIB_NAME}_VERSION_DEV: ${${ROCM_LIB_NAME}_VERSION_DEV}")
message("${ROCM_LIB_NAME}_VERSION_DEV_MAJOR: ${${ROCM_LIB_NAME}_VERSION_DEV_MAJOR}")
message("${ROCM_LIB_NAME}_VERSION_DEV_MINOR: ${${ROCM_LIB_NAME}_VERSION_DEV_MINOR}")
message("${ROCM_LIB_NAME}_VERSION_DEV_PATCH: ${${ROCM_LIB_NAME}_VERSION_DEV_PATCH}")

if(UNIX)
message("\n***** Library versions from dpkg *****\n")
execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
endif()
message("\n***** ROCm version from ${ROCM_PATH}/.info/version-dev ****\n")
message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}")
message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}")
message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}")
message("ROCM_VERSION_DEV_PATCH: ${ROCM_VERSION_DEV_PATCH}")

message("\n***** Library versions from dpkg *****\n")
execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hip_base COMMAND awk "{print $2 \" VERSION: \" $3}")
execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}")

message("\n***** Library versions from cmake find_package *****\n")

Expand All @@ -176,7 +206,6 @@ if(HIP_FOUND)
### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.###

set(hip_DIR ${HIP_PATH}/lib/cmake/hip)
set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs)
set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr)
set(rocrand_DIR ${ROCRAND_PATH}/lib/cmake/rocrand)
Expand All @@ -186,13 +215,11 @@ if(HIP_FOUND)
set(rocfft_DIR ${ROCFFT_PATH}/lib/cmake/rocfft)
set(hipfft_DIR ${HIPFFT_PATH}/lib/cmake/hipfft)
set(hipsparse_DIR ${HIPSPARSE_PATH}/lib/cmake/hipsparse)
set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl)
set(rocprim_DIR ${ROCPRIM_PATH}/lib/cmake/rocprim)
set(hipcub_DIR ${HIPCUB_PATH}/lib/cmake/hipcub)
set(rocthrust_DIR ${ROCTHRUST_PATH}/lib/cmake/rocthrust)

find_package_and_print_version(hip REQUIRED)
find_package_and_print_version(hsa-runtime64 REQUIRED)
find_package_and_print_version(amd_comgr REQUIRED)
find_package_and_print_version(rocrand REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
Expand All @@ -203,7 +230,6 @@ if(HIP_FOUND)
find_package_and_print_version(hipfft REQUIRED)
endif()
find_package_and_print_version(hipsparse REQUIRED)
find_package_and_print_version(rccl)
find_package_and_print_version(rocprim REQUIRED)
find_package_and_print_version(hipcub REQUIRED)
find_package_and_print_version(rocthrust REQUIRED)
Expand All @@ -223,12 +249,22 @@ if(HIP_FOUND)
# TODO: miopen_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${MIOPEN_PATH}/lib)
# TODO: rccl_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib)
# hiprtc is part of HIP
find_library(ROCM_HIPRTC_LIB ${hip_library_name} HINTS ${HIP_PATH}/lib)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include)

if(UNIX)
set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64)
set(rccl_DIR ${RCCL_PATH}/lib/cmake/rccl)

find_package_and_print_version(hsa-runtime64 REQUIRED)
find_package_and_print_version(rccl)

# TODO: rccl_LIBRARIES should return fullpath to the library file,
# however currently it's just the lib name
find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${RCCL_PATH}/lib)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCTRACER_PATH}/lib)
set(roctracer_INCLUDE_DIRS ${ROCTRACER_PATH}/include)
endif()
endif()

8 changes: 6 additions & 2 deletions tools/setup_helpers/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,13 @@ def build_extension(self, ext):
import sys

python_version = sys.version_info

cxx_compiler = os.environ.get('CXX', 'cl')
c_compiler = os.environ.get('CC', 'cl')

cmake_args += [
"-DCMAKE_C_COMPILER=cl",
"-DCMAKE_CXX_COMPILER=cl",
f"-DCMAKE_C_COMPILER={c_compiler}",
f"-DCMAKE_CXX_COMPILER={cxx_compiler}",
f"-DPYTHON_VERSION={python_version.major}.{python_version.minor}",
]

Expand Down