From a24d9b8f00e03acc9b6be1e9d03dffac53e6e246 Mon Sep 17 00:00:00 2001 From: Chong Gao Date: Wed, 1 Jan 2025 08:55:27 +0800 Subject: [PATCH] Fix NullType --- .../aggregate/GpuHyperLogLogPlusPlus.scala | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala index f78ef47a5e4..0fc287e11bf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/aggregate/GpuHyperLogLogPlusPlus.scala @@ -21,7 +21,7 @@ import scala.collection.immutable.Seq import ai.rapids.cudf import ai.rapids.cudf.{DType, GroupByAggregation, ReductionAggregation} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.Arm.withResourceIfAllowed +import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed} import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression import com.nvidia.spark.rapids.jni.HyperLogLogPlusPlusHostUDF import com.nvidia.spark.rapids.shims.ShimExpression @@ -34,11 +34,24 @@ import org.apache.spark.sql.vectorized.ColumnarBatch case class CudfHLLPP(override val dataType: DataType, precision: Int) extends CudfAggregate { override lazy val reductionAggregate: cudf.ColumnVector => cudf.Scalar = - (input: cudf.ColumnVector) => input.reduce( - ReductionAggregation.hostUDF( - HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF( - HyperLogLogPlusPlusHostUDF.AggregationType.Reduction, precision)), - DType.STRUCT) + (input: cudf.ColumnVector) => { + if (input.getNullCount == input.getRowCount) { + // For NullType column or all values are null, + // return a struct scalar: struct(0L, 0L, ..., 0L) + val numCols = (1 << precision) / 10 + 1 + withResource(cudf.ColumnVector.fromLongs(0L)) { cv => + // Underlying uses deep-copy, so we can reuse this `cv` and fill multiple times. + val cvs: Array[cudf.ColumnView] = Array.fill(numCols)(cv) + cudf.Scalar.structFromColumnViews(cvs: _*) + } + } else { + input.reduce( + ReductionAggregation.hostUDF( + HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF( + HyperLogLogPlusPlusHostUDF.AggregationType.Reduction, precision)), + DType.STRUCT) + } + } override lazy val groupByAggregate: GroupByAggregation = GroupByAggregation.hostUDF( HyperLogLogPlusPlusHostUDF.createHLLPPHostUDF(