requests;
+ requests.emplace_back();
+ requests[0].values = vals;
+ requests[0].aggregations.push_back(std::move(agg));
+ cudf::groupby::groupby gb_obj(
+ cudf::table_view({keys}), cudf::null_policy::INCLUDE, cudf::sorted::NO, {}, {});
+ [[maybe_unused]] auto const grp_result =
+ gb_obj.aggregate(requests, cudf::test::get_default_stream());
+ EXPECT_TRUE(test_run);
+ }
+}
diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java
index 379750bb0b7..2276b223740 100644
--- a/java/src/main/java/ai/rapids/cudf/Aggregation.java
+++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java
@@ -62,15 +62,16 @@ enum Kind {
LAG(23),
PTX(24),
CUDA(25),
- M2(26),
- MERGE_M2(27),
- RANK(28),
- DENSE_RANK(29),
- PERCENT_RANK(30),
- TDIGEST(31), // This can take a delta argument for accuracy level
- MERGE_TDIGEST(32), // This can take a delta argument for accuracy level
- HISTOGRAM(33),
- MERGE_HISTOGRAM(34);
+ HOST_UDF(26),
+ M2(27),
+ MERGE_M2(28),
+ RANK(29),
+ DENSE_RANK(30),
+ PERCENT_RANK(31),
+ TDIGEST(32), // This can take a delta argument for accuracy level
+ MERGE_TDIGEST(33), // This can take a delta argument for accuracy level
+ HISTOGRAM(34),
+ MERGE_HISTOGRAM(35);
final int nativeId;
@@ -385,6 +386,35 @@ public boolean equals(Object other) {
}
}
+ static final class HostUDFAggregation extends Aggregation {
+ private final HostUDFWrapper wrapper;
+
+ private HostUDFAggregation(HostUDFWrapper wrapper) {
+ super(Kind.HOST_UDF);
+ this.wrapper = wrapper;
+ }
+
+ @Override
+ long createNativeInstance() {
+ return Aggregation.createHostUDFAgg(wrapper.udfNativeHandle);
+ }
+
+ @Override
+ public int hashCode() {
+ return 31 * kind.hashCode() + wrapper.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ } else if (other instanceof HostUDFAggregation) {
+ return wrapper.equals(((HostUDFAggregation) other).wrapper);
+ }
+ return false;
+ }
+ }
+
protected final Kind kind;
protected Aggregation(Kind kind) {
@@ -837,6 +867,15 @@ static MergeSetsAggregation mergeSets(NullEquality nullEquality, NaNEquality nan
return new MergeSetsAggregation(nullEquality, nanEquality);
}
+ /**
+ * Host UDF aggregation, to execute a host-side user-defined function (UDF).
+ * @param wrapper The wrapper for the native host UDF instance.
+ * @return A new HostUDFAggregation instance
+ */
+ static HostUDFAggregation hostUDF(HostUDFWrapper wrapper) {
+ return new HostUDFAggregation(wrapper);
+ }
+
static final class LeadAggregation extends LeadLagAggregation {
private LeadAggregation(int offset, ColumnVector defaultOutput) {
super(Kind.LEAD, offset, defaultOutput);
@@ -990,4 +1029,9 @@ static MergeHistogramAggregation mergeHistogram() {
* Create a TDigest aggregation.
*/
private static native long createTDigestAgg(int kind, int delta);
+
+ /**
+ * Create a HOST_UDF aggregation.
+ */
+ private static native long createHostUDFAgg(long udfNativeHandle);
}
diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
index 0fae33927b6..27966ddfdd4 100644
--- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
+++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
@@ -277,6 +277,15 @@ public static GroupByAggregation mergeSets() {
return new GroupByAggregation(Aggregation.mergeSets());
}
+ /**
+ * Execute an aggregation using a host-side user-defined function (UDF).
+ * @param wrapper The wrapper for the native host UDF instance.
+ * @return A new GroupByAggregation instance
+ */
+ public static GroupByAggregation hostUDF(HostUDFWrapper wrapper) {
+ return new GroupByAggregation(Aggregation.hostUDF(wrapper));
+ }
+
/**
* Merge the partial sets produced by multiple CollectSetAggregations.
*
diff --git a/java/src/main/java/ai/rapids/cudf/HostUDFWrapper.java b/java/src/main/java/ai/rapids/cudf/HostUDFWrapper.java
new file mode 100644
index 00000000000..0b6ecf2e140
--- /dev/null
+++ b/java/src/main/java/ai/rapids/cudf/HostUDFWrapper.java
@@ -0,0 +1,34 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ai.rapids.cudf;
+
+/**
+ * A wrapper around native host UDF aggregations.
+ *
+ * This class is used to store the native handle of a host UDF aggregation and is used as
+ * a proxy object to compute hash code and compare two host UDF aggregations for equality.
+ *
+ * A new host UDF aggregation implementation must extend this class and override the
+ * {@code hashCode} and {@code equals} methods for such purposes.
+ */
+public abstract class HostUDFWrapper {
+ public final long udfNativeHandle;
+
+ public HostUDFWrapper(long udfNativeHandle) {
+ this.udfNativeHandle = udfNativeHandle;
+ }
+}
diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp
index c40f1c55500..dd41c677761 100644
--- a/java/src/main/native/src/AggregationJni.cpp
+++ b/java/src/main/native/src/AggregationJni.cpp
@@ -17,6 +17,7 @@
#include "cudf_jni_apis.hpp"
#include
+#include
extern "C" {
@@ -80,25 +81,28 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createNoParamAgg(JNIEnv*
// case 23: LAG
// case 24: PTX
// case 25: CUDA
- case 26: // M2
+ // case 26: HOST_UDF
+ case 27: // M2
return cudf::make_m2_aggregation();
- case 27: // MERGE_M2
+ case 28: // MERGE_M2
return cudf::make_merge_m2_aggregation();
- case 28: // RANK
+ case 29: // RANK
return cudf::make_rank_aggregation(
cudf::rank_method::MIN, {}, cudf::null_policy::INCLUDE);
- case 29: // DENSE_RANK
+ case 30: // DENSE_RANK
return cudf::make_rank_aggregation(
cudf::rank_method::DENSE, {}, cudf::null_policy::INCLUDE);
- case 30: // ANSI SQL PERCENT_RANK
+ case 31: // ANSI SQL PERCENT_RANK
return cudf::make_rank_aggregation(cudf::rank_method::MIN,
{},
cudf::null_policy::INCLUDE,
{},
cudf::rank_percentage::ONE_NORMALIZED);
- case 33: // HISTOGRAM
+ // case 32: TDIGEST
+ // case 33: MERGE_TDIGEST
+ case 34: // HISTOGRAM
return cudf::make_histogram_aggregation();
- case 34: // MERGE_HISTOGRAM
+ case 35: // MERGE_HISTOGRAM
return cudf::make_merge_histogram_aggregation();
default: throw std::logic_error("Unsupported No Parameter Aggregation Operation");
@@ -160,10 +164,10 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createTDigestAgg(JNIEnv*
std::unique_ptr ret;
// These numbers come from Aggregation.java and must stay in sync
switch (kind) {
- case 31: // TDIGEST
+ case 32: // TDIGEST
ret = cudf::make_tdigest_aggregation(delta);
break;
- case 32: // MERGE_TDIGEST
+ case 33: // MERGE_TDIGEST
ret = cudf::make_merge_tdigest_aggregation(delta);
break;
default: throw std::logic_error("Unsupported TDigest Aggregation Operation");
@@ -296,4 +300,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMergeSetsAgg(JNIEn
CATCH_STD(env, 0);
}
+JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createHostUDFAgg(JNIEnv* env,
+ jclass class_object,
+ jlong udf_native_handle)
+{
+ JNI_NULL_CHECK(env, udf_native_handle, "udf_native_handle is null", 0);
+ try {
+ cudf::jni::auto_set_device(env);
+ auto const udf_ptr = reinterpret_cast(udf_native_handle);
+ auto output = cudf::make_host_udf_aggregation(udf_ptr->clone());
+ return reinterpret_cast(output.release());
+ }
+ CATCH_STD(env, 0);
+}
+
} // extern "C"