Skip to content

Commit

Permalink
Improve threading in _set_interpolated_values_device
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jan 24, 2025
1 parent 69eff97 commit e612a32
Showing 1 changed file with 29 additions and 39 deletions.
68 changes: 29 additions & 39 deletions ext/cuda/remapping_distributed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,28 @@ function _set_interpolated_values_device!(
# FIXME: Avoid allocation of tuple
field_values = tuple(map(f -> Fields.field_values(f), fields)...)

purely_vertical_space = isnothing(interpolation_matrix)
num_horizontal_points =
purely_vertical_space ? 1 : prod(size(local_horiz_indices))
num_points = num_horizontal_points * length(vert_interpolation_weights)
max_threads = 256
nthreads = min(num_points, max_threads)
nblocks = cld(num_points, nthreads)
num_horiz = length(local_horiz_indices)
num_vert = length(vert_bounding_indices)
num_fields = length(field_values)
nitems = num_horiz * num_vert * num_fields

_, Nq = size(interpolation_matrix[1])
args = (
out,
interpolation_matrix,
local_horiz_indices,
vert_interpolation_weights,
vert_bounding_indices,
field_values,
Val(Nq),
)
threads = threads_via_occupancy(set_interpolated_values_kernel!, args)
p = linear_partition(nitems, threads)
auto_launch!(
set_interpolated_values_kernel!,
args;
threads_s = (nthreads),
blocks_s = (nblocks),
threads_s = (p.threads),
blocks_s = (p.blocks),
)
call_post_op_callback() && post_op_callback(
out,
Expand All @@ -60,42 +62,30 @@ function set_interpolated_values_kernel!(
vert_interpolation_weights,
vert_bounding_indices,
field_values,
)
# TODO: Check the memory access pattern. This was not optimized and likely inefficient!
::Val{Nq},
) where {Nq}
num_horiz = length(local_horiz_indices)
num_vert = length(vert_bounding_indices)
num_fields = length(field_values)

hindex = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
vindex = (blockIdx().y - Int32(1)) * blockDim().y + threadIdx().y
findex = (blockIdx().z - Int32(1)) * blockDim().z + threadIdx().z

totalThreadsX = gridDim().x * blockDim().x
totalThreadsY = gridDim().y * blockDim().y
totalThreadsZ = gridDim().z * blockDim().z
I = (CUDA.blockIdx().x - Int32(1)) * CUDA.blockDim().x + CUDA.threadIdx().x
inds = (num_horiz, num_vert, num_fields)
(i, j, k) = CartesianIndices(map(x -> Base.OneTo(x), inds))[I].I

_, Nq = size(I1)
1 I prod(inds) || return nothing
CI = CartesianIndex
for i in hindex:totalThreadsX:num_horiz
h = local_horiz_indices[i]
for j in vindex:totalThreadsY:num_vert
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
for k in findex:totalThreadsZ:num_fields
if i num_horiz && j num_vert && k num_fields
out[i, j, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, j, k] +=
I1[i, t] *
I2[i, s] *
(
A * field_values[k][CI(t, s, 1, v_lo, h)] +
B * field_values[k][CI(t, s, 1, v_hi, h)]
)
end
end
end
end
h = local_horiz_indices[i]
v_lo, v_hi = vert_bounding_indices[j]
A, B = vert_interpolation_weights[j]
out[i, j, k] = 0
for t in 1:Nq, s in 1:Nq
out[i, j, k] +=
I1[i, t] *
I2[i, s] *
(
A * field_values[k][CI(t, s, 1, v_lo, h)] +
B * field_values[k][CI(t, s, 1, v_hi, h)]
)
end
return nothing
end
Expand Down

0 comments on commit e612a32

Please sign in to comment.