diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala index 4fbc612591b..4d691703b74 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPartitioning.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2020-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,15 +29,12 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.GpuShuffleEnv import org.apache.spark.sql.vectorized.ColumnarBatch -object GpuPartitioning { - // The maximum size of an Array minus a bit for overhead for metadata - val MaxCpuBatchSize = 2147483639L - 2048L -} - trait GpuPartitioning extends Partitioning { - private[this] val (maxCompressionBatchSize, _useGPUShuffle, _useMultiThreadedShuffle) = { + private[this] val ( + maxCpuBatchSize, maxCompressionBatchSize, _useGPUShuffle, _useMultiThreadedShuffle) = { val rapidsConf = new RapidsConf(SQLConf.get) - (rapidsConf.shuffleCompressionMaxBatchMemory, + (rapidsConf.shuffleParitioningMaxCpuBatchSize, + rapidsConf.shuffleCompressionMaxBatchMemory, GpuShuffleEnv.useGPUShuffle(rapidsConf), GpuShuffleEnv.useMultiThreadedShuffle(rapidsConf)) } @@ -124,7 +121,7 @@ trait GpuPartitioning extends Partitioning { // This should be a temp work around. partitionColumns.foreach(_.getBase.getNullCount) val totalInputSize = GpuColumnVector.getTotalDeviceMemoryUsed(partitionColumns) - val mightNeedToSplit = totalInputSize > GpuPartitioning.MaxCpuBatchSize + val mightNeedToSplit = totalInputSize > maxCpuBatchSize // We have to wrap the NvtxWithMetrics over both copyToHostAsync and corresponding CudaSync, // because the copyToHostAsync calls above are not guaranteed to be asynchronous (e.g.: when @@ -164,7 +161,7 @@ trait GpuPartitioning extends Partitioning { case (batch, part) => val totalSize = SlicedGpuColumnVector.getTotalHostMemoryUsed(batch) val numOutputBatches = - math.ceil(totalSize.toDouble / GpuPartitioning.MaxCpuBatchSize).toInt + math.ceil(totalSize.toDouble / maxCpuBatchSize).toInt if (numOutputBatches > 1) { // For now we are going to slice it on number of rows instead of looking // at each row to try and decide. If we get in trouble we can probably diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 609155e4b86..b0cd3014b91 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2025, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -1976,6 +1976,20 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .integerConf .createWithDefault(20) + val SHUFFLE_PARTITIONING_MAX_CPU_BATCH_SIZE = + conf("spark.rapids.shuffle.partitioning.maxCpuBatchSize") + .doc("The maximum size of a sliced batch output to the CPU side " + + "when GPU partitioning shuffle data. This can be used to limit the peak on-heap memory " + + "used by CPU to serialize the shuffle data, especially for skew data cases. " + + "The default value is maximum size of an Array minus 2k overhead (2147483639L - 2048L), " + + "user should only set a smaller value than default value to avoid subsequent failures.") + .internal() + .bytesConf(ByteUnit.BYTE) + .checkValue(v => v > 0 && v <= 2147483639L - 2048L, + s"maxCpuBatchSize must be positive and not exceed ${2147483639L - 2048L} bytes.") + // The maximum size of an Array minus a bit for overhead for metadata + .createWithDefault(2147483639L - 2048L) + val SHUFFLE_MULTITHREADED_READER_THREADS = conf("spark.rapids.shuffle.multiThreaded.reader.threads") .doc("The number of threads to use for reading shuffle blocks per executor in the " + @@ -3176,6 +3190,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val shuffleMultiThreadedReaderThreads: Int = get(SHUFFLE_MULTITHREADED_READER_THREADS) + lazy val shuffleParitioningMaxCpuBatchSize: Long = get(SHUFFLE_PARTITIONING_MAX_CPU_BATCH_SIZE) + lazy val shuffleKudoSerializerEnabled: Boolean = get(SHUFFLE_KUDO_SERIALIZER_ENABLED) def isUCXShuffleManagerMode: Boolean =