From 963bdd077bedcfffda040fcfbb88220e63ceb7db Mon Sep 17 00:00:00 2001 From: Erik Schnetter Date: Mon, 2 Sep 2024 14:01:55 -0400 Subject: [PATCH] Bump minor version --- Project.toml | 2 +- src/IndexSpaces.jl | 58 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0b32573..f6abf3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "IndexSpaces" uuid = "3ceaaf38-e7b4-4193-a23a-303dd231a83a" authors = ["Erik Schnetter "] -version = "1.5.1" +version = "1.6.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" diff --git a/src/IndexSpaces.jl b/src/IndexSpaces.jl index c18d747..555c25b 100644 --- a/src/IndexSpaces.jl +++ b/src/IndexSpaces.jl @@ -935,6 +935,39 @@ end # Memory access +# Address space 1 is global, 3 is shared +function unsafe_load2(ptr::Core.LLVMPtr{Int32,1}) + return Base.llvmcall( + """ + %ptr = bitcast i8 addrspace(1)* %0 to [2 x i32] addrspace(1)* + %val = load [2 x i32], [2 x i32] addrspace(1)* %ptr, align 16 + ret [2 x i32] %val + """, + NTuple{2,Int32}, + Tuple{Core.LLVMPtr{Int32,1}}, + ptr, + ) +end +function unsafe_load2(ptr::Core.LLVMPtr{Int32,3}) + return Base.llvmcall( + """ + %ptr = bitcast i8 addrspace(3)* %0 to [2 x i32] addrspace(3)* + %val = load [2 x i32], [2 x i32] addrspace(3)* %ptr, align 16 + ret [2 x i32] %val + """, + NTuple{2,Int32}, + Tuple{Core.LLVMPtr{Int32,3}}, + ptr, + ) +end +unsafe_load2(arr::CuDeviceArray{Int32}, idx::Integer) = unsafe_load2(pointer(arr, idx)) +function unsafe_load2(arr::CuDeviceArray{T}, idx::Integer) where {T} + @assert sizeof(T) == sizeof(Int32) + res = unsafe_load2(reinterpret(Int32, arr), idx)::NTuple{2,Int32} + # return ntuple(n -> reinterpret(T, res[n]), 2)::NTuple{2,T} + return ntuple(n -> T(res[n] % UInt32), 2)::NTuple{2,T} +end + # Address space 1 is global, 3 is shared function unsafe_load4(ptr::Core.LLVMPtr{Int32,1}) return Base.llvmcall( @@ -989,6 +1022,31 @@ function load!( addr = memory_index(reg_layout, mem_layout, vals) push!(emitter.statements, :($reg_name = $mem_var[$(postprocess(addr)) + 0x1])) end + elseif align == 8 + # Find registers with stride 1 + # TODO: Better error message if not accessing global memory + phys0 = inv(mem_layout)[Memory(:memory, 1, 2)] + register0 = reg_layout[phys0]::Register + tmp_layout = copy(reg_layout) + delete!(tmp_layout, phys0) + loop_over_registers(emitter, tmp_layout) do state + state0 = copy(state) + state1 = copy(state) + state0.dict[register0.name] = get(state0.dict, register0.name, Int32(0)) + Int32(0 * register0.offset) + state1.dict[register0.name] = get(state1.dict, register0.name, Int32(0)) + Int32(1 * register0.offset) + reg0_name = register_name(reg_var, state0) + reg1_name = register_name(reg_var, state1) + vals = physics_values(state0, reg_layout) + addr = memory_index(reg_layout, mem_layout, vals) + push!( + emitter.statements, + :( + ($reg0_name, $reg1_name) = IndexSpaces.unsafe_load2( + $mem_var, $(postprocess(addr)) + 0x1 + ) + ), + ) + end elseif align == 16 # Find registers with strides 1 and 2 # TODO: Better error message if not accessing global memory