diff --git a/cpp/src/io/orc/stripe_data.cu b/cpp/src/io/orc/stripe_data.cu index 675169ca782..3fe3ae75599 100644 --- a/cpp/src/io/orc/stripe_data.cu +++ b/cpp/src/io/orc/stripe_data.cu @@ -139,8 +139,8 @@ struct orcdec_state_s { * adjacent row groups and its length is greater than the maximum length allowed to be consumed. * This limit is imposed by the decoder when processing the SECONDARY stream. This class shall be * instantiated in the shared memory, and be used to cache the DATA stream with a decoded data type - * of `int64_t`. As an optimization, the actual cache is a local variable and does not reside in the - * shared memory. + * of `int64_t`. As an optimization, the actual cache is implemented in the cache_helper class as a + * local variable and does not reside in the shared memory. */ class run_cache_manager { private: @@ -204,9 +204,9 @@ class run_cache_manager { * @brief Copy the excess data from the intermediate buffer for the DATA stream to the cache. * * @param[in] src Intermediate buffer for the DATA stream. - * @param[out] cache Local variable serving as the cache for the DATA stream. + * @param[out] cache_storage Local variable serving as the cache for the DATA stream. */ - __device__ void write_to_cache(int64_t* src, int64_t& cache) + __device__ void write_to_cache(int64_t* src, int64_t& cache_storage) { if (_status != status::CAN_WRITE_TO_CACHE) { return; } @@ -220,7 +220,7 @@ class run_cache_manager { auto const length_to_skip = _run_length - _reusable_length; if (tid < _reusable_length) { auto const src_idx = tid + length_to_skip; - cache = src[src_idx]; + cache_storage = src[src_idx]; } if (tid == 0) { _status = status::CAN_READ_FROM_CACHE; } } else { @@ -235,9 +235,9 @@ class run_cache_manager { * * @param[in,out] dst Intermediate buffer for the DATA stream. * @param[in,out] rle Run length decoder state object. - * @param[in] cache Local variable serving as the cache for the DATA stream. + * @param[in] cache_storage Local variable serving as the cache for the DATA stream. */ - __device__ void read_from_cache(int64_t* dst, orc_rlev2_state_s* rle, int64_t cache) + __device__ void read_from_cache(int64_t* dst, orc_rlev2_state_s* rle, int64_t cache_storage) { if (_status != status::CAN_READ_FROM_CACHE) { return; } @@ -252,7 +252,7 @@ class run_cache_manager { __syncthreads(); // Second, insert the cached data - if (tid < _reusable_length) { dst[tid] = cache; } + if (tid < _reusable_length) { dst[tid] = cache_storage; } __syncthreads(); if (tid == 0) { @@ -274,6 +274,64 @@ class run_cache_manager { uint32_t _run_length; ///< The length of the run, 512 in the above example. }; +/** + * @brief Helper class to help run_cache_manager cache the first run of TIMESTAMP's DATA stream for + * a row group. + * + * The run_cache_manager is intended to be stored in the shared memory, whereas the actual cache in + * the local storage as an optimization. If a function is to use run_cache_manager, both the manager + * and the cache objects need to be passed. This class is introduced to simplify the function call, + * so that only a single cache_helper object needs to be passed. + */ +class cache_helper { + public: + /** + * @brief Constructor. + * + * @param[in] run_cache_manager_inst + */ + __device__ explicit cache_helper(run_cache_manager& run_cache_manager_inst) + : _run_cache_manager_inst(run_cache_manager_inst) + { + } + + /** + * @brief Wrapper of run_cache_manager's namesake function. + */ + __device__ void set_reusable_length(uint32_t run_length, uint32_t max_length) + { + _run_cache_manager_inst.set_reusable_length(run_length, max_length); + } + + /** + * @brief Wrapper of run_cache_manager's namesake function. + */ + __device__ void write_to_cache(int64_t* src) + { + _run_cache_manager_inst.write_to_cache(src, _storage); + } + + /** + * @brief Wrapper of run_cache_manager's namesake function. + */ + __device__ void read_from_cache(int64_t* dst, orc_rlev2_state_s* rle) + { + _run_cache_manager_inst.read_from_cache(dst, rle, _storage); + } + + /** + * @brief Wrapper of run_cache_manager's namesake function. + */ + [[nodiscard]] __device__ uint32_t adjust_max_length(uint32_t max_length) + { + return _run_cache_manager_inst.adjust_max_length(max_length); + } + + private: + run_cache_manager& _run_cache_manager_inst; + int64_t _storage; +}; + /** * @brief Initializes byte stream, modifying length and start position to keep the read pointer * 8-byte aligned. @@ -785,14 +843,11 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs, T* vals, uint32_t maxvals, int t, - bool has_buffered_values = false, - run_cache_manager* run_cache_manager_inst = nullptr, - int64_t* cache = nullptr) + bool has_buffered_values = false, + cache_helper* cache_helper_inst = nullptr) { if (t == 0) { - if (run_cache_manager_inst != nullptr) { - maxvals = run_cache_manager_inst->adjust_max_length(maxvals); - } + if (cache_helper_inst != nullptr) { maxvals = cache_helper_inst->adjust_max_length(maxvals); } uint32_t maxpos = min(bs->len, bs->pos + (bytestream_buffer_size - 8u)); uint32_t lastpos = bs->pos; auto numvals = 0; @@ -836,9 +891,7 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs, } } - if (run_cache_manager_inst != nullptr) { - run_cache_manager_inst->set_reusable_length(n, maxvals); - } + if (cache_helper_inst != nullptr) { cache_helper_inst->set_reusable_length(n, maxvals); } if ((numvals != 0) and (numvals + n > maxvals)) break; // case where there are buffered values and can't consume a whole chunk @@ -1024,12 +1077,12 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs, // Currently run_cache_manager is only designed to fix the TIMESTAMP's DATA stream bug where the // data type is int64_t. if constexpr (cuda::std::is_same_v) { - if (run_cache_manager_inst != nullptr) { + if (cache_helper_inst != nullptr) { // Run cache is read from during the 2nd iteration of the top-level while loop in // gpuDecodeOrcColumnData(). - run_cache_manager_inst->read_from_cache(vals, rle, *cache); + cache_helper_inst->read_from_cache(vals, rle); // Run cache is written to during the 1st iteration of the loop. - run_cache_manager_inst->write_to_cache(vals, *cache); + cache_helper_inst->write_to_cache(vals); } } return rle->num_vals; @@ -1568,7 +1621,7 @@ CUDF_KERNEL void __launch_bounds__(block_size) bool const is_valid = s->chunk.type_kind != STRUCT; size_t const max_num_rows = s->chunk.column_num_rows; __shared__ run_cache_manager run_cache_manager_inst; - int64_t cache{}; + cache_helper cache_helper_inst(run_cache_manager_inst); if (t == 0 and is_valid) { // If we have an index, seek to the initial run and update row positions if (num_rowgroups > 0) { @@ -1772,8 +1825,13 @@ CUDF_KERNEL void __launch_bounds__(block_size) if (is_rlev1(s->chunk.encoding_kind)) { numvals = Integer_RLEv1(bs, &s->u.rlev1, s->vals.i64, numvals, t); } else { - numvals = Integer_RLEv2( - bs, &s->u.rlev2, s->vals.i64, numvals, t, false, &run_cache_manager_inst, &cache); + numvals = Integer_RLEv2(bs, + &s->u.rlev2, + s->vals.i64, + numvals, + t, + false /**has_buffered_values */, + &cache_helper_inst); } if (s->chunk.type_kind == DECIMAL) { // If we're using an index, we may have to drop values from the initial run