-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hyper log log plus plus(HLL++) #2522
base: branch-25.02
Are you sure you want to change the base?
Changes from 6 commits
03c0f5a
df8b223
2daca3f
3afdfde
956af39
8aaf0f6
5bfb544
f8c6a02
208d67e
e29d5a1
9f7ec44
3c70a30
3e22512
aa7ca68
f0970c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -196,6 +196,7 @@ add_library( | |||||
src/HashJni.cpp | ||||||
src/HistogramJni.cpp | ||||||
src/HostTableJni.cpp | ||||||
src/HLLPPJni.cpp | ||||||
src/JSONUtilsJni.cpp | ||||||
src/NativeParquetJni.cpp | ||||||
src/ParseURIJni.cpp | ||||||
|
@@ -204,6 +205,7 @@ add_library( | |||||
src/SparkResourceAdaptorJni.cpp | ||||||
src/SubStringIndexJni.cpp | ||||||
src/ZOrderJni.cpp | ||||||
src/HLLPP.cu | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about
Suggested change
This name is also applied for the |
||||||
src/bloom_filter.cu | ||||||
src/case_when.cu | ||||||
src/cast_decimal_to_string.cu | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,102 @@ | ||||||||||||
/* | ||||||||||||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
* | ||||||||||||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||
* you may not use this file except in compliance with the License. | ||||||||||||
* You may obtain a copy of the License at | ||||||||||||
* | ||||||||||||
* http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||
* | ||||||||||||
* Unless required by applicable law or agreed to in writing, software | ||||||||||||
* distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||
* See the License for the specific language governing permissions and | ||||||||||||
* limitations under the License. | ||||||||||||
*/ | ||||||||||||
#include "HLLPP.hpp" | ||||||||||||
|
||||||||||||
#include <cudf/column/column.hpp> | ||||||||||||
#include <cudf/column/column_device_view.cuh> | ||||||||||||
#include <cudf/column/column_factories.hpp> | ||||||||||||
#include <cudf/column/column_view.hpp> | ||||||||||||
#include <cudf/detail/iterator.cuh> | ||||||||||||
#include <cudf/lists/lists_column_view.hpp> | ||||||||||||
#include <cudf/types.hpp> | ||||||||||||
|
||||||||||||
#include <rmm/cuda_stream_view.hpp> | ||||||||||||
#include <rmm/device_uvector.hpp> | ||||||||||||
#include <rmm/exec_policy.hpp> | ||||||||||||
|
||||||||||||
#include <cuco/detail/hyperloglog/finalizer.cuh> | ||||||||||||
#include <thrust/for_each.h> | ||||||||||||
#include <thrust/iterator/counting_iterator.h> | ||||||||||||
|
||||||||||||
namespace spark_rapids_jni { | ||||||||||||
|
||||||||||||
namespace { | ||||||||||||
|
||||||||||||
// The number of bits required by register value. Register value stores num of zeros. | ||||||||||||
// XXHash64 value is 64 bits, it's safe to use 6 bits to store a register value. | ||||||||||||
constexpr int REGISTER_VALUE_BITS = 6; | ||||||||||||
|
||||||||||||
// MASK binary 6 bits: 111111 | ||||||||||||
constexpr uint64_t MASK = (1L << REGISTER_VALUE_BITS) - 1L; | ||||||||||||
|
||||||||||||
// One long stores 10 register values | ||||||||||||
constexpr int REGISTERS_PER_LONG = 64 / REGISTER_VALUE_BITS; | ||||||||||||
|
||||||||||||
__device__ inline int get_register_value(int64_t const long_10_registers, int reg_idx) | ||||||||||||
{ | ||||||||||||
int64_t shift_mask = MASK << (REGISTER_VALUE_BITS * reg_idx); | ||||||||||||
int64_t v = (long_10_registers & shift_mask) >> (REGISTER_VALUE_BITS * reg_idx); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
return static_cast<int>(v); | ||||||||||||
} | ||||||||||||
|
||||||||||||
struct estimate_fn { | ||||||||||||
cudf::device_span<int64_t const*> sketch_longs; | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
int const precision; | ||||||||||||
int64_t* const out; | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We now favor non-const members so the functor can be moved by the compiler if needed.
Suggested change
|
||||||||||||
|
||||||||||||
__device__ void operator()(cudf::size_type const idx) const | ||||||||||||
{ | ||||||||||||
auto const num_regs = 1ull << precision; | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems to be used to compare with signed int later, thus it should not be unsigned here.
Suggested change
|
||||||||||||
double sum = 0; | ||||||||||||
int zeroes = 0; | ||||||||||||
|
||||||||||||
for (auto reg_idx = 0; reg_idx < num_regs; ++reg_idx) { | ||||||||||||
// each long contains 10 register values | ||||||||||||
int long_col_idx = reg_idx / REGISTERS_PER_LONG; | ||||||||||||
int reg_idx_in_long = reg_idx % REGISTERS_PER_LONG; | ||||||||||||
int reg = get_register_value(sketch_longs[long_col_idx][idx], reg_idx_in_long); | ||||||||||||
sum += double{1} / static_cast<double>(1ull << reg); | ||||||||||||
zeroes += reg == 0; | ||||||||||||
} | ||||||||||||
|
||||||||||||
auto const finalize = cuco::hyperloglog_ns::detail::finalizer(precision); | ||||||||||||
out[idx] = finalize(sum, zeroes); | ||||||||||||
} | ||||||||||||
}; | ||||||||||||
|
||||||||||||
} // end anonymous namespace | ||||||||||||
|
||||||||||||
std::unique_ptr<cudf::column> estimate_from_hll_sketches(cudf::column_view const& input, | ||||||||||||
int precision, | ||||||||||||
rmm::cuda_stream_view stream, | ||||||||||||
rmm::device_async_resource_ref mr) | ||||||||||||
{ | ||||||||||||
CUDF_EXPECTS(precision >= 4, "HyperLogLogPlusPlus requires precision is bigger than 4."); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
auto const input_iter = cudf::detail::make_counting_transform_iterator( | ||||||||||||
0, [&](int i) { return input.child(i).begin<int64_t>(); }); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need a |
||||||||||||
auto input_cols = std::vector<int64_t const*>(input_iter, input_iter + input.num_children()); | ||||||||||||
auto d_inputs = cudf::detail::make_device_uvector_async(input_cols, stream, mr); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cudf::get_current_device_resource_ref()): |
||||||||||||
auto result = cudf::make_numeric_column( | ||||||||||||
cudf::data_type{cudf::type_id::INT64}, input.size(), cudf::mask_state::ALL_VALID, stream); | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need such all-valid null mask? How about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tested Spark behavior, for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant, if all rows are valid, we don't need to allocate a null mask.
Suggested change
|
||||||||||||
// evaluate from struct<long, ..., long> | ||||||||||||
thrust::for_each_n(rmm::exec_policy_nosync(stream), | ||||||||||||
thrust::make_counting_iterator(0), | ||||||||||||
input.size(), | ||||||||||||
estimate_fn{d_inputs, precision, result->mutable_view().data<int64_t>()}); | ||||||||||||
return result; | ||||||||||||
} | ||||||||||||
|
||||||||||||
} // namespace spark_rapids_jni |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,32 @@ | ||||||||
/* | ||||||||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||||||||
* | ||||||||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||
* you may not use this file except in compliance with the License. | ||||||||
* You may obtain a copy of the License at | ||||||||
* | ||||||||
* http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
* | ||||||||
* Unless required by applicable law or agreed to in writing, software | ||||||||
* distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
* See the License for the specific language governing permissions and | ||||||||
* limitations under the License. | ||||||||
*/ | ||||||||
|
||||||||
#include <cudf/column/column.hpp> | ||||||||
#include <cudf/column/column_view.hpp> | ||||||||
#include <cudf/utilities/default_stream.hpp> | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
#include <rmm/cuda_stream_view.hpp> | ||||||||
#include <rmm/resource_ref.hpp> | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
namespace spark_rapids_jni { | ||||||||
|
||||||||
std::unique_ptr<cudf::column> estimate_from_hll_sketches( | ||||||||
cudf::column_view const& input, | ||||||||
int precision, | ||||||||
rmm::cuda_stream_view stream = cudf::get_default_stream(), | ||||||||
rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()); | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
} // namespace spark_rapids_jni |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "HLLPP.hpp" | ||
#include "cudf_jni_apis.hpp" | ||
|
||
extern "C" { | ||
|
||
JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_HLLPP_estimateDistinctValueFromSketches( | ||
JNIEnv* env, jclass, jlong sketches, jint precision) | ||
{ | ||
JNI_NULL_CHECK(env, sketches, "Sketch column is null", 0); | ||
try { | ||
cudf::jni::auto_set_device(env); | ||
auto const sketch_view = reinterpret_cast<cudf::column_view const*>(sketches); | ||
return cudf::jni::ptr_as_jlong( | ||
spark_rapids_jni::estimate_from_hll_sketches(*sketch_view, precision).release()); | ||
} | ||
CATCH_STD(env, 0); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,48 @@ | ||||||
/* | ||||||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||||||
* | ||||||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
* you may not use this file except in compliance with the License. | ||||||
* You may obtain a copy of the License at | ||||||
* | ||||||
* http://www.apache.org/licenses/LICENSE-2.0 | ||||||
* | ||||||
* Unless required by applicable law or agreed to in writing, software | ||||||
* distributed under the License is distributed on an "AS IS" BASIS, | ||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
* See the License for the specific language governing permissions and | ||||||
* limitations under the License. | ||||||
*/ | ||||||
|
||||||
package com.nvidia.spark.rapids.jni; | ||||||
|
||||||
import ai.rapids.cudf.ColumnVector; | ||||||
import ai.rapids.cudf.ColumnView; | ||||||
import ai.rapids.cudf.NativeDepsLoader; | ||||||
|
||||||
/** | ||||||
* HyperLogLogPlusPlus | ||||||
*/ | ||||||
public class HLLPP { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
static { | ||||||
NativeDepsLoader.loadNativeDeps(); | ||||||
} | ||||||
|
||||||
/** | ||||||
* Compute the approximate count distinct value from sketch values. | ||||||
* <p> | ||||||
* The input sketch values must be given in the format `Struct<INT64, INT64, ...>`, | ||||||
* The num of children is: num_registers_per_sketch / 10 + 1, here 10 means a INT64 contains | ||||||
* max 10 registers. Register value is 6 bits. The input is columnar data, e.g.: sketch 0 | ||||||
* is composed of by all the data of the children at index 0. | ||||||
* | ||||||
* @param input The sketch column which constains Struct<INT64, INT64, ...> values. | ||||||
* @param precision The num of bits for addressing. | ||||||
* @return A INT64 column with each value indicates the approximate count distinct value. | ||||||
*/ | ||||||
public static ColumnVector estimateDistinctValueFromSketches(ColumnView input, int precision) { | ||||||
return new ColumnVector(estimateDistinctValueFromSketches(input.getNativeView(), precision)); | ||||||
} | ||||||
|
||||||
private static native long estimateDistinctValueFromSketches(long inputHandle, int precision); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if this Java interface will no longer be needed after converting the code to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename to: |
||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Copyright (c) 2024, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.nvidia.spark.rapids.jni; | ||
|
||
import ai.rapids.cudf.GroupByAggregation; | ||
import ai.rapids.cudf.Table; | ||
|
||
import org.junit.jupiter.api.Test; | ||
|
||
|
||
public class HLLPPTest { | ||
|
||
@Test | ||
void testGroupByHLL() { | ||
// A trivial test: | ||
try (Table input = new Table.TestBuilder().column(1, 2, 3, 1, 2, 2, 1, 3, 3, 2) | ||
.column(0, 1, -2, 3, -4, -5, -6, 7, -8, 9) | ||
.build()){ | ||
input.groupBy(0).aggregate(GroupByAggregation.HLLPP(0) | ||
.onColumn(1)); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's try to be generic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to: HLLPPHostUDFJni
AggregationJni is too generic