Skip to content

Commit

Permalink
Fix kernel compilation on Windows (#543)
Browse files Browse the repository at this point in the history
* Use longlong on Windows

* Use correct LLD on Windows
  • Loading branch information
pxl-th authored Nov 21, 2023
1 parent fb7f556 commit 7a8bff1
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 48 deletions.
18 changes: 10 additions & 8 deletions src/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,16 @@ function create_executable(obj)
`$(AMDGPU.lld_path)`
end

path_exe = mktemp() do path_o, io_o
write(io_o, obj)
flush(io_o)
path_exe = path_o * ".exe"
run(`$lld -shared -o $path_exe $path_o`)
path_exe
end
return read(path_exe)
path_o = tempname(;cleanup=false) * ".obj"
path_exe = tempname(;cleanup=false) * ".exe"

write(path_o, obj)
run(`$lld -shared -o $path_exe $path_o`)
bin = read(path_exe)

rm(path_o)
rm(path_exe)
return bin
end

function hipcompile(@nospecialize(job::CompilerJob))
Expand Down
9 changes: 6 additions & 3 deletions src/device/gcn/wavefront.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
export wavefrontsize

_Clong = Sys.islinux() ? Clong : Clonglong
_Culong = Sys.islinux() ? Culong : Culonglong

for (name,op) in ((:add,typeof(+)), (:max,typeof(max)), (:min,typeof(min)))
wfred_name = Symbol("wfred_$name")
wfscan_name = Symbol("wfscan_$name")
for jltype in (Cint, Clong, Cuint, Culong, Float16, Float32, Float64)
for jltype in (Cint, _Clong, Cuint, _Culong, Float16, Float32, Float64)
type_suffix = fntypes[jltype]

@eval @device_function $(wfred_name)(x::$jltype) = ccall(
Expand All @@ -21,7 +24,7 @@ end
for (name,op) in ((:and,typeof(&)), (:or,typeof(|)), (:xor,typeof()))
wfred_name = Symbol("wfred_$name")
wfscan_name = Symbol("wfscan_$name")
for jltype in (Cint, Clong, Cuint, Culong)
for jltype in (Cint, _Clong, Cuint, _Culong)
type_suffix = fntypes[jltype]

@eval @device_function $(wfred_name)(x::$jltype) = ccall(
Expand All @@ -36,7 +39,7 @@ for (name,op) in ((:and,typeof(&)), (:or,typeof(|)), (:xor,typeof(⊻)))
@eval @inline wfscan(::$op, x, inclusive::Bool) = $(wfscan_name)(x, inclusive)
end

for jltype in (Cuint, Culong)
for jltype in (Cuint, _Culong)
type_suffix = fntypes[jltype]
@eval @device_function wfbcast(x::$jltype, i::Cuint) = ccall(
$("extern __ockl_wfbcast_$(type_suffix)"), llvmcall,
Expand Down
31 changes: 15 additions & 16 deletions src/discovery/discovery.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ function get_library(
end
end

function get_ld_lld(;
function get_ld_lld(rocm_paths::Vector{String};
from_artifact::Bool, artifact_library::Symbol, artifact_field::Symbol,
)
if from_artifact || Sys.iswindows() # TODO temporary fix for Windows
if from_artifact
get_artifact_library(artifact_library, artifact_field)
else
find_ld_lld()
find_ld_lld(rocm_paths)
end
end

Expand Down Expand Up @@ -111,18 +111,6 @@ function __init__()
rocm_paths = use_artifacts() ? String[] : find_roc_paths()

try
# Core.
lld_path = get_ld_lld(; from_artifact=false,
artifact_library=:LLD_jll, artifact_field=:lld_path)
lld_artifact = false
if isempty(lld_path)
lld_path = get_ld_lld(; from_artifact=true,
artifact_library=:LLD_jll, artifact_field=:lld_path)
lld_artifact = true
end
global lld_path = lld_path
global lld_artifact = lld_artifact

global libhsaruntime = if Sys.islinux()
get_library("libhsa-runtime64";
rocm_paths, artifact_library=:hsa_rocr_jll,
Expand All @@ -133,12 +121,23 @@ function __init__()

lib_prefix = Sys.islinux() ? "lib" : ""
# TODO if more than 1 path - force user to specify
@show rocm_paths
rocm_path = isempty(rocm_paths) ? "" : first(rocm_paths)
if Sys.iswindows() && !isempty(rocm_paths)
push!(rocm_paths, joinpath(first(rocm_paths), "bin"))
end

# Linker.
lld_path = get_ld_lld(rocm_paths; from_artifact=false,
artifact_library=:LLD_jll, artifact_field=:lld_path)
lld_artifact = false
if isempty(lld_path)
lld_path = get_ld_lld(rocm_paths; from_artifact=true,
artifact_library=:LLD_jll, artifact_field=:lld_path)
lld_artifact = true
end
global lld_path = lld_path
global lld_artifact = lld_artifact

# HIP.
global libhip = get_library(
Sys.islinux() ? "libamdhip64" : "amdhip64.dll";
Expand Down
29 changes: 8 additions & 21 deletions src/discovery/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,28 +93,16 @@ function find_roc_paths()
return filter(isdir, paths) # TODO require only 1 dir or specify explicitly?
end

function find_ld_lld()
# TODO this is incorrect for Windows
paths = split(get(ENV, "PATH", ""), ":")
paths = filter(path -> path != "", paths)
paths = map(Base.Filesystem.abspath, paths)

basedir = get(ENV, "ROCM_PATH", "/opt/rocm")
ispath(joinpath(basedir, "llvm/bin/ld.lld")) &&
push!(paths, joinpath(basedir, "llvm/bin/"))
ispath(joinpath(basedir, "hcc/bin/ld.lld")) &&
push!(paths, joinpath(basedir, "/hcc/bin/"))
ispath(joinpath(basedir, "opencl/bin/x86_64/ld.lld")) &&
push!(paths, joinpath(basedir, "opencl/bin/x86_64/"))

for path in paths
exp_ld_path = joinpath(path, "ld.lld")
function find_ld_lld(rocm_paths::Vector{String})
lld_name = "ld.lld" * (Sys.iswindows() ? ".exe" : "")
for path in rocm_paths
exp_ld_path = joinpath(path, lld_name)
if ispath(exp_ld_path)
try
tmpfile = mktemp()
run(pipeline(`$exp_ld_path -v`; stdout=tmpfile[1]))
vstr = read(tmpfile[1], String)
rm(tmpfile[1])
tmpfile = tempname(;cleanup=false)
run(pipeline(`$exp_ld_path -v`; stdout=tmpfile))
vstr = read(tmpfile, String)
rm(tmpfile)
vstr = replace(vstr, "AMD " => "")
vstr_splits = split(vstr, ' ')
if VersionNumber(vstr_splits[2]) >= v"6.0.0"
Expand All @@ -136,7 +124,6 @@ function find_device_libs(rocm_path::String)
devlibs_path !== "" && return devlibs_path

# Try the canonical location.
@show rocm_path
canonical_dir = joinpath(rocm_path, "amdgcn", "bitcode")
isdir(canonical_dir) && return canonical_dir

Expand Down

0 comments on commit 7a8bff1

Please sign in to comment.