From f3ac8bede0afdd877c698fb1e54edd5b7f1366c8 Mon Sep 17 00:00:00 2001 From: Jihoon Son Date: Thu, 5 Dec 2024 12:20:49 -0800 Subject: [PATCH] Fix the task count check in TrafficController (#11783) * Fix TrafficController numTasks check Signed-off-by: Jihoon Son * rename weights properly * simplify the loop condition * Rename the condition variable for readability Co-authored-by: Gera Shegalov * missing renames * add test for when all tasks are big --------- Signed-off-by: Jihoon Son Co-authored-by: Gera Shegalov --- .../rapids/io/async/ThrottlingExecutor.scala | 2 + .../rapids/io/async/TrafficController.scala | 46 +++++++---- .../io/async/ThrottlingExecutorSuite.scala | 16 ++-- .../io/async/TrafficControllerSuite.scala | 80 ++++++++++++++++++- 4 files changed, 118 insertions(+), 26 deletions(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala index 45889bf89ac..99c3cc9ea5e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutor.scala @@ -20,6 +20,8 @@ import java.util.concurrent.{Callable, ExecutorService, Future, TimeUnit} /** * Thin wrapper around an ExecutorService that adds throttling. + * + * The given executor is owned by this class and will be shutdown when this class is shutdown. */ class ThrottlingExecutor( val executor: ExecutorService, throttler: TrafficController) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala index 0110f2d89ca..e69af5bf258 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/io/async/TrafficController.scala @@ -17,6 +17,7 @@ package com.nvidia.spark.rapids.io.async import java.util.concurrent.Callable +import java.util.concurrent.locks.ReentrantLock import javax.annotation.concurrent.GuardedBy import com.nvidia.spark.rapids.RapidsConf @@ -85,38 +86,55 @@ class HostMemoryThrottle(val maxInFlightHostMemoryBytes: Long) extends Throttle * * This class is thread-safe as it is used by multiple tasks. */ -class TrafficController protected[rapids] (throttle: Throttle) { +class TrafficController protected[rapids] (@GuardedBy("lock") throttle: Throttle) { - @GuardedBy("this") + @GuardedBy("lock") private var numTasks: Int = 0 + private val lock = new ReentrantLock() + private val canBeScheduled = lock.newCondition() + /** * Blocks the task from being scheduled until the throttle allows it. If there is no task * currently scheduled, the task is scheduled immediately even if the throttle is exceeded. */ - def blockUntilRunnable[T](task: Task[T]): Unit = synchronized { - if (numTasks > 0) { - while (!throttle.canAccept(task)) { - wait(100) + def blockUntilRunnable[T](task: Task[T]): Unit = { + lock.lockInterruptibly() + try { + while (numTasks > 0 && !throttle.canAccept(task)) { + canBeScheduled.await() } + numTasks += 1 + throttle.taskScheduled(task) + } finally { + lock.unlock() } - numTasks += 1 - throttle.taskScheduled(task) } - def taskCompleted[T](task: Task[T]): Unit = synchronized { - numTasks -= 1 - throttle.taskCompleted(task) - notify() + def taskCompleted[T](task: Task[T]): Unit = { + lock.lockInterruptibly() + try { + numTasks -= 1 + throttle.taskCompleted(task) + canBeScheduled.signal() + } finally { + lock.unlock() + } } - def numScheduledTasks: Int = synchronized { - numTasks + def numScheduledTasks: Int = { + lock.lockInterruptibly() + try { + numTasks + } finally { + lock.unlock() + } } } object TrafficController { + @GuardedBy("this") private var instance: TrafficController = _ /** diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala index a8acf240878..86fb692cd64 100644 --- a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/ThrottlingExecutorSuite.scala @@ -73,17 +73,17 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach { assertResult(0)(throttle.getTotalHostMemoryBytes) } - test("tasks submission fails if total weight exceeds maxWeight") { + test("tasks submission fails if totalHostMemoryBytes exceeds maxHostMemoryBytes") { val task1 = new TestTask val future1 = executor.submit(task1, 10) assertResult(1)(trafficController.numScheduledTasks) assertResult(10)(throttle.getTotalHostMemoryBytes) val task2 = new TestTask - val task2Weight = 100 + val task2HostMemory = 100 val exec = Executors.newSingleThreadExecutor() val future2 = exec.submit(new Runnable { - override def run(): Unit = executor.submit(task2, task2Weight) + override def run(): Unit = executor.submit(task2, task2HostMemory) }) Thread.sleep(100) assert(!future2.isDone) @@ -94,10 +94,10 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach { future1.get(longTimeoutSec, TimeUnit.SECONDS) future2.get(longTimeoutSec, TimeUnit.SECONDS) assertResult(1)(trafficController.numScheduledTasks) - assertResult(task2Weight)(throttle.getTotalHostMemoryBytes) + assertResult(task2HostMemory)(throttle.getTotalHostMemoryBytes) } - test("submit one task heavier than maxWeight") { + test("submit one task heavier than maxHostMemoryBytes") { val future = executor.submit(() => Thread.sleep(10), throttle.maxInFlightHostMemoryBytes + 1) future.get(longTimeoutSec, TimeUnit.SECONDS) assert(future.isDone) @@ -105,7 +105,7 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach { assertResult(0)(throttle.getTotalHostMemoryBytes) } - test("submit multiple tasks such that total weight does not exceed maxWeight") { + test("submit multiple tasks such that totalHostMemoryBytes does not exceed maxHostMemoryBytes") { val numTasks = 10 val taskRunTime = 10 var future: Future[Unit] = null @@ -125,10 +125,10 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach { assertResult(10)(throttle.getTotalHostMemoryBytes) val task2 = new TestTask - val task2Weight = 100 + val task2HostMemory = 100 val exec = Executors.newSingleThreadExecutor() val future2 = exec.submit(new Runnable { - override def run(): Unit = executor.submit(task2, task2Weight) + override def run(): Unit = executor.submit(task2, task2HostMemory) }) executor.shutdownNow(longTimeoutSec, TimeUnit.SECONDS) diff --git a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala index 32868ff6055..1c06755a8af 100644 --- a/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala +++ b/sql-plugin/src/test/scala/com/nvidia/spark/rapids/io/async/TrafficControllerSuite.scala @@ -16,19 +16,34 @@ package com.nvidia.spark.rapids.io.async -import java.util.concurrent.{ExecutionException, Executors, ExecutorService, TimeUnit} +import java.util.concurrent.{ExecutionException, Executors, ExecutorService, Future, TimeUnit} import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.TimeLimitedTests import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.time.Span +import org.scalatest.time.SpanSugar._ -class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach { +class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach with TimeLimitedTests { - private var throttle: HostMemoryThrottle = _ + class RecordingExecOrderHostMemoryThrottle(maxInFlightHostMemoryBytes: Long) + extends HostMemoryThrottle(maxInFlightHostMemoryBytes) { + var tasksScheduled = Seq.empty[TestTask] + + override def taskScheduled[T](task: Task[T]): Unit = { + tasksScheduled = tasksScheduled :+ task.asInstanceOf[TestTask] + super.taskScheduled(task) + } + } + + val timeLimit: Span = 10.seconds + + private var throttle: RecordingExecOrderHostMemoryThrottle = _ private var controller: TrafficController = _ private var executor: ExecutorService = _ override def beforeEach(): Unit = { - throttle = new HostMemoryThrottle(100) + throttle = new RecordingExecOrderHostMemoryThrottle(100) controller = new TrafficController(throttle) executor = Executors.newSingleThreadExecutor() } @@ -76,6 +91,63 @@ class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach { f.get(1, TimeUnit.SECONDS) } + test("big task should be scheduled after all running tasks are completed") { + val taskMemoryBytes = 50 + val t1 = new TestTask(taskMemoryBytes) + controller.blockUntilRunnable(t1) + + val t2 = new TestTask(150) + val f = executor.submit(new Runnable { + override def run(): Unit = controller.blockUntilRunnable(t2) + }) + Thread.sleep(100) + assert(!f.isDone) + + controller.taskCompleted(t1) + f.get(1, TimeUnit.SECONDS) + } + + test("all tasks are bigger than the total memory limit") { + val bigTaskMemoryBytes = 130 + val (tasks, futures) = (0 to 2).map { _ => + val t = new TestTask(bigTaskMemoryBytes) + val f: Future[_] = executor.submit(new Runnable { + override def run(): Unit = controller.blockUntilRunnable(t) + }) + (t, f.asInstanceOf[Future[Unit]]) + }.unzip + while (controller.numScheduledTasks == 0) { + Thread.sleep(100) + } + assert(futures(0).isDone) + assertResult(1)(controller.numScheduledTasks) + assertResult(throttle.tasksScheduled.head)(tasks(0)) + + // The first task has been completed + controller.taskCompleted(tasks(0)) + // Wait for the second task to be scheduled + while (controller.numScheduledTasks == 0) { + Thread.sleep(100) + } + assert(futures(1).isDone) + assertResult(1)(controller.numScheduledTasks) + assertResult(throttle.tasksScheduled(1))(tasks(1)) + + // The second task has been completed + controller.taskCompleted(tasks(1)) + // Wait for the third task to be scheduled + while (controller.numScheduledTasks == 0) { + Thread.sleep(100) + } + assert(futures(2).isDone) + assertResult(1)(controller.numScheduledTasks) + assertResult(throttle.tasksScheduled(2))(tasks(2)) + + // The third task has been completed + controller.taskCompleted(tasks(2)) + assertResult(0)(controller.numScheduledTasks) + } + test("shutdown while blocking") { val t1 = new TestTask(10) controller.blockUntilRunnable(t1)