Skip to content

Commit

Permalink
Improve: use shared memory
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Nov 1, 2024
1 parent 5d96885 commit 5af14f5
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu
Original file line number Diff line number Diff line change
Expand Up @@ -231,36 +231,45 @@ CUDF_KERNEL void partial_group_sketches_from_hashs_kernel(
* Has num_groups rows.
*
*/
template <int block_size>
CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches,
int64_t num_registers_per_sketch,
int* const registers_output_cache,
int const* const registers_thread_cache,
size_type const* const group_lables_thread_cache)
{
// register idx is tid
extern __shared__ int8_t shared_data[];
auto const tid = cudf::detail::grid_1d::global_thread_id();
int reg_max = 0;
int prev_group = group_lables_thread_cache[0];
int shared_idx = tid % block_size;

// register idx is tid
shared_data[shared_idx] = static_cast<int8_t>(0);
int prev_group = group_lables_thread_cache[0];
for (auto i = 0; i < num_sketches; i++) {
int curr_group = group_lables_thread_cache[i];
int curr_reg_v = registers_thread_cache[tid + i * num_registers_per_sketch];
int8_t curr_reg_v =
static_cast<int8_t>(registers_thread_cache[tid + i * num_registers_per_sketch]);
if (curr_group == prev_group) {
if (curr_reg_v > reg_max) { reg_max = curr_reg_v; }
if (curr_reg_v > shared_data[shared_idx]) { shared_data[shared_idx] = curr_reg_v; }
} else {
// meets a new group, store the result for previous group
int64_t reg_idx = prev_group * num_registers_per_sketch + tid;
int curr_reg_v = registers_output_cache[reg_idx];
if (reg_max > curr_reg_v) { registers_output_cache[reg_idx] = reg_max; }
if (shared_data[shared_idx] > curr_reg_v) {
registers_output_cache[reg_idx] = shared_data[shared_idx];
}

reg_max = curr_reg_v;
shared_data[shared_idx] = curr_reg_v;
}
prev_group = curr_group;
}

// handles the last register in this thread
int64_t reg_idx = prev_group * num_registers_per_sketch + tid;
int curr_reg_v = registers_output_cache[reg_idx];
if (reg_max > curr_reg_v) { registers_output_cache[reg_idx] = reg_max; }
if (shared_data[shared_idx] > curr_reg_v) {
registers_output_cache[reg_idx] = shared_data[shared_idx];
}
}

/**
Expand Down Expand Up @@ -363,12 +372,13 @@ std::unique_ptr<column> group_hllpp(column_view const& input,
// 3. merge the intermidate result
auto num_merge_threads = num_registers_per_sketch;
auto num_merge_blocks = cudf::util::div_rounding_up_safe(num_merge_threads, block_size);
merge_sketches_vertically<<<num_merge_blocks, block_size, 0, stream.value()>>>(
total_threads_partial_group, // num_sketches
num_registers_per_sketch,
sketches_output.begin(),
registers_thread_cache.begin(),
group_lables_thread_cache.begin());
merge_sketches_vertically<block_size>
<<<num_merge_blocks, block_size, block_size, stream.value()>>>(
total_threads_partial_group, // num_sketches
num_registers_per_sketch,
sketches_output.begin(),
registers_thread_cache.begin(),
group_lables_thread_cache.begin());

// 4. create output columns
auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1;
Expand Down Expand Up @@ -583,12 +593,13 @@ std::unique_ptr<column> merge_hyper_log_log(
auto const num_phase2_threads = num_registers_per_sketch;
auto const num_phase2_blocks = cudf::util::div_rounding_up_safe(num_phase2_threads, block_size);
// 2nd kernel: vertical merge
merge_sketches_vertically<<<num_phase2_blocks, block_size, 0, stream.value()>>>(
num_threads_per_col_phase1, // num_sketches
num_registers_per_sketch,
registers_output_cache.begin(),
registers_thread_cache.begin(),
group_lables_thread_cache.begin());
merge_sketches_vertically<block_size>
<<<num_phase2_blocks, block_size, block_size, stream.value()>>>(
num_threads_per_col_phase1, // num_sketches
num_registers_per_sketch,
registers_output_cache.begin(),
registers_thread_cache.begin(),
group_lables_thread_cache.begin());
}

// create output columns
Expand Down

0 comments on commit 5af14f5

Please sign in to comment.