Skip to content

Commit

Permalink
Bump minor version
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett committed Sep 2, 2024
1 parent a3a411e commit 963bdd0
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "IndexSpaces"
uuid = "3ceaaf38-e7b4-4193-a23a-303dd231a83a"
authors = ["Erik Schnetter <[email protected]>"]
version = "1.5.1"
version = "1.6.0"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
58 changes: 58 additions & 0 deletions src/IndexSpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,39 @@ end

# Memory access

# Address space 1 is global, 3 is shared
function unsafe_load2(ptr::Core.LLVMPtr{Int32,1})
return Base.llvmcall(
"""
%ptr = bitcast i8 addrspace(1)* %0 to [2 x i32] addrspace(1)*
%val = load [2 x i32], [2 x i32] addrspace(1)* %ptr, align 16
ret [2 x i32] %val
""",
NTuple{2,Int32},
Tuple{Core.LLVMPtr{Int32,1}},
ptr,
)
end
function unsafe_load2(ptr::Core.LLVMPtr{Int32,3})
return Base.llvmcall(
"""
%ptr = bitcast i8 addrspace(3)* %0 to [2 x i32] addrspace(3)*
%val = load [2 x i32], [2 x i32] addrspace(3)* %ptr, align 16
ret [2 x i32] %val
""",
NTuple{2,Int32},
Tuple{Core.LLVMPtr{Int32,3}},
ptr,
)
end
unsafe_load2(arr::CuDeviceArray{Int32}, idx::Integer) = unsafe_load2(pointer(arr, idx))
function unsafe_load2(arr::CuDeviceArray{T}, idx::Integer) where {T}
@assert sizeof(T) == sizeof(Int32)
res = unsafe_load2(reinterpret(Int32, arr), idx)::NTuple{2,Int32}
# return ntuple(n -> reinterpret(T, res[n]), 2)::NTuple{2,T}
return ntuple(n -> T(res[n] % UInt32), 2)::NTuple{2,T}
end

# Address space 1 is global, 3 is shared
function unsafe_load4(ptr::Core.LLVMPtr{Int32,1})
return Base.llvmcall(
Expand Down Expand Up @@ -989,6 +1022,31 @@ function load!(
addr = memory_index(reg_layout, mem_layout, vals)
push!(emitter.statements, :($reg_name = $mem_var[$(postprocess(addr)) + 0x1]))
end
elseif align == 8
# Find registers with stride 1
# TODO: Better error message if not accessing global memory
phys0 = inv(mem_layout)[Memory(:memory, 1, 2)]
register0 = reg_layout[phys0]::Register
tmp_layout = copy(reg_layout)
delete!(tmp_layout, phys0)
loop_over_registers(emitter, tmp_layout) do state
state0 = copy(state)
state1 = copy(state)
state0.dict[register0.name] = get(state0.dict, register0.name, Int32(0)) + Int32(0 * register0.offset)
state1.dict[register0.name] = get(state1.dict, register0.name, Int32(0)) + Int32(1 * register0.offset)
reg0_name = register_name(reg_var, state0)
reg1_name = register_name(reg_var, state1)
vals = physics_values(state0, reg_layout)
addr = memory_index(reg_layout, mem_layout, vals)
push!(
emitter.statements,
:(
($reg0_name, $reg1_name) = IndexSpaces.unsafe_load2(
$mem_var, $(postprocess(addr)) + 0x1
)
),
)
end
elseif align == 16
# Find registers with strides 1 and 2
# TODO: Better error message if not accessing global memory
Expand Down

0 comments on commit 963bdd0

Please sign in to comment.