Skip to content

Commit

Permalink
Merge pull request #5 from eschnett/eschnett/shared-layouts-3
Browse files Browse the repository at this point in the history
Improve shared memory layout handling
  • Loading branch information
eschnett authored Jul 18, 2024
2 parents 92d2d26 + dd62af7 commit a3a411e
Showing 1 changed file with 50 additions and 37 deletions.
87 changes: 50 additions & 37 deletions src/IndexSpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,35 @@ indexvalue(::State, unrolled_loop::UnrolledLoop) = :($(unrolled_loop.name)::Int3
indexvalue(::State, ::Shared) = @assert false
indexvalue(::State, ::Memory) = @assert false

typerange(::Type{T}) where {T} = typemin(T):typemax(T)

function combine_32_64(vals32::AbstractVector{<:Code}, vals64::AbstractVector{<:Code})
@assert all(val isa Number ? val in typerange(Int32) : true for val in vals32)
@assert all(val isa Number ? val in typerange(Int64) && !(val in typerange(Int32)) : true for val in vals64)
if isempty(vals32)
val32 = 0i32
elseif length(vals32) == 1
val32 = vals32[1]
else
val32 = :(+($(vals32...)))
end
val32 = evaluate_partially(val32)
if isempty(vals64)
val = val32
else
if length(vals64) == 1
val64 = vals64[1]
else
val64 = :(+($(vals64...)))
end
val64 = evaluate_partially(val64)
val = :($val32 + $val64)
end
return val::Code
end

function physics_values(state::State, reg_layout::Layout{Physics,Machine})
vals = Dict{IndexTag,Code}()
vals = Dict{IndexTag,NTuple{2,Vector{Code}}}()
for (phys, mach) in reg_layout.dict
mach isa SIMD && continue
machtag = indextag(mach)
Expand All @@ -619,16 +646,19 @@ function physics_values(state::State, reg_layout::Layout{Physics,Machine})
@assert !(phys isa SIMD)
phystag = indextag(phys)
physoff = Int32(phys.offset)
# physlen = Int32(phys.length)
physval = :($val * $physoff)
oldphysval = get(vals, phystag, 0i32)
if oldphysval isa Number
oldphysval::Int32
physlen = Int32(phys.length)
physvals32, physvals64 = get!(vals, phystag, (Code[], Code[]))
if (physlen - 1) * Int64(physoff) in typerange(Int32)
# We can use 32-bit indexing
physval = :($val * $physoff)
push!(physvals32, physval)
else
# We need 64-bit indexing
physval = :($val * $(Int64(physoff)))
push!(physvals64, physval)
end
physval = :($oldphysval + $physval)
vals[phystag] = physval
end
return Dict{IndexTag,Code}(k => evaluate_partially(v) for (k, v) in vals)
return Dict{IndexTag,Code}(k => combine_32_64(vals32, vals64) for (k, (vals32, vals64)) in vals)
end

function memory_index(reg_layout::Layout{Physics,Machine}, mem_layout::Layout{Physics,Machine}, vals::Dict{IndexTag,Code})
Expand All @@ -648,16 +678,17 @@ function memory_index(reg_layout::Layout{Physics,Machine}, mem_layout::Layout{Ph

is_shared = any(mach isa Shared for (phys, mach) in mem_layout.dict)
is_memory = any(mach isa Memory for (phys, mach) in mem_layout.dict)
@assert is_shared + is_memory == 1
# If the memory region is very small (4 bytes long) then there are
# neither Shared nor Memory references
@assert is_shared + is_memory <= 1

# addr = 0i32
addrs32 = Code[]
addrs64 = Code[]
for (phys, mach) in mem_layout.dict
mach isa SIMD && continue
# Ensure that we are mapping to memory
is_shared && mach::Union{Block,Shared,Loop,UnrolledLoop}
is_memory && mach::Memory
# Ensure that we are mapping to the right kind of memory
mach isa Union{Block,Shared,Loop,UnrolledLoop} && @assert is_shared
mach isa Memory && @assert is_memory
# Only memory addresses contribute
!(mach isa Union{Shared,Memory}) && continue

Expand All @@ -669,7 +700,7 @@ function memory_index(reg_layout::Layout{Physics,Machine}, mem_layout::Layout{Ph

machoff = Int32(mach.offset)
machlen = Int32(mach.length)
if (machlen - 1) * Int64(machoff) <= typemax(Int32)
if (machlen - 1) * Int64(machoff) in typerange(Int32)
# We can use 32-bit indexing
machval = :($val * $machoff)
push!(addrs32, machval)
Expand All @@ -680,25 +711,7 @@ function memory_index(reg_layout::Layout{Physics,Machine}, mem_layout::Layout{Ph
push!(addrs64, machval)
end
end
if isempty(addrs32)
addr32 = 0i32
elseif length(addrs32) == 1
addr32 = addrs32[1]
else
addr32 = :(+($(addrs32...)))
end
addr32 = evaluate_partially(addr32)
if isempty(addrs64)
addr = addr32
else
if length(addrs64) == 1
addr64 = addrs64[1]
else
addr64 = :(+($(addrs64...)))
end
addr64 = evaluate_partially(addr64)
addr = :($addr + $addr64)
end
addr = combine_32_64(addrs32, addrs64)
return addr::Code
end

Expand Down Expand Up @@ -1064,7 +1077,7 @@ function store!(
mem::Pair{Symbol,Layout{Physics,Machine}},
reg_var::Symbol;
align::Int=4,
condition=nothing,
condition=Returns(true),
offset::Code=0,
postprocess=identity,
)
Expand All @@ -1080,7 +1093,7 @@ function store!(
stmt = quote
$mem_var[$(postprocess(:($addr + $offset))) + 0x1] = $reg_name
end
if cond !== nothing
if cond !== true
stmt = quote
if $cond
$stmt
Expand Down Expand Up @@ -1124,7 +1137,7 @@ function store!(
$mem_var, $(postprocess(:($addr + $offset))) + 0x1, ($reg0_name, $reg1_name, $reg2_name, $reg3_name)
)
end
if cond !== nothing
if cond !== true
stmt = quote
if $cond
$stmt
Expand Down

0 comments on commit a3a411e

Please sign in to comment.