-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support CUDA arrays #36
Comments
Quick test on caviness at UD (interactive session with 16G of memory and a GPU) of the Khatri-Rao product implementation from #34 (comment) _
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.9.4 (2023-11-14)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
(gpu-cuda) pkg> st
Status `~/gpu-cuda/Project.toml`
⌃ [052768ef] CUDA v5.0.0
Info Packages marked with ⌃ have new versions available and may be upgradable.
julia> using CUDA
julia> CUDA.versioninfo()
CUDA runtime 12.2, artifact installation
CUDA driver 12.1
NVIDIA driver 530.30.2
CUDA libraries:
- CUBLAS: 12.2.5
- CURAND: 10.3.3
- CUFFT: 11.0.8
- CUSOLVER: 11.5.2
- CUSPARSE: 12.1.2
- CUPTI: 20.0.0
- NVML: 12.0.0+530.30.2
Julia packages:
- CUDA: 5.0.0
- CUDA_Driver_jll: 0.6.0+4
- CUDA_Runtime_jll: 0.9.2+4
Toolchain:
- Julia: 1.9.4
- LLVM: 14.0.6
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0, 7.1, 7.2, 7.3, 7.4, 7.5
- Device capability support: sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80, sm_86
1 device:
0: Tesla P100-PCIE-12GB (sm_60, 11.906 GiB / 12.000 GiB available)
julia> function khatrirao(A::Vararg{T,N}) where {T<:AbstractMatrix,N}
r = size(A[1],2)
# @boundscheck all(==(r),size.(A,2)) || throw(DimensionMismatch())
R = ntuple(Val(N)) do k
dims = (ntuple(i->1,Val(N-k))..., :, ntuple(i->1,Val(k-1))..., r)
return reshape(A[k],dims)
end
return reshape(broadcast(*, R...),:,r)
end
khatrirao (generic function with 1 method)
julia> A = CUDA.randn.((100, 100, 100), 500);
julia> typeof(A)
Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}
julia> A_cpu = Array.(A);
julia> typeof(A_cpu)
Tuple{Matrix{Float32}, Matrix{Float32}, Matrix{Float32}}
julia> A_cpu_64 = convert.(Matrix{Float64}, A_cpu);
julia> typeof(A_cpu_64)
Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}
julia> typeof(khatrirao(A...))
CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
julia> typeof(khatrirao(A_cpu...))
Matrix{Float32} (alias for Array{Float32, 2})
julia> typeof(khatrirao(A_cpu_64...))
Matrix{Float64} (alias for Array{Float64, 2})
julia> Array(khatrirao(A...)) == khatrirao(A_cpu...) ≈ khatrirao(A_cpu_64...)
true
julia> CUDA.@time khatrirao(A...);
0.065930 seconds (111 CPU allocations: 6.359 KiB) (1 GPU allocation: 1.863 GiB, 6.48% memmgmt time)
julia> @time khatrirao(A_cpu...);
1.039305 seconds (18 allocations: 1.863 GiB, 4.55% gc time)
julia> @time khatrirao(A_cpu_64...);
1.450372 seconds (21 allocations: 3.725 GiB, 0.75% gc time) Looks promising! |
Using the simple MTTKRP implementation in master with the Khatri-Rao function, can get a significant speedup using GPU for a tensor of size (100, 200, 300) (for all three modes of MTTKRP):
|
Can get the new MTTKRP function from #35 to work on gpu with one modification (making a copy of the result of selectdim in the loop in the else case, which is necessary because selectdim returns a view which is not gpu-compatible):
Rewriting the loop in the else case to be one multiplication would make the code fully-compatible with CuArrays. |
Idea is to add GPU support by making the
gcp
implementation sufficiently generic so that it can work withCuArray
s. Can also addCuArray
specific methods via package extensions.Should probably wait until current work on #35 and perhaps #34 are done.
Tagging @alexmul1114 who will be working on this as part of his winter fellowship.
The text was updated successfully, but these errors were encountered: