Skip to content

Commit

Permalink
Add cache_helper class to simplify function call
Browse files Browse the repository at this point in the history
  • Loading branch information
kingcrimsontianyu committed Dec 21, 2024
1 parent 5c47aaa commit 9d2cac0
Showing 1 changed file with 81 additions and 23 deletions.
104 changes: 81 additions & 23 deletions cpp/src/io/orc/stripe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ struct orcdec_state_s {
* adjacent row groups and its length is greater than the maximum length allowed to be consumed.
* This limit is imposed by the decoder when processing the SECONDARY stream. This class shall be
* instantiated in the shared memory, and be used to cache the DATA stream with a decoded data type
* of `int64_t`. As an optimization, the actual cache is a local variable and does not reside in the
* shared memory.
* of `int64_t`. As an optimization, the actual cache is implemented in the cache_helper class as a
* local variable and does not reside in the shared memory.
*/
class run_cache_manager {
private:
Expand Down Expand Up @@ -204,9 +204,9 @@ class run_cache_manager {
* @brief Copy the excess data from the intermediate buffer for the DATA stream to the cache.
*
* @param[in] src Intermediate buffer for the DATA stream.
* @param[out] cache Local variable serving as the cache for the DATA stream.
* @param[out] cache_storage Local variable serving as the cache for the DATA stream.
*/
__device__ void write_to_cache(int64_t* src, int64_t& cache)
__device__ void write_to_cache(int64_t* src, int64_t& cache_storage)
{
if (_status != status::CAN_WRITE_TO_CACHE) { return; }

Expand All @@ -220,7 +220,7 @@ class run_cache_manager {
auto const length_to_skip = _run_length - _reusable_length;
if (tid < _reusable_length) {
auto const src_idx = tid + length_to_skip;
cache = src[src_idx];
cache_storage = src[src_idx];
}
if (tid == 0) { _status = status::CAN_READ_FROM_CACHE; }
} else {
Expand All @@ -235,9 +235,9 @@ class run_cache_manager {
*
* @param[in,out] dst Intermediate buffer for the DATA stream.
* @param[in,out] rle Run length decoder state object.
* @param[in] cache Local variable serving as the cache for the DATA stream.
* @param[in] cache_storage Local variable serving as the cache for the DATA stream.
*/
__device__ void read_from_cache(int64_t* dst, orc_rlev2_state_s* rle, int64_t cache)
__device__ void read_from_cache(int64_t* dst, orc_rlev2_state_s* rle, int64_t cache_storage)
{
if (_status != status::CAN_READ_FROM_CACHE) { return; }

Expand All @@ -252,7 +252,7 @@ class run_cache_manager {
__syncthreads();

// Second, insert the cached data
if (tid < _reusable_length) { dst[tid] = cache; }
if (tid < _reusable_length) { dst[tid] = cache_storage; }
__syncthreads();

if (tid == 0) {
Expand All @@ -274,6 +274,64 @@ class run_cache_manager {
uint32_t _run_length; ///< The length of the run, 512 in the above example.
};

/**
* @brief Helper class to help run_cache_manager cache the first run of TIMESTAMP's DATA stream for
* a row group.
*
* The run_cache_manager is intended to be stored in the shared memory, whereas the actual cache in
* the local storage as an optimization. If a function is to use run_cache_manager, both the manager
* and the cache objects need to be passed. This class is introduced to simplify the function call,
* so that only a single cache_helper object needs to be passed.
*/
class cache_helper {
public:
/**
* @brief Constructor.
*
* @param[in] run_cache_manager_inst
*/
__device__ explicit cache_helper(run_cache_manager& run_cache_manager_inst)
: _run_cache_manager_inst(run_cache_manager_inst)
{
}

/**
* @brief Wrapper of run_cache_manager's namesake function.
*/
__device__ void set_reusable_length(uint32_t run_length, uint32_t max_length)
{
_run_cache_manager_inst.set_reusable_length(run_length, max_length);
}

/**
* @brief Wrapper of run_cache_manager's namesake function.
*/
__device__ void write_to_cache(int64_t* src)
{
_run_cache_manager_inst.write_to_cache(src, _storage);
}

/**
* @brief Wrapper of run_cache_manager's namesake function.
*/
__device__ void read_from_cache(int64_t* dst, orc_rlev2_state_s* rle)
{
_run_cache_manager_inst.read_from_cache(dst, rle, _storage);
}

/**
* @brief Wrapper of run_cache_manager's namesake function.
*/
[[nodiscard]] __device__ uint32_t adjust_max_length(uint32_t max_length)
{
return _run_cache_manager_inst.adjust_max_length(max_length);
}

private:
run_cache_manager& _run_cache_manager_inst;
int64_t _storage;
};

/**
* @brief Initializes byte stream, modifying length and start position to keep the read pointer
* 8-byte aligned.
Expand Down Expand Up @@ -785,14 +843,11 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
T* vals,
uint32_t maxvals,
int t,
bool has_buffered_values = false,
run_cache_manager* run_cache_manager_inst = nullptr,
int64_t* cache = nullptr)
bool has_buffered_values = false,
cache_helper* cache_helper_inst = nullptr)
{
if (t == 0) {
if (run_cache_manager_inst != nullptr) {
maxvals = run_cache_manager_inst->adjust_max_length(maxvals);
}
if (cache_helper_inst != nullptr) { maxvals = cache_helper_inst->adjust_max_length(maxvals); }
uint32_t maxpos = min(bs->len, bs->pos + (bytestream_buffer_size - 8u));
uint32_t lastpos = bs->pos;
auto numvals = 0;
Expand Down Expand Up @@ -836,9 +891,7 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
}
}

if (run_cache_manager_inst != nullptr) {
run_cache_manager_inst->set_reusable_length(n, maxvals);
}
if (cache_helper_inst != nullptr) { cache_helper_inst->set_reusable_length(n, maxvals); }

if ((numvals != 0) and (numvals + n > maxvals)) break;
// case where there are buffered values and can't consume a whole chunk
Expand Down Expand Up @@ -1024,12 +1077,12 @@ static __device__ uint32_t Integer_RLEv2(orc_bytestream_s* bs,
// Currently run_cache_manager is only designed to fix the TIMESTAMP's DATA stream bug where the
// data type is int64_t.
if constexpr (cuda::std::is_same_v<T, int64_t>) {
if (run_cache_manager_inst != nullptr) {
if (cache_helper_inst != nullptr) {
// Run cache is read from during the 2nd iteration of the top-level while loop in
// gpuDecodeOrcColumnData().
run_cache_manager_inst->read_from_cache(vals, rle, *cache);
cache_helper_inst->read_from_cache(vals, rle);
// Run cache is written to during the 1st iteration of the loop.
run_cache_manager_inst->write_to_cache(vals, *cache);
cache_helper_inst->write_to_cache(vals);
}
}
return rle->num_vals;
Expand Down Expand Up @@ -1568,7 +1621,7 @@ CUDF_KERNEL void __launch_bounds__(block_size)
bool const is_valid = s->chunk.type_kind != STRUCT;
size_t const max_num_rows = s->chunk.column_num_rows;
__shared__ run_cache_manager run_cache_manager_inst;
int64_t cache{};
cache_helper cache_helper_inst(run_cache_manager_inst);
if (t == 0 and is_valid) {
// If we have an index, seek to the initial run and update row positions
if (num_rowgroups > 0) {
Expand Down Expand Up @@ -1772,8 +1825,13 @@ CUDF_KERNEL void __launch_bounds__(block_size)
if (is_rlev1(s->chunk.encoding_kind)) {
numvals = Integer_RLEv1<int64_t>(bs, &s->u.rlev1, s->vals.i64, numvals, t);
} else {
numvals = Integer_RLEv2<int64_t>(
bs, &s->u.rlev2, s->vals.i64, numvals, t, false, &run_cache_manager_inst, &cache);
numvals = Integer_RLEv2<int64_t>(bs,
&s->u.rlev2,
s->vals.i64,
numvals,
t,
false /**has_buffered_values */,
&cache_helper_inst);
}
if (s->chunk.type_kind == DECIMAL) {
// If we're using an index, we may have to drop values from the initial run
Expand Down

0 comments on commit 9d2cac0

Please sign in to comment.