Skip to content

Commit

Permalink
Add skipping to rle_stream, use for lists (chunked reads)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmattione-nvidia committed Sep 23, 2024
1 parent 8852839 commit e285fbf
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 47 deletions.
112 changes: 65 additions & 47 deletions cpp/src/io/parquet/decode_fixed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ __device__ inline void gpuDecodeFixedWidthValues(

static constexpr bool enable_print = false;
static constexpr bool enable_print_range_error = false;
static constexpr bool enable_print_large_list = false;
// static constexpr bool enable_print_large_list = true;

if constexpr (enable_print) {
if(t == 0) { printf("DECODE VALUES: start %d, end %d, first_row %d, leaf_level_index %d, dtype_len %u, "
Expand Down Expand Up @@ -133,7 +133,7 @@ __device__ inline void gpuDecodeFixedWidthValues(
} else {
gpuOutputGeneric(s, sb, src_pos, static_cast<uint8_t*>(dst), dtype_len);
}

/*
if constexpr (enable_print_large_list) {
if (dtype == INT32) {
int value_stored = *static_cast<uint32_t*>(dst);
Expand All @@ -143,6 +143,7 @@ __device__ inline void gpuDecodeFixedWidthValues(
}
}
}
*/
}

pos += batch_size;
Expand Down Expand Up @@ -628,15 +629,7 @@ static __device__ int gpuUpdateValidityAndRowIndicesLists(
int max_depth_valid_count = s->nesting_info[max_depth].valid_count;

__syncthreads();

if constexpr (enable_print_large_list) {
auto first_ni_value_count = s->nesting_info[0].value_count;
if((value_count != (4*input_row_count)) || (input_row_count != first_ni_value_count)){
printf("ALGO GARBAGE GET: blockIdx.x %d, value_count %d, target_value_count %d, t %d, value_count %d, input_row_count %d, first_ni_value_count %d\n",
blockIdx.x, value_count, target_value_count, t, value_count, input_row_count, first_ni_value_count);
}
}


using block_scan = cub::BlockScan<int, decode_block_size>;
__shared__ typename block_scan::TempStorage scan_storage;

Expand Down Expand Up @@ -700,15 +693,15 @@ if constexpr (enable_print_large_list) {
__syncthreads();

if constexpr (enable_print_large_list) {
if(bool(is_new_row) != (t % 4 == 0)) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, is_new_row %d\n",
blockIdx.x, value_count, target_value_count, t, is_new_row);
if(within_batch && (bool(is_new_row) != (t % 4 == 0))) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, is_new_row %d, start_depth %d, rep_level %d\n",
blockIdx.x, value_count, target_value_count, t, is_new_row, start_depth, rep_level);
}
if(num_prior_new_rows != ((t + 3) / 4)) {
if(within_batch && (num_prior_new_rows != ((t + 3) / 4))) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, num_prior_new_rows %d\n",
blockIdx.x, value_count, target_value_count, t, num_prior_new_rows);
}
if(total_num_new_rows != 32) {
if((value_count + 128 <= target_value_count) && (total_num_new_rows != 32)) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, total_num_new_rows %d\n",
blockIdx.x, value_count, target_value_count, t, total_num_new_rows);
}
Expand Down Expand Up @@ -747,15 +740,17 @@ if constexpr (enable_print_large_list) {
int block_value_count = value_count_scan_results.block_count;

if constexpr (enable_print_large_list) {
if(in_nesting_bounds != (t % 4 == 0)) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, in_nesting_bounds %d, start_depth %d, end_depth %d, in_row_bounds %d, row_index %d, input_row_count %d\n",
blockIdx.x, value_count, target_value_count, t, in_nesting_bounds, start_depth, end_depth, in_row_bounds, row_index, input_row_count);
if(within_batch && in_row_bounds && (in_nesting_bounds != (t % 4 == 0))) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, in_nesting_bounds %d, start_depth %d, end_depth %d, "
"in_row_bounds %d, row_index %d, input_row_count %d, row_index_lower_bound %d, last_row %d, first_row %d, s->num_rows %d\n",
blockIdx.x, value_count, target_value_count, t, in_nesting_bounds, start_depth, end_depth, in_row_bounds, row_index, input_row_count,
row_index_lower_bound, last_row, first_row, s->num_rows);
}
if(thread_value_count != ((t + 3) / 4)) {
if(within_batch && in_row_bounds && (thread_value_count != ((t + 3) / 4))) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, thread_value_count %d\n",
blockIdx.x, value_count, target_value_count, t, thread_value_count);
}
if(block_value_count != 32) {
if((value_count + 128 <= target_value_count) && (input_row_count + total_num_new_rows <= last_row) && (block_value_count != 32)) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, block_value_count %d\n",
blockIdx.x, value_count, target_value_count, t, block_value_count);
}
Expand Down Expand Up @@ -813,15 +808,15 @@ if constexpr (enable_print_large_list) {
int block_valid_count = valid_count_scan_results.block_count;

if constexpr (enable_print_large_list) {
if(((d_idx == 0) && (is_valid != (t % 4 == 0))) || ((d_idx == 1) && !is_valid)) {
if(within_batch && in_row_bounds && (((d_idx == 0) && (is_valid != (t % 4 == 0))) || ((d_idx == 1) && !is_valid))) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, d_idx %d, is_valid %d, in_nesting_bounds %d\n",
blockIdx.x, value_count, target_value_count, t, d_idx, is_valid, in_nesting_bounds);
}
if (((d_idx == 0) && (thread_valid_count != ((t + 3)/ 4))) || ((d_idx == 1) && (thread_valid_count != t))) {
if (within_batch && in_row_bounds && (((d_idx == 0) && (thread_valid_count != ((t + 3)/ 4))) || ((d_idx == 1) && (thread_valid_count != t)))) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, d_idx %d, thread_valid_count %d\n",
blockIdx.x, value_count, target_value_count, t, d_idx, thread_valid_count);
}
if(((d_idx == 0) && (block_valid_count != 32)) || ((d_idx == 1) && (block_valid_count != 128))) {
if((value_count + 128 <= target_value_count) && (input_row_count + total_num_new_rows <= last_row) && (((d_idx == 0) && (block_valid_count != 32)) || ((d_idx == 1) && (block_valid_count != 128)))) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, d_idx %d, block_valid_count %d\n",
blockIdx.x, value_count, target_value_count, t, d_idx, block_valid_count);
}
Expand Down Expand Up @@ -859,17 +854,16 @@ if constexpr (enable_print_large_list) {
next_thread_value_count = next_value_count_scan_results.thread_count_within_block;
next_block_value_count = next_value_count_scan_results.block_count;


if constexpr (enable_print_large_list) {
if(next_in_nesting_bounds != 1) {
if(within_batch && in_row_bounds && (next_in_nesting_bounds != 1)) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, next_in_nesting_bounds %d, start_depth %d, end_depth %d, in_row_bounds %d, row_index %d, input_row_count %d\n",
blockIdx.x, value_count, target_value_count, t, next_in_nesting_bounds, start_depth, end_depth, in_row_bounds, row_index, input_row_count);
}
if(next_thread_value_count != t) {
if(within_batch && in_row_bounds && (next_thread_value_count != t)) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, next_thread_value_count %d\n",
blockIdx.x, value_count, target_value_count, t, next_thread_value_count);
}
if(next_block_value_count != 128) {
if((value_count + 128 <= target_value_count) && (input_row_count + total_num_new_rows <= last_row) && (next_block_value_count != 128)) {
printf("CUB GARBAGE: blockIdx.x %d, value_count %d, target_value_count %d, t %d, next_block_value_count %d\n",
blockIdx.x, value_count, target_value_count, t, next_block_value_count);
}
Expand All @@ -893,9 +887,11 @@ if constexpr (enable_print_large_list) {
//STORE THE OFFSET FOR THE NEW LIST LOCATION
(reinterpret_cast<cudf::size_type*>(ni.data_out))[idx] = ofs;

/*
if constexpr (enable_print_large_list) {
int overall_index = 4*(blockIdx.x * 20000 + idx);
if(overall_index != ofs) {
printf("WHOA BAD OFFSET\n");
printf("WHOA BAD OFFSET: WROTE %d to %d! t %d, blockIdx.x %d, idx %d, d_idx %d, start_depth %d, end_depth %d, max_depth %d, "
"in_row_bounds %d, in_nesting_bounds %d, next_in_nesting_bounds %d, row_index %d, row_index_lower_bound %d, last_row %d, "
"input_row_count %d, num_prior_new_rows %d, is_new_row %d, total_num_new_rows %d, rep_level %d, def_level %d, ni.value_count %d, "
Expand All @@ -907,7 +903,7 @@ if constexpr (enable_print_large_list) {
next_thread_value_count, next_ni.page_start_value, value_count, target_value_count, block_value_count, next_block_value_count);
}
}

*/
if constexpr (enable_print || enable_print_range_error) {
if((idx < 0) || (idx > 50000)){ printf("WHOA: offset index %d out of bounds!\n", idx); }
if(ofs < 0){ printf("WHOA: offset value %d out of bounds!\n", ofs); }
Expand Down Expand Up @@ -1030,13 +1026,6 @@ if constexpr (enable_print_large_list) {

// If we have lists # rows != # values
s->input_row_count = input_row_count;
if constexpr (enable_print_large_list) {
auto first_ni_value_count = s->nesting_info[0].value_count;
if((value_count != (4*input_row_count)) || (input_row_count != first_ni_value_count)){
printf("ALGO GARBAGE SET: blockIdx.x %d, value_count %d, target_value_count %d, t %d, value_count %d, input_row_count %d, first_ni_value_count %d\n",
blockIdx.x, value_count, target_value_count, t, value_count, input_row_count, first_ni_value_count);
}
}
}

return max_depth_valid_count;
Expand Down Expand Up @@ -1069,6 +1058,32 @@ __device__ inline bool maybe_has_nulls(page_state_s* s)
return run_val != s->col.max_level[lvl];
}

template <int decode_block_size_t, typename stream_type>
__device__ int skip_decode(stream_type& parquet_stream, int num_to_skip, int t)
{
static constexpr bool enable_print = false;

//Dictionary
int num_skipped = parquet_stream.skip_decode(t, num_to_skip);
if constexpr (enable_print) {
if (t == 0) { printf("SKIPPED: num_skipped %d, for %d\n", num_skipped, num_to_skip); }
}
//it could be that (e.g.) we skip 5000 but starting at row 4000 we have a run of length 2000:
//in that case skip_decode() only skips 4000, and we have to process the remaining 1000 up front
//modulo 2 * block_size of course, since that's as many as we process at once
while (num_skipped < num_to_skip) {
auto const to_skip = min(2*decode_block_size_t, num_to_skip - num_skipped);
parquet_stream.decode_next(t, to_skip);
num_skipped += to_skip;
if constexpr (enable_print) {
if (t == 0) { printf("EXTRA SKIPPED: to_skip %d, at %d, for %d\n", to_skip, num_skipped, num_to_skip); }
}
__syncthreads();
}

return num_skipped;
}

/**
* @brief Kernel for computing fixed width non dictionary column data stored in the pages
*
Expand Down Expand Up @@ -1190,18 +1205,8 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t)
if(t == 0) { printf("INIT DICT: dict_bits %d, data_start %p, data_end %p, dict_idx %p, page.num_input_values %d, s->dict_pos %d \n",
s->dict_bits, s->data_start, s->data_end, sb->dict_idx, s->page.num_input_values, s->dict_pos); }
}
if constexpr (has_lists_t){
int init_decode = 0;
while (init_decode < s->page.skipped_leaf_values) {
auto const to_skip = min(decode_block_size_t, s->page.skipped_leaf_values - init_decode);
dict_stream.decode_next(t, to_skip);
init_decode += to_skip;
__syncthreads();
}
}
}
__syncthreads();


if constexpr (enable_print) {
if((t == 0) && (page_idx == 0)){
printf("SIZES: shared_rep_size %d, shared_dict_size %d, shared_def_size %d\n", shared_rep_size, shared_dict_size, shared_def_size);
Expand All @@ -1225,6 +1230,19 @@ CUDF_KERNEL void __launch_bounds__(decode_block_size_t)
// the core loop. decode batches of level stream data using rle_stream objects
// and pass the results to gpuDecodeValues

//For lists (which can have skipped values, skip ahead in the decoding so that we don't repeat work
if constexpr (has_lists_t){
if(s->page.skipped_leaf_values > 0) {
if (should_process_nulls) {
skip_decode<decode_block_size_t>(def_decoder, s->page.skipped_leaf_values, t);
}
processed_count = skip_decode<decode_block_size_t>(rep_decoder, s->page.skipped_leaf_values, t);
if constexpr (has_dict_t) {
skip_decode<decode_block_size_t>(dict_stream, s->page.skipped_leaf_values, t);
}
}
}

if constexpr (enable_print) {
if(t == 0) { printf("page_idx %d, nullable %d, should_process_nulls %d, has_lists_t %d, has_dict_t %d, num_rows %lu, page.num_input_values %d\n",
page_idx, int(nullable), int(should_process_nulls), int(has_lists_t), int(has_dict_t), num_rows, s->page.num_input_values); }
Expand Down
69 changes: 69 additions & 0 deletions cpp/src/io/parquet/rle_stream.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ struct rle_stream {
run.level_run = level_run;
run.remaining = run.size;
cur += run_bytes;
//printf("STORE RUN: decode_index %d, fill_index %d, output_pos %d, run.size %d\n",
//decode_index, fill_index, output_pos, run.size);
output_pos += run.size;
fill_index++;
}
Expand Down Expand Up @@ -353,6 +355,8 @@ struct rle_stream {
// this is the last batch we will process this iteration if:
// - either this run still has remaining values
// - or it is consumed fully and its last index corresponds to output_count
//printf("STATUS: run_index %d, batch_len %d, remaining %d, at_end %d, last_run_pos %d, cur_values %d\n",
//run_index, batch_len, remaining, at_end, last_run_pos, cur_values);
if (remaining > 0 || at_end) { values_processed_shared = output_count; }
if (remaining == 0 && (at_end || is_last_decode_warp(warp_id))) {
decode_index_shared = run_index + 1;
Expand All @@ -372,6 +376,71 @@ struct rle_stream {
return values_processed_shared;
}

__device__ inline int skip_runs(int target_count)
{
//we want to process all runs UP TO BUT NOT INCLUDING the run that overlaps with the skip amount
//so thread 0 spins like crazy on fill_run_batch(), skipping writing unnecessary run info
//then when it hits the one that matters, we don't process it at all and bail as if we never started
//basically we're setting up the global vars necessary to start fill_run_batch for the first time
while (cur < end) {
// bytes for the varint header
uint8_t const* _cur = cur;
int const level_run = get_vlq32(_cur, end);

// run_bytes includes the header size
int run_bytes = _cur - cur;
int run_size;
if (is_literal_run(level_run)) {
// from the parquet spec: literal runs always come in multiples of 8 values.
run_size = (level_run >> 1) * 8;
run_bytes += ((run_size * level_bits) + 7) >> 3;
} else {
// repeated value run
run_size = (level_run >> 1);
run_bytes += ((level_bits) + 7) >> 3;
}

if((output_pos + run_size) > target_count) {
//printf("SKIPPING: target_count %d, run_size %d, output_pos %d\n", target_count, run_size, output_pos);
return output_pos; //bail! we've reached the starting one
}

output_pos += run_size;
cur += run_bytes;
}

//printf("SKIPPING: target_count %d, output_pos %d\n", target_count, output_pos);
return output_pos; //we skipped everything
}


__device__ inline int skip_decode(int t, int count)
{
int const output_count = min(count, total_values - cur_values);

// special case. if level_bits == 0, just return all zeros. this should tremendously speed up
// a very common case: columns with no nulls, especially if they are non-nested
if (level_bits == 0) {
cur_values = output_count;
return output_count;
}

__shared__ int values_processed_shared;

__syncthreads();

// warp 0 reads ahead and fills `runs` array to be decoded by remaining warps.
if (t == 0) {
values_processed_shared = skip_runs(output_count);
}
__syncthreads();

cur_values = values_processed_shared;

// valid for every thread
return values_processed_shared;
}

__device__ inline int decode_next(int t) { return decode_next(t, max_output_values); }
};

Expand Down

0 comments on commit e285fbf

Please sign in to comment.