From 6b90b2fffb9035921fab6cd105469645c09a7b4d Mon Sep 17 00:00:00 2001 From: Jihoon Son Date: Mon, 25 Nov 2024 14:55:44 -0800 Subject: [PATCH] Add support for asynchronous writing for parquet (#11730) * Support async writing for query output Signed-off-by: Jihoon Son * doc change * use a long timeout * fix test failure due to a race * fix flaky test * address comments * fix the config name for hold gpu * Update sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStream.scala Simplify case arm Co-authored-by: Gera Shegalov * address comments * missing doc change * use trampoline --------- Signed-off-by: Jihoon Son Co-authored-by: Gera Shegalov --- .../spark/rapids/ColumnarOutputWriter.scala | 36 +++- .../spark/rapids/GpuParquetFileFormat.scala | 13 +- .../com/nvidia/spark/rapids/Plugin.scala | 3 + .../com/nvidia/spark/rapids/RapidsConf.scala | 35 ++++ .../rapids/io/async/AsyncOutputStream.scala | 186 ++++++++++++++++++ .../rapids/io/async/ThrottlingExecutor.scala | 43 ++++ .../rapids/io/async/TrafficController.scala | 142 +++++++++++++ .../io/async/AsyncOutputStreamSuite.scala | 162 +++++++++++++++ .../io/async/ThrottlingExecutorSuite.scala | 145 ++++++++++++++ .../io/async/TrafficControllerSuite.scala | 101 ++++++++++ 10 files changed, 855 insertions(+), 11 deletions(-) create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStream.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala create mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala create mode 100644 sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStreamSuite.scala create mode 100644 sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala create mode 100644 sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala index 69157c046b6..df62683d346 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ColumnarOutputWriter.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,11 +25,13 @@ import com.nvidia.spark.Retryable import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitSpillableInHalfByRows, withRestoreOnRetry, withRetry, withRetryNoSplit} +import com.nvidia.spark.rapids.io.async.{AsyncOutputStream, TrafficController} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FSDataOutputStream, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.rapids.{ColumnarWriteTaskStatsTracker, GpuWriteTaskStatsTracker} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch @@ -70,21 +72,31 @@ abstract class ColumnarOutputWriterFactory extends Serializable { abstract class ColumnarOutputWriter(context: TaskAttemptContext, dataSchema: StructType, rangeName: String, - includeRetry: Boolean) extends HostBufferConsumer { + includeRetry: Boolean, + holdGpuBetweenBatches: Boolean = false) extends HostBufferConsumer with Logging { protected val tableWriter: TableWriter protected val conf: Configuration = context.getConfiguration - // This is implemented as a method to make it easier to subclass - // ColumnarOutputWriter in the tests, and override this behavior. - protected def getOutputStream: FSDataOutputStream = { + private val trafficController: Option[TrafficController] = TrafficController.getInstance + + private def openOutputStream(): OutputStream = { val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(conf) fs.create(hadoopPath, false) } - protected val outputStream: FSDataOutputStream = getOutputStream + // This is implemented as a method to make it easier to subclass + // ColumnarOutputWriter in the tests, and override this behavior. + protected def getOutputStream: OutputStream = { + trafficController.map(controller => { + logWarning("Async output write enabled") + new AsyncOutputStream(() => openOutputStream(), controller) + }).getOrElse(openOutputStream()) + } + + protected val outputStream: OutputStream = getOutputStream private[this] val tempBuffer = new Array[Byte](128 * 1024) private[this] var anythingWritten = false @@ -166,7 +178,11 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, } // we successfully buffered to host memory, release the semaphore and write // the buffered data to the FS - GpuSemaphore.releaseIfNecessary(TaskContext.get) + if (!holdGpuBetweenBatches) { + logDebug("Releasing semaphore between batches") + GpuSemaphore.releaseIfNecessary(TaskContext.get) + } + writeBufferedData() updateStatistics(writeStartTime, gpuTime, statsTrackers) spillableBatch.numRows() @@ -202,6 +218,10 @@ abstract class ColumnarOutputWriter(context: TaskAttemptContext, // buffer an empty batch on close() to work around issues in cuDF // where corrupt files can be written if nothing is encoded via the writer. anythingWritten = true + + // tableWriter.write() serializes the table into the HostMemoryBuffer, and buffers it + // by calling handleBuffer() on the ColumnarOutputWriter. It may not write to the + // output stream just yet. tableWriter.write(table) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 25105386b3d..2b5f246e56a 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -271,13 +271,19 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { s"Set Parquet option ${ParquetOutputFormat.JOB_SUMMARY_LEVEL} to NONE.") } + val asyncOutputWriteEnabled = RapidsConf.ENABLE_ASYNC_OUTPUT_WRITE.get(sqlConf) + // holdGpuBetweenBatches is on by default if asyncOutputWriteEnabled is on + val holdGpuBetweenBatches = RapidsConf.ASYNC_QUERY_OUTPUT_WRITE_HOLD_GPU_IN_TASK.get(sqlConf) + .getOrElse(asyncOutputWriteEnabled) + new ColumnarOutputWriterFactory { override def newInstance( path: String, dataSchema: StructType, context: TaskAttemptContext): ColumnarOutputWriter = { new GpuParquetWriter(path, dataSchema, compressionType, outputTimestampType.toString, - dateTimeRebaseMode, timestampRebaseMode, context, parquetFieldIdWriteEnabled) + dateTimeRebaseMode, timestampRebaseMode, context, parquetFieldIdWriteEnabled, + holdGpuBetweenBatches) } override def getFileExtension(context: TaskAttemptContext): String = { @@ -299,8 +305,9 @@ class GpuParquetWriter( dateRebaseMode: DateTimeRebaseMode, timestampRebaseMode: DateTimeRebaseMode, context: TaskAttemptContext, - parquetFieldIdEnabled: Boolean) - extends ColumnarOutputWriter(context, dataSchema, "Parquet", true) { + parquetFieldIdEnabled: Boolean, + holdGpuBetweenBatches: Boolean) + extends ColumnarOutputWriter(context, dataSchema, "Parquet", true, holdGpuBetweenBatches) { override def throwIfRebaseNeededInExceptionMode(batch: ColumnarBatch): Unit = { val cols = GpuColumnVector.extractBases(batch) cols.foreach { col => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index e20b21da520..5127c7899a8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -31,6 +31,7 @@ import com.nvidia.spark.DFUDFPlugin import com.nvidia.spark.rapids.RapidsConf.AllowMultipleJars import com.nvidia.spark.rapids.RapidsPluginUtils.buildInfoEvent import com.nvidia.spark.rapids.filecache.{FileCache, FileCacheLocalityManager, FileCacheLocalityMsg} +import com.nvidia.spark.rapids.io.async.TrafficController import com.nvidia.spark.rapids.jni.GpuTimeZoneDB import com.nvidia.spark.rapids.python.PythonWorkerSemaphore import org.apache.commons.lang3.exception.ExceptionUtils @@ -554,6 +555,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { extraExecutorPlugins.foreach(_.init(pluginContext, extraConf)) GpuSemaphore.initialize() FileCache.init(pluginContext) + TrafficController.initialize(conf) } catch { // Exceptions in executor plugin can cause a single thread to die but the executor process // sticks around without any useful info until it hearbeat times out. Print what happened @@ -656,6 +658,7 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging { extraExecutorPlugins.foreach(_.shutdown()) FileCache.shutdown() GpuCoreDumpHandler.shutdown() + TrafficController.shutdown() } override def onTaskFailed(failureReason: TaskFailedReason): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index e22b8f53497..ab7a788d205 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -2406,6 +2406,36 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression. .booleanConf .createWithDefault(false) + val ENABLE_ASYNC_OUTPUT_WRITE = + conf("spark.rapids.sql.asyncWrite.queryOutput.enabled") + .doc("Option to turn on the async query output write. During the final output write, the " + + "task first copies the output to the host memory, and then writes it into the storage. " + + "When this option is enabled, the task will asynchronously write the output in the host " + + "memory to the storage. Only the Parquet format is supported currently.") + .internal() + .booleanConf + .createWithDefault(false) + + val ASYNC_QUERY_OUTPUT_WRITE_HOLD_GPU_IN_TASK = + conf("spark.rapids.sql.queryOutput.holdGpuInTask") + .doc("Option to hold GPU semaphore between batch processing during the final output write. " + + "This option could degrade query performance if it is enabled without the async query " + + "output write. It is recommended to consider enabling this option only when " + + s"${ENABLE_ASYNC_OUTPUT_WRITE.key} is set. This option is off by default when the async " + + "query output write is disabled; otherwise, it is on.") + .internal() + .booleanConf + .createOptional + + val ASYNC_WRITE_MAX_IN_FLIGHT_HOST_MEMORY_BYTES = + conf("spark.rapids.sql.asyncWrite.maxInFlightHostMemoryBytes") + .doc("Maximum number of host memory bytes per executor that can be in-flight for async " + + "query output write. Tasks may be blocked if the total host memory bytes in-flight " + + "exceeds this value.") + .internal() + .bytesConf(ByteUnit.BYTE) + .createWithDefault(2L * 1024 * 1024 * 1024) + private def printSectionHeader(category: String): Unit = println(s"\n### $category") @@ -2663,6 +2693,9 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val isFoldableNonLitAllowed: Boolean = get(FOLDABLE_NON_LIT_ALLOWED) + lazy val asyncWriteMaxInFlightHostMemoryBytes: Long = + get(ASYNC_WRITE_MAX_IN_FLIGHT_HOST_MEMORY_BYTES) + /** * Convert a string value to the injection configuration OomInjection. * @@ -3248,6 +3281,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging { lazy val caseWhenFuseEnabled: Boolean = get(CASE_WHEN_FUSE) + lazy val isAsyncOutputWriteEnabled: Boolean = get(ENABLE_ASYNC_OUTPUT_WRITE) + private val optimizerDefaults = Map( // this is not accurate because CPU projections do have a cost due to appending values // to each row that is produced, but this needs to be a really small number because diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStream.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStream.scala new file mode 100644 index 00000000000..40904a96dd2 --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStream.scala @@ -0,0 +1,186 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async + +import java.io.{IOException, OutputStream} +import java.util.concurrent.{Callable, TimeUnit} +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} + +import com.nvidia.spark.rapids.RapidsPluginImplicits._ + +import org.apache.spark.sql.rapids.execution.TrampolineUtil + +/** + * OutputStream that performs writes asynchronously. Writes are scheduled on a background thread + * and executed in the order they were scheduled. This class is not thread-safe and should only be + * used by a single thread. + */ +class AsyncOutputStream(openFn: Callable[OutputStream], trafficController: TrafficController) + extends OutputStream { + + private var closed = false + + private val executor = new ThrottlingExecutor( + TrampolineUtil.newDaemonCachedThreadPool("AsyncOutputStream", 1, 1), + trafficController) + + // Open the underlying stream asynchronously as soon as the AsyncOutputStream is constructed, + // so that the open can be done in parallel with other operations. This could help with + // performance if the open is slow. + private val openFuture = executor.submit(openFn, 0) + // Let's give it enough time to open the stream. Something bad should have happened if it + // takes more than 5 minutes to open a stream. + private val openTimeoutMin = 5 + + private lazy val delegate: OutputStream = { + openFuture.get(openTimeoutMin, TimeUnit.MINUTES) + } + + class Metrics { + var numBytesScheduled: Long = 0 + // This is thread-safe as it is updated by the background thread and can be read by + // any threads. + val numBytesWritten: AtomicLong = new AtomicLong(0) + } + + val metrics = new Metrics + + /** + * The last error that occurred in the background thread, or None if no error occurred. + * Once it is set, all subsequent writes that are already scheduled will fail and no new + * writes will be accepted. + * + * This is thread-safe as it is set by the background thread and can be read by any threads. + */ + val lastError: AtomicReference[Option[Throwable]] = + new AtomicReference[Option[Throwable]](None) + + @throws[IOException] + private def throwIfError(): Unit = { + lastError.get() match { + case Some(t: IOException) => throw t + case Some(t) => throw new IOException(t) + case None => + } + } + + @throws[IOException] + private def ensureOpen(): Unit = { + if (closed) { + throw new IOException("Stream closed") + } + } + + private def scheduleWrite(fn: () => Unit, bytesToWrite: Int): Unit = { + throwIfError() + ensureOpen() + + metrics.numBytesScheduled += bytesToWrite + executor.submit(() => { + throwIfError() + ensureOpen() + + try { + fn() + metrics.numBytesWritten.addAndGet(bytesToWrite) + } catch { + case t: Throwable => + // Update the error state + lastError.set(Some(t)) + } + }, bytesToWrite) + } + + override def write(b: Int): Unit = { + scheduleWrite(() => delegate.write(b), 1) + } + + override def write(b: Array[Byte]): Unit = { + scheduleWrite(() => delegate.write(b), b.length) + } + + /** + * Schedules a write of the given bytes to the underlying stream. The write is executed + * asynchronously on a background thread. The method returns immediately, and the write may not + * have completed when the method returns. + * + * If an error has occurred in the background thread and [[lastError]] has been set, this function + * will throw an IOException immediately. + * + * If an error has occurred in the background thread while executing a previous write after the + * current write has been scheduled, the current write will fail with the same error. + */ + @throws[IOException] + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + scheduleWrite(() => delegate.write(b, off, len), len) + } + + /** + * Flushes all pending writes to the underlying stream. This method blocks until all pending + * writes have been completed. If an error has occurred in the background thread, this method + * will throw an IOException. + * + * If an error has occurred in the background thread and [[lastError]] has been set, this function + * will throw an IOException immediately. + * + * If an error has occurred in the background thread while executing a previous task after the + * current flush has been scheduled, the current flush will fail with the same error. + */ + @throws[IOException] + override def flush(): Unit = { + throwIfError() + ensureOpen() + + val f = executor.submit(() => { + throwIfError() + ensureOpen() + + delegate.flush() + }, 0) + + f.get() + } + + /** + * Closes the underlying stream and releases any resources associated with it. All pending writes + * are flushed before closing the stream. This method blocks until all pending writes have been + * completed. + * + * If an error has occurred while flushing, this function will throw an IOException. + * + * If an error has occurred while executing a previous task before this function is called, + * this function will throw the same error. All resources and the underlying stream are still + * guaranteed to be closed. + */ + @throws[IOException] + override def close(): Unit = { + if (!closed) { + Seq[AutoCloseable]( + () => { + // Wait for all pending writes to complete + // This will throw an exception if one of the writes fails + flush() + }, + () => { + // Give the executor a chance to shutdown gracefully. + executor.shutdownNow(10, TimeUnit.SECONDS) + }, + delegate, + () => closed = true).safeClose() + } + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala new file mode 100644 index 00000000000..45889bf89ac --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async + +import java.util.concurrent.{Callable, ExecutorService, Future, TimeUnit} + +/** + * Thin wrapper around an ExecutorService that adds throttling. + */ +class ThrottlingExecutor( + val executor: ExecutorService, throttler: TrafficController) { + + def submit[T](callable: Callable[T], hostMemoryBytes: Long): Future[T] = { + val task = new Task[T](hostMemoryBytes, callable) + throttler.blockUntilRunnable(task) + executor.submit(() => { + try { + task.call() + } finally { + throttler.taskCompleted(task) + } + }) + } + + def shutdownNow(timeout: Long, timeUnit: TimeUnit): Unit = { + executor.shutdownNow() + executor.awaitTermination(timeout, timeUnit) + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala new file mode 100644 index 00000000000..0110f2d89ca --- /dev/null +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async + +import java.util.concurrent.Callable +import javax.annotation.concurrent.GuardedBy + +import com.nvidia.spark.rapids.RapidsConf + +/** + * Simple wrapper around a [[Callable]] that also keeps track of the host memory bytes used by + * the task. + * + * Note: we may want to add more metadata to the task in the future, such as the device memory, + * as we implement more throttling strategies. + */ +class Task[T](val hostMemoryBytes: Long, callable: Callable[T]) extends Callable[T] { + override def call(): T = callable.call() +} + +/** + * Throttle interface to be implemented by different throttling strategies. + * + * Currently, only HostMemoryThrottle is implemented, which limits the maximum in-flight host + * memory bytes. In the future, we can add more throttling strategies, such as limiting the + * device memory usage, the number of tasks, etc. + */ +trait Throttle { + + /** + * Returns true if the task can be accepted, false otherwise. + * TrafficController will block the task from being scheduled until this method returns true. + */ + def canAccept[T](task: Task[T]): Boolean + + /** + * Callback to be called when a task is scheduled. + */ + def taskScheduled[T](task: Task[T]): Unit + + /** + * Callback to be called when a task is completed, either successfully or with an exception. + */ + def taskCompleted[T](task: Task[T]): Unit +} + +/** + * Throttle implementation that limits the total host memory used by the in-flight tasks. + */ +class HostMemoryThrottle(val maxInFlightHostMemoryBytes: Long) extends Throttle { + private var totalHostMemoryBytes: Long = 0 + + override def canAccept[T](task: Task[T]): Boolean = { + totalHostMemoryBytes + task.hostMemoryBytes <= maxInFlightHostMemoryBytes + } + + override def taskScheduled[T](task: Task[T]): Unit = { + totalHostMemoryBytes += task.hostMemoryBytes + } + + override def taskCompleted[T](task: Task[T]): Unit = { + totalHostMemoryBytes -= task.hostMemoryBytes + } + + def getTotalHostMemoryBytes: Long = totalHostMemoryBytes +} + +/** + * TrafficController is responsible for blocking tasks from being scheduled when the throttle + * is exceeded. It also keeps track of the number of tasks that are currently scheduled. + * + * This class is thread-safe as it is used by multiple tasks. + */ +class TrafficController protected[rapids] (throttle: Throttle) { + + @GuardedBy("this") + private var numTasks: Int = 0 + + /** + * Blocks the task from being scheduled until the throttle allows it. If there is no task + * currently scheduled, the task is scheduled immediately even if the throttle is exceeded. + */ + def blockUntilRunnable[T](task: Task[T]): Unit = synchronized { + if (numTasks > 0) { + while (!throttle.canAccept(task)) { + wait(100) + } + } + numTasks += 1 + throttle.taskScheduled(task) + } + + def taskCompleted[T](task: Task[T]): Unit = synchronized { + numTasks -= 1 + throttle.taskCompleted(task) + notify() + } + + def numScheduledTasks: Int = synchronized { + numTasks + } +} + +object TrafficController { + + private var instance: TrafficController = _ + + /** + * Initializes the TrafficController singleton instance. + * This is called once per executor. + */ + def initialize(conf: RapidsConf): Unit = synchronized { + if (conf.isAsyncOutputWriteEnabled && instance == null) { + instance = new TrafficController( + new HostMemoryThrottle(conf.asyncWriteMaxInFlightHostMemoryBytes)) + } + } + + def getInstance: Option[TrafficController] = synchronized { + Option(instance) + } + + def shutdown(): Unit = synchronized { + if (instance != null) { + instance = null + } + } +} diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStreamSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStreamSuite.scala new file mode 100644 index 00000000000..a4fa35349ce --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/AsyncOutputStreamSuite.scala @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async + +import java.io.{BufferedOutputStream, File, FileOutputStream, IOException, OutputStream} +import java.util.concurrent.Callable + +import com.nvidia.spark.rapids.Arm.withResource +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +class AsyncOutputStreamSuite extends AnyFunSuite with BeforeAndAfterEach { + + private val bufLen = 128 * 1024 + private val buf: Array[Byte] = new Array[Byte](bufLen) + private val maxBufCount = 10 + private val trafficController = new TrafficController( + new HostMemoryThrottle(bufLen * maxBufCount)) + + def openStream(): AsyncOutputStream = { + new AsyncOutputStream(() => { + val file = File.createTempFile("async-write-test", "tmp") + new BufferedOutputStream(new FileOutputStream(file)) + }, trafficController) + } + + test("open, write, and close") { + val numBufs = 1000 + val stream = openStream() + withResource(stream) { os => + for (_ <- 0 until numBufs) { + os.write(buf) + } + } + assertResult(bufLen * numBufs)(stream.metrics.numBytesScheduled) + assertResult(bufLen * numBufs)(stream.metrics.numBytesWritten.get()) + } + + test("write after closed") { + val os = openStream() + os.close() + assertThrows[IOException] { + os.write(buf) + } + } + + test("flush after closed") { + val os = openStream() + os.close() + assertThrows[IOException] { + os.flush() + } + } + + class ThrowingOutputStream extends OutputStream { + + var failureCount = 0 + + override def write(i: Int): Unit = { + failureCount += 1 + throw new IOException(s"Failed ${failureCount} times") + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + failureCount += 1 + throw new IOException(s"Failed ${failureCount} times") + } + } + + def assertThrowsWithMsg[T](fn: Callable[T], clue: String, + expectedMsgPrefix: String): Unit = { + withClue(clue) { + try { + fn.call() + } catch { + case t: Throwable => + assertIOExceptionMsg(t, expectedMsgPrefix) + } + } + } + + def assertIOExceptionMsg(t: Throwable, expectedMsgPrefix: String): Unit = { + if (t.getClass.isAssignableFrom(classOf[IOException])) { + if (!t.getMessage.contains(expectedMsgPrefix)) { + fail(s"Unexpected exception message: ${t.getMessage}") + } + } else { + if (t.getCause != null) { + assertIOExceptionMsg(t.getCause, expectedMsgPrefix) + } else { + fail(s"Unexpected exception: $t") + } + } + } + + test("write after error") { + val os = new AsyncOutputStream(() => new ThrowingOutputStream, trafficController) + + // The first call to `write` should succeed + os.write(buf) + + // Wait for the first write to fail + while (os.lastError.get().isEmpty) { + Thread.sleep(100) + } + + // The second `write` call should fail with the exception thrown by the first write failure + assertThrowsWithMsg(() => os.write(buf), + "The second write should fail with the exception thrown by the first write failure", + "Failed 1 times") + + // `close` throws the same exception + assertThrowsWithMsg(() => os.close(), + "The second write should fail with the exception thrown by the first write failure", + "Failed 1 times") + + assertResult(bufLen)(os.metrics.numBytesScheduled) + assertResult(0)(os.metrics.numBytesWritten.get()) + assert(os.lastError.get().get.isInstanceOf[IOException]) + } + + test("flush after error") { + val os = new AsyncOutputStream(() => new ThrowingOutputStream, trafficController) + + // The first write should succeed + os.write(buf) + + // The flush should fail with the exception thrown by the write failure + assertThrowsWithMsg(() => os.flush(), + "The flush should fail with the exception thrown by the write failure", + "Failed 1 times") + + // `close` throws the same exception + assertThrowsWithMsg(() => os.close(), + "The flush should fail with the exception thrown by the write failure", + "Failed 1 times") + } + + test("close after error") { + val os = new AsyncOutputStream(() => new ThrowingOutputStream, trafficController) + + os.write(buf) + + assertThrowsWithMsg(() => os.close(), + "Close should fail with the exception thrown by the write failure", + "Failed 1 times") + } +} diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala new file mode 100644 index 00000000000..a8acf240878 --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala @@ -0,0 +1,145 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async + +import java.util.concurrent.{Callable, CountDownLatch, ExecutionException, Executors, Future, RejectedExecutionException, TimeUnit} + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach { + + // Some tests might take longer than usual in the limited CI environment. + // Use a long timeout to avoid flakiness. + val longTimeoutSec = 5 + + var throttle: HostMemoryThrottle = _ + var trafficController: TrafficController = _ + var executor: ThrottlingExecutor = _ + + class TestTask extends Callable[Unit] { + val latch = new CountDownLatch(1) + override def call(): Unit = { + latch.await() + } + } + + override def beforeEach(): Unit = { + throttle = new HostMemoryThrottle(100) + trafficController = new TrafficController(throttle) + executor = new ThrottlingExecutor( + Executors.newSingleThreadExecutor(), + trafficController + ) + } + + override def afterEach(): Unit = { + executor.shutdownNow(longTimeoutSec, TimeUnit.SECONDS) + } + + test("tasks submitted should update the state") { + val task1 = new TestTask + val future1 = executor.submit(task1, 10) + assertResult(1)(trafficController.numScheduledTasks) + assertResult(10)(throttle.getTotalHostMemoryBytes) + + val task2 = new TestTask + val future2 = executor.submit(task2, 20) + assertResult(2)(trafficController.numScheduledTasks) + assertResult(30)(throttle.getTotalHostMemoryBytes) + + task1.latch.countDown() + future1.get(longTimeoutSec, TimeUnit.SECONDS) + assertResult(1)(trafficController.numScheduledTasks) + assertResult(20)(throttle.getTotalHostMemoryBytes) + + task2.latch.countDown() + future2.get(longTimeoutSec, TimeUnit.SECONDS) + assertResult(0)(trafficController.numScheduledTasks) + assertResult(0)(throttle.getTotalHostMemoryBytes) + } + + test("tasks submission fails if total weight exceeds maxWeight") { + val task1 = new TestTask + val future1 = executor.submit(task1, 10) + assertResult(1)(trafficController.numScheduledTasks) + assertResult(10)(throttle.getTotalHostMemoryBytes) + + val task2 = new TestTask + val task2Weight = 100 + val exec = Executors.newSingleThreadExecutor() + val future2 = exec.submit(new Runnable { + override def run(): Unit = executor.submit(task2, task2Weight) + }) + Thread.sleep(100) + assert(!future2.isDone) + assertResult(1)(trafficController.numScheduledTasks) + assertResult(10)(throttle.getTotalHostMemoryBytes) + + task1.latch.countDown() + future1.get(longTimeoutSec, TimeUnit.SECONDS) + future2.get(longTimeoutSec, TimeUnit.SECONDS) + assertResult(1)(trafficController.numScheduledTasks) + assertResult(task2Weight)(throttle.getTotalHostMemoryBytes) + } + + test("submit one task heavier than maxWeight") { + val future = executor.submit(() => Thread.sleep(10), throttle.maxInFlightHostMemoryBytes + 1) + future.get(longTimeoutSec, TimeUnit.SECONDS) + assert(future.isDone) + assertResult(0)(trafficController.numScheduledTasks) + assertResult(0)(throttle.getTotalHostMemoryBytes) + } + + test("submit multiple tasks such that total weight does not exceed maxWeight") { + val numTasks = 10 + val taskRunTime = 10 + var future: Future[Unit] = null + for (_ <- 0 to numTasks) { + future = executor.submit(() => Thread.sleep(taskRunTime), 1) + } + // Give enough time for all tasks to complete + future.get(numTasks * taskRunTime * 5, TimeUnit.MILLISECONDS) + assertResult(0)(trafficController.numScheduledTasks) + assertResult(0)(throttle.getTotalHostMemoryBytes) + } + + test("shutdown while a task is blocked") { + val task1 = new TestTask + val future1 = executor.submit(task1, 10) + assertResult(1)(trafficController.numScheduledTasks) + assertResult(10)(throttle.getTotalHostMemoryBytes) + + val task2 = new TestTask + val task2Weight = 100 + val exec = Executors.newSingleThreadExecutor() + val future2 = exec.submit(new Runnable { + override def run(): Unit = executor.submit(task2, task2Weight) + }) + executor.shutdownNow(longTimeoutSec, TimeUnit.SECONDS) + + def assertCause(t: Throwable, cause: Class[_]): Unit = { + assert(t.getCause != null) + assert(cause.isInstance(t.getCause)) + } + + val e1 = intercept[ExecutionException](future1.get()) + assertCause(e1, classOf[InterruptedException]) + val e2 = intercept[ExecutionException](future2.get()) + assertCause(e2, classOf[RejectedExecutionException]) + } +} diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala new file mode 100644 index 00000000000..32868ff6055 --- /dev/null +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async + +import java.util.concurrent.{ExecutionException, Executors, ExecutorService, TimeUnit} + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach { + + private var throttle: HostMemoryThrottle = _ + private var controller: TrafficController = _ + private var executor: ExecutorService = _ + + override def beforeEach(): Unit = { + throttle = new HostMemoryThrottle(100) + controller = new TrafficController(throttle) + executor = Executors.newSingleThreadExecutor() + } + + override def afterEach(): Unit = { + executor.shutdownNow() + executor.awaitTermination(1, TimeUnit.SECONDS) + } + + class TestTask(taskMemoryBytes: Long) extends Task[Unit](taskMemoryBytes, () => {}) {} + + test("schedule tasks without blocking") { + val taskMemoryBytes = 50 + val t1 = new TestTask(taskMemoryBytes) + controller.blockUntilRunnable(t1) + assertResult(1)(controller.numScheduledTasks) + assertResult(taskMemoryBytes)(throttle.getTotalHostMemoryBytes) + + val t2 = new TestTask(50) + controller.blockUntilRunnable(t2) + assertResult(2)(controller.numScheduledTasks) + assertResult(2 * taskMemoryBytes)(throttle.getTotalHostMemoryBytes) + + controller.taskCompleted(t1) + assertResult(1)(controller.numScheduledTasks) + assertResult(taskMemoryBytes)(throttle.getTotalHostMemoryBytes) + } + + test("schedule task with blocking") { + val taskMemoryBytes = 50 + val t1 = new TestTask(taskMemoryBytes) + controller.blockUntilRunnable(t1) + + val t2 = new TestTask(taskMemoryBytes) + controller.blockUntilRunnable(t2) + + val t3 = new TestTask(taskMemoryBytes) + val f = executor.submit(new Runnable { + override def run(): Unit = controller.blockUntilRunnable(t3) + }) + Thread.sleep(100) + assert(!f.isDone) + + controller.taskCompleted(t1) + f.get(1, TimeUnit.SECONDS) + } + + test("shutdown while blocking") { + val t1 = new TestTask(10) + controller.blockUntilRunnable(t1) + + val t2 = new TestTask(110) + + val f = executor.submit(new Runnable { + override def run(): Unit = { + controller.blockUntilRunnable(t2) + } + }) + + executor.shutdownNow() + try { + f.get(1, TimeUnit.SECONDS) + fail("Should be interrupted") + } catch { + case ee: ExecutionException => + assert(ee.getCause.isInstanceOf[InterruptedException]) + case _: Throwable => fail("Should be interrupted") + } + } +}