Skip to content

Commit

Permalink
Reduction for hllpp
Browse files Browse the repository at this point in the history
  • Loading branch information
Chong Gao committed Nov 21, 2024
1 parent abb4cad commit 77ea21c
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/types.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <rmm/cuda_stream_view.hpp>

namespace cudf {
namespace groupby::detail {

/**
* Compute the hashs of the input column, then generate a scalar that is a sketch in long array
* format
*/
std::unique_ptr<scalar> reduce_hyper_log_log_plus_plus(column_view const& input,
int64_t const precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);

/**
* Merge sketches in long array format, and compute the estimated distinct value(long)
* Input is a struct column with multiple long columns which is consistent with Spark.
*/
std::unique_ptr<scalar> reduce_merge_hyper_log_log_plus_plus(column_view const& input,
int64_t const precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr);

} // namespace groupby::detail
} // namespace cudf
231 changes: 219 additions & 12 deletions cpp/src/groupby/sort/group_hyper_log_log_plus_plus.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ constexpr int REGISTERS_PER_LONG = 64 / REGISTER_VALUE_BITS;
// XXHash seed
constexpr int64_t SEED = 42L;

// max precision, if require a precision bigger than 18, then use 18.
constexpr int MAX_PRECISION = 18;

/**
*
* Computes register values from hash values and partially groups from register values.
Expand Down Expand Up @@ -238,7 +241,7 @@ CUDF_KERNEL void merge_sketches_vertically(int64_t num_sketches,
int const* const registers_thread_cache,
size_type const* const group_lables_thread_cache)
{
extern __shared__ int8_t shared_data[];
__shared__ int8_t shared_data[block_size];
auto const tid = cudf::detail::grid_1d::global_thread_id();
int shared_idx = tid % block_size;

Expand Down Expand Up @@ -626,6 +629,172 @@ std::unique_ptr<column> merge_hyper_log_log(
return make_structs_column(num_groups, std::move(results), 0, rmm::device_buffer{});
}

/**
* launch only 1 block
*/
template <int block_size>
CUDF_KERNEL void reduce_hllpp_kernel(column_device_view hashs, int32_t* const output, int precision)
{
__shared__ int32_t shared_data[block_size];

auto const tid = cudf::detail::grid_1d::global_thread_id();
auto const num_hashs = hashs.size();
uint64_t const num_registers_per_sketch = 1L << precision;
int const idx_shift = 64 - precision;
uint64_t const w_padding = 1ULL << (precision - 1);

// init tmp data
for (int i = tid; i < num_registers_per_sketch; i += block_size) {
shared_data[i] = 0;
}
__syncthreads();

// update max reg value
for (int i = tid; i < num_hashs; i += block_size) {
uint64_t const hash = static_cast<uint64_t>(hashs.element<int64_t>(i));
uint64_t const reg_idx = hash >> idx_shift;
int const reg_v =
static_cast<int>(cuda::std::countl_zero((hash << precision) | w_padding) + 1ULL);
cuda::atomic_ref<int32_t, cuda::thread_scope_block> register_ref(shared_data[reg_idx]);
register_ref.fetch_max(reg_v, cuda::memory_order_relaxed);
}
__syncthreads();

// copy to output
for (int i = tid; i < num_registers_per_sketch; i += block_size) {
output[i] = shared_data[i];
}
}

std::unique_ptr<scalar> reduce_hllpp(column_view const& input,
int64_t const precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
int64_t num_registers_per_sketch = 1L << precision;
// 1. compute all the hashs
auto hash_col =
make_numeric_column(data_type{type_id::INT64}, input.size(), mask_state::ALL_VALID, stream, mr);
auto input_table = cudf::table_view{{input}};
auto d_input_table = cudf::table_device_view::create(input_table, stream);
bool const nullable = input.has_nulls();
thrust::tabulate(
rmm::exec_policy(stream),
hash_col->mutable_view().begin<int64_t>(),
hash_col->mutable_view().end<int64_t>(),
cudf::hashing::detail::xxhash_64_device_row_hasher(nullable, *d_input_table, SEED));
auto d_hashs = cudf::column_device_view::create(hash_col->view(), stream);

// 2. reduce
rmm::device_uvector<int32_t> output_tmp(num_registers_per_sketch, stream, mr);
constexpr int64_t block_size = 256;
// max shared memory is 2^18 * 4 = 1M
auto const shared_mem_size = num_registers_per_sketch * sizeof(int32_t);
reduce_hllpp_kernel<block_size>
<<<1, block_size, shared_mem_size, stream.value()>>>(*d_hashs, output_tmp.begin(), precision);

// 3. compact to longs
auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1;
auto const results_iter = cudf::detail::make_counting_transform_iterator(0, [&](int i) {
return make_numeric_column(
data_type{type_id::INT64}, 1 /**num_groups*/, mask_state::ALL_VALID, stream, mr);
});
auto children = std::vector<std::unique_ptr<column>>(results_iter, results_iter + num_long_cols);
auto d_results = [&] {
auto host_results_pointer_iter =
thrust::make_transform_iterator(children.begin(), [](auto const& results_column) {
return results_column->mutable_view().template data<int64_t>();
});
auto host_results_pointers =
std::vector<int64_t*>(host_results_pointer_iter, host_results_pointer_iter + children.size());
return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr);
}();
auto const num_compact_threads = num_long_cols;
auto const num_compact_blocks = cudf::util::div_rounding_up_safe(num_compact_threads, block_size);
compact_kernel<<<num_compact_blocks, block_size, 0, stream.value()>>>(
1 /**num_groups*/, num_registers_per_sketch, d_results, output_tmp);

// 4. create scalar
auto host_results_view_iter = thrust::make_transform_iterator(
children.begin(), [](auto const& results_column) { return results_column->view(); });
auto views =
std::vector<column_view>(host_results_view_iter, host_results_view_iter + num_long_cols);
auto table_view = cudf::table_view{views};
auto table = cudf::table(table_view);
return std::make_unique<cudf::struct_scalar>(std::move(table), true, stream, mr);
}

CUDF_KERNEL void reduce_merge_hll_kernel_vertically(cudf::device_span<int64_t const*> sketch_longs,
size_type num_sketches,
int num_registers_per_sketch,
int* const output)
{
auto const tid = cudf::detail::grid_1d::global_thread_id();
if (tid >= num_registers_per_sketch) { return; }
auto long_idx = tid / REGISTERS_PER_LONG;
auto reg_idx_in_long = tid % REGISTERS_PER_LONG;
int max = 0;
for (auto row_idx = 0; row_idx < num_sketches; row_idx++) {
int reg_v = get_register_value(sketch_longs[long_idx][row_idx], reg_idx_in_long);
if (reg_v > max) { max = reg_v; }
}
output[tid] = max;
}

std::unique_ptr<scalar> reduce_merge_hllpp(column_view const& input,
int64_t const precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
// create device input
int64_t num_registers_per_sketch = 1 << precision;
auto num_long_cols = num_registers_per_sketch / REGISTERS_PER_LONG + 1;
cudf::structs_column_view scv(input);
auto const input_iter = cudf::detail::make_counting_transform_iterator(
0, [&](int i) { return scv.get_sliced_child(i, stream).begin<int64_t>(); });
auto input_cols = std::vector<int64_t const*>(input_iter, input_iter + num_long_cols);
auto d_inputs = cudf::detail::make_device_uvector_async(input_cols, stream, mr);

// create one row output
auto const results_iter = cudf::detail::make_counting_transform_iterator(0, [&](int i) {
return make_numeric_column(
data_type{type_id::INT64}, 1 /** num_rows */, mask_state::ALL_VALID, stream, mr);
});
auto children = std::vector<std::unique_ptr<column>>(results_iter, results_iter + num_long_cols);
auto d_results = [&] {
auto host_results_pointer_iter =
thrust::make_transform_iterator(children.begin(), [](auto const& results_column) {
return results_column->mutable_view().template data<int64_t>();
});
auto host_results_pointers =
std::vector<int64_t*>(host_results_pointer_iter, host_results_pointer_iter + children.size());
return cudf::detail::make_device_uvector_async(host_results_pointers, stream, mr);
}();

// execute merge kernel
auto num_threads = num_registers_per_sketch;
constexpr int64_t block_size = 256;
auto num_blocks = cudf::util::div_rounding_up_safe(num_threads, block_size);
auto output_cache = rmm::device_uvector<int32_t>(num_registers_per_sketch, stream, mr);
reduce_merge_hll_kernel_vertically<<<num_blocks, block_size, 0, stream.value()>>>(
d_inputs, input.size(), num_registers_per_sketch, output_cache.begin());

// compact to longs
auto const num_compact_threads = num_long_cols;
auto const num_compact_blocks = cudf::util::div_rounding_up_safe(num_compact_threads, block_size);
compact_kernel<<<num_compact_blocks, block_size, 0, stream.value()>>>(
1 /** num_groups **/, num_registers_per_sketch, d_results, output_cache);

// create scalar
auto host_results_view_iter = thrust::make_transform_iterator(
children.begin(), [](auto const& results_column) { return results_column->view(); });
auto views =
std::vector<column_view>(host_results_view_iter, host_results_view_iter + num_long_cols);
auto table_view = cudf::table_view{views};
auto table = cudf::table(table_view);
return std::make_unique<cudf::struct_scalar>(std::move(table), true, stream, mr);
}

} // namespace

/**
Expand All @@ -640,33 +809,71 @@ std::unique_ptr<column> group_hyper_log_log_plus_plus(
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_EXPECTS(precision >= 4 && precision <= 18, "HLL++ requires precision in range: [4, 18]");
auto input_type =
cudf::is_dictionary(input.type()) ? dictionary_column_view(input).keys().type() : input.type();

return group_hllpp(input, num_groups, group_lables, precision, stream, mr);
CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4.");
auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision;
return group_hllpp(input, num_groups, group_lables, adjust_precision, stream, mr);
}

/**
* Merge sketches in the same group.
* Input is a struct column with multiple long columns which is consistent with Spark.
*/
std::unique_ptr<column> group_merge_hyper_log_log_plus_plus(
column_view const& values,
column_view const& input,
int64_t const num_groups,
cudf::device_span<size_type const> group_lables,
int64_t const precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_EXPECTS(precision >= 4 && precision <= 18, "HLL++ requires precision in range: [4, 18]");
CUDF_EXPECTS(values.type().id() == type_id::STRUCT,
CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4.");
CUDF_EXPECTS(input.type().id() == type_id::STRUCT,
"HyperLogLogPlusPlus buffer type must be a STRUCT of long columns.");
for (auto i = 0; i < input.num_children(); i++) {
CUDF_EXPECTS(input.child(i).type().id() == type_id::INT64,
"HyperLogLogPlusPlus buffer type must be a STRUCT of long columns.");
}
auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision;
auto expected_num_longs = (1 << adjust_precision) / REGISTERS_PER_LONG + 1;
CUDF_EXPECTS(input.num_children() == expected_num_longs,
"The num of long columns in input is incorrect.");
return merge_hyper_log_log(input, num_groups, group_lables, adjust_precision, stream, mr);
}

/**
* Compute the hashs of the input column, then generate a sketch stored in a struct of long scalar.
*/
std::unique_ptr<scalar> reduce_hyper_log_log_plus_plus(column_view const& input,
int64_t const precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4.");
auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision;
return reduce_hllpp(input, adjust_precision, stream, mr);
}

/**
* Merge all sketches in the input column into one sketch.
* Input is a struct column with multiple long columns which is consistent with Spark.
*/
std::unique_ptr<scalar> reduce_merge_hyper_log_log_plus_plus(column_view const& input,
int64_t const precision,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4.");
CUDF_EXPECTS(input.type().id() == type_id::STRUCT,
"HyperLogLogPlusPlus buffer type must be a STRUCT of long columns.");
for (auto i = 0; i < values.num_children(); i++) {
CUDF_EXPECTS(values.child(i).type().id() == type_id::INT64,
for (auto i = 0; i < input.num_children(); i++) {
CUDF_EXPECTS(input.child(i).type().id() == type_id::INT64,
"HyperLogLogPlusPlus buffer type must be a STRUCT of long columns.");
}
return merge_hyper_log_log(values, num_groups, group_lables, precision, stream, mr);
auto adjust_precision = precision > MAX_PRECISION ? MAX_PRECISION : precision;
auto expected_num_longs = (1 << adjust_precision) / REGISTERS_PER_LONG + 1;
CUDF_EXPECTS(input.num_children() == expected_num_longs,
"The num of long columns in input is incorrect.");
return reduce_merge_hllpp(input, adjust_precision, stream, mr);
}

} // namespace detail
Expand Down
10 changes: 9 additions & 1 deletion cpp/src/reductions/reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/memory_resource.hpp>
#include <cudf/utilities/type_checks.hpp>

#include <cudf/detail/hyper_log_log_plus_plus/hyper_log_log_plus_plus.hpp>
#include <rmm/cuda_stream_view.hpp>

#include <utility>
Expand Down Expand Up @@ -144,6 +144,14 @@ struct reduce_dispatch_functor {
auto td_agg = static_cast<cudf::detail::merge_tdigest_aggregation const&>(agg);
return tdigest::detail::reduce_merge_tdigest(col, td_agg.max_centroids, stream, mr);
}
case aggregation::HLLPP: {
auto hllpp_agg = static_cast<cudf::detail::hyper_log_log_aggregation const&>(agg);
return cudf::groupby::detail::reduce_hyper_log_log_plus_plus(col, hllpp_agg.precision, stream, mr);
}
case aggregation::MERGE_HLLPP: {
auto hllpp_agg = static_cast<cudf::detail::merge_hyper_log_log_aggregation const&>(agg);
return cudf::groupby::detail::reduce_merge_hyper_log_log_plus_plus(col, hllpp_agg.precision, stream, mr);
}
default: CUDF_FAIL("Unsupported reduction operator");
}
}
Expand Down

0 comments on commit 77ea21c

Please sign in to comment.