Skip to content

Commit

Permalink
Support CUDA on Julia 1.9+ via a package extension.
Browse files Browse the repository at this point in the history
  • Loading branch information
GunnarFarneback committed Oct 11, 2023
1 parent dfb67d6 commit 5f07cf3
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 8 deletions.
11 changes: 11 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
[compat]
ArgCheck = "2"
CEnum = "0.4"
CUDA = "4, 5"
DataStructures = "0.18"
DocStringExtensions = "0.8, 0.9"
Requires = "1"
cuDNN = "1.1"
julia = "1.6"

[extensions]
CUDAExt = ["CUDA", "cuDNN"]

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test"]

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
11 changes: 11 additions & 0 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module CUDAExt

# These functions are only defined for diagnostic purposes. Otherwise
# the CUDA extension only relies on the CUDA and cuDNN dependencies to
# have loaded the libraries needed by ONNXRunTime's CUDA execution
# provider.
import CUDA
cuda_functional() = CUDA.functional()
cuda_runtime_version() = CUDA.runtime_version()

end
12 changes: 8 additions & 4 deletions src/ONNXRunTime.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module ONNXRunTime
using Requires:@require
if !isdefined(Base, :get_extension)
using Requires: @require
end

function _perm(arr::AbstractArray{T,N}) where {T,N}
ntuple(i->N+1-i, N)
Expand All @@ -14,9 +16,11 @@ end
include("capi.jl")
include("highlevel.jl")

function __init__()
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
CUDA.functional() && include("cuda.jl")
@static if !isdefined(Base, :get_extension)
function __init__()
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
CUDA.functional() && include("cuda.jl")
end
end
end

Expand Down
29 changes: 25 additions & 4 deletions src/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,31 @@ function load_inference(path::AbstractString; execution_provider::Symbol=:cpu,
if execution_provider === :cpu
session_options = CreateSessionOptions(api)
elseif execution_provider === :cuda
if !(isdefined(@__MODULE__, :CUDA))
@warn """
The $(repr(execution_provider)) requires the CUDA.jl package to be available. Try adding `import CUDA` to your code.
"""
if isdefined(Base, :get_extension)
CUDAExt = Base.get_extension(@__MODULE__, :CUDAExt)
if isnothing(CUDAExt)
@warn """
The $(repr(execution_provider)) execution provider requires the CUDA.jl and cuDNN.jl packages to be available. Try adding `import CUDA, cuDNN` to your code.
"""
elseif !getfield(CUDAExt, :cuda_functional)()
@warn """
The $(repr(execution_provider)) execution provider requires CUDA to be functional. See `CUDA.functional`.
"""
elseif !(v"11.8" <= getfield(CUDAExt, :cuda_runtime_version)() < v"12")
# Note: The supported version range is a property
# inherited from the CUDA runtime library and needs to
# be updated when the library is updated. It may be a
# good idea to centralize this information somewhere.
@warn """
The $(repr(execution_provider)) execution provider requires a CUDA runtime version of at least 11.8 but less than 12. See `CUDA.set_runtime_version!`.
"""
end
else
if !isdefined(@__MODULE__, :CUDA)
@warn """
The $(repr(execution_provider)) execution provider requires the CUDA.jl package to be available. Try adding `import CUDA` to your code.
"""
end
end
session_options = CreateSessionOptions(api)
cuda_options = OrtCUDAProviderOptions()
Expand Down
178 changes: 178 additions & 0 deletions test/test_cuda_extension.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# This file is not included from `runtests.jl` nor run in CI.
#
# Run it with `julia tests/test_cuda_extension.jl`. This requires that
# Julia is installed with juliaup and will involve downloading of a
# lot of big artifacts. The output will contain lots of error messages
# from caught errors; what matters is that all testsets pass.

using Test

juliaup_found = false
try run(pipeline(`juliaup --version`, stdout = devnull, stderr = devnull))
global juliaup_found = true
catch e
end

if !juliaup_found
error("`juliaup` needs to be installed for the CUDA extension tests")
end

wait(run(`juliaup add 1.6`, wait = false))
wait(run(`juliaup add 1.9`, wait = false))

package_path = dirname(@__DIR__)
onnx_path = joinpath(@__DIR__, "data", "copy2d.onnx")

function with_environment(f::Function; cuda_runtime_version)
mktempdir() do env
write(joinpath(env, "LocalPreferences.toml"),
"""
[CUDA_Runtime_jll]
version = "$(cuda_runtime_version)"
""")
write(joinpath(env, "Project.toml"),
"""
[extras]
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
""")
f(env)
end
end

@testset "Julia 1.6 CUDA 3" begin
with_environment(cuda_runtime_version = "11.8") do env
install_script = """
using Pkg
Pkg.develop(path = "$(package_path)")
Pkg.add(name = "CUDA", version = "3")
"""
@test success(run(`julia +1.6 --project=$(env) -e "$(install_script)"`))
# Correct dependency for :cuda.
test_script = """
using ONNXRunTime, CUDA
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test success(run(`julia +1.6 --project=$(env) -e "$(test_script)"`))
# CUDA not loaded.
test_script = """
using ONNXRunTime
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test_throws ProcessFailedException run(`julia +1.6 --project=$(env) -e "$(test_script)"`)
# CUDA not loaded but running on CPU, so it's fine.
test_script = """
using ONNXRunTime
load_inference("$(onnx_path)", execution_provider = :cpu)
"""
@test success(run(`julia +1.6 --project=$(env) -e "$(test_script)"`))
end
end

@testset "Julia 1.9 CUDA 3" begin
with_environment(cuda_runtime_version = "11.8") do env
install_script = """
using Pkg
Pkg.develop(path = "$(package_path)")
Pkg.add(name = "CUDA", version = "3")
"""
# CUDA 3 is not possible to install together with ONNXRunTime
# on Julia 1.9 due to Compat requirements.
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(install_script)"`)
end
end

@testset "Julia 1.9 CUDA.jl $(cuda_version) CUDA runtime 11.8" for cuda_version in (4, 5)
with_environment(cuda_runtime_version = "11.8") do env
install_script = """
using Pkg
Pkg.develop(path = "$(package_path)")
Pkg.add(name = "CUDA", version = "$(cuda_version)")
Pkg.add(name = "cuDNN")
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(install_script)"`))
# Correct dependencies for :cuda.
test_script = """
using ONNXRunTime, CUDA, cuDNN
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
# Neither CUDA nor cuDNN loaded.
test_script = """
using ONNXRunTime
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(test_script)"`)
# Neither CUDA nor cuDNN loaded but running on CPU, so it's fine.
test_script = """
using ONNXRunTime
load_inference("$(onnx_path)", execution_provider = :cpu)
"""
# CUDA not loaded. Well, cuDNN pulls in CUDA so this passes anyway.
test_script = """
using ONNXRunTime
using cuDNN
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
# CUDA not loaded but running on CPU, so it's fine.
test_script = """
using ONNXRunTime
using cuDNN
load_inference("$(onnx_path)", execution_provider = :cpu)
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
# cuDNN not loaded.
test_script = """
using ONNXRunTime
using CUDA
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(test_script)"`)
# cuDNN not loaded but running on CPU, so it's fine.
test_script = """
using ONNXRunTime
using CUDA
load_inference("$(onnx_path)", execution_provider = :cpu)
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
end
end

@testset "Julia 1.9 CUDA.jl $(cuda_version) CUDA runtime 11.6" for cuda_version in (4, 5)
with_environment(cuda_runtime_version = "11.6") do env
install_script = """
using Pkg
Pkg.develop(path = "$(package_path)")
Pkg.add(name = "CUDA", version = "$(cuda_version)")
Pkg.add(name = "cuDNN")
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(install_script)"`))
# Correct dependencies for :cuda. CUDA runtime version is
# lower than officially supported but close enough to at least
# load so there will be a warning but no error.
test_script = """
using ONNXRunTime, CUDA, cuDNN
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
end
end

@testset "Julia 1.9 CUDA.jl $(cuda_version) CUDA runtime 12.1" for cuda_version in (4, 5)
with_environment(cuda_runtime_version = "12.1") do env
install_script = """
using Pkg
Pkg.develop(path = "$(package_path)")
Pkg.add(name = "CUDA", version = "$(cuda_version)")
Pkg.add(name = "cuDNN")
"""
@test success(run(`julia +1.9 --project=$(env) -e "$(install_script)"`))
# Correct dependencies for :cuda but fails due to bad version
# of CUDA runtime.
test_script = """
using ONNXRunTime, CUDA, cuDNN
load_inference("$(onnx_path)", execution_provider = :cuda)
"""
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(test_script)"`)
end
end

0 comments on commit 5f07cf3

Please sign in to comment.