Skip to content

Commit

Permalink
Fix the task count check in TrafficController (#11783)
Browse files Browse the repository at this point in the history
* Fix TrafficController numTasks check

Signed-off-by: Jihoon Son <[email protected]>

* rename weights properly

* simplify the loop condition

* Rename the condition variable for readability

Co-authored-by: Gera Shegalov <[email protected]>

* missing renames

* add test for when all tasks are big

---------

Signed-off-by: Jihoon Son <[email protected]>
Co-authored-by: Gera Shegalov <[email protected]>
  • Loading branch information
jihoonson and gerashegalov authored Dec 5, 2024
1 parent 234f4db commit f3ac8be
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import java.util.concurrent.{Callable, ExecutorService, Future, TimeUnit}

/**
* Thin wrapper around an ExecutorService that adds throttling.
*
* The given executor is owned by this class and will be shutdown when this class is shutdown.
*/
class ThrottlingExecutor(
val executor: ExecutorService, throttler: TrafficController) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.nvidia.spark.rapids.io.async

import java.util.concurrent.Callable
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.GuardedBy

import com.nvidia.spark.rapids.RapidsConf
Expand Down Expand Up @@ -85,38 +86,55 @@ class HostMemoryThrottle(val maxInFlightHostMemoryBytes: Long) extends Throttle
*
* This class is thread-safe as it is used by multiple tasks.
*/
class TrafficController protected[rapids] (throttle: Throttle) {
class TrafficController protected[rapids] (@GuardedBy("lock") throttle: Throttle) {

@GuardedBy("this")
@GuardedBy("lock")
private var numTasks: Int = 0

private val lock = new ReentrantLock()
private val canBeScheduled = lock.newCondition()

/**
* Blocks the task from being scheduled until the throttle allows it. If there is no task
* currently scheduled, the task is scheduled immediately even if the throttle is exceeded.
*/
def blockUntilRunnable[T](task: Task[T]): Unit = synchronized {
if (numTasks > 0) {
while (!throttle.canAccept(task)) {
wait(100)
def blockUntilRunnable[T](task: Task[T]): Unit = {
lock.lockInterruptibly()
try {
while (numTasks > 0 && !throttle.canAccept(task)) {
canBeScheduled.await()
}
numTasks += 1
throttle.taskScheduled(task)
} finally {
lock.unlock()
}
numTasks += 1
throttle.taskScheduled(task)
}

def taskCompleted[T](task: Task[T]): Unit = synchronized {
numTasks -= 1
throttle.taskCompleted(task)
notify()
def taskCompleted[T](task: Task[T]): Unit = {
lock.lockInterruptibly()
try {
numTasks -= 1
throttle.taskCompleted(task)
canBeScheduled.signal()
} finally {
lock.unlock()
}
}

def numScheduledTasks: Int = synchronized {
numTasks
def numScheduledTasks: Int = {
lock.lockInterruptibly()
try {
numTasks
} finally {
lock.unlock()
}
}
}

object TrafficController {

@GuardedBy("this")
private var instance: TrafficController = _

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach {
assertResult(0)(throttle.getTotalHostMemoryBytes)
}

test("tasks submission fails if total weight exceeds maxWeight") {
test("tasks submission fails if totalHostMemoryBytes exceeds maxHostMemoryBytes") {
val task1 = new TestTask
val future1 = executor.submit(task1, 10)
assertResult(1)(trafficController.numScheduledTasks)
assertResult(10)(throttle.getTotalHostMemoryBytes)

val task2 = new TestTask
val task2Weight = 100
val task2HostMemory = 100
val exec = Executors.newSingleThreadExecutor()
val future2 = exec.submit(new Runnable {
override def run(): Unit = executor.submit(task2, task2Weight)
override def run(): Unit = executor.submit(task2, task2HostMemory)
})
Thread.sleep(100)
assert(!future2.isDone)
Expand All @@ -94,18 +94,18 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach {
future1.get(longTimeoutSec, TimeUnit.SECONDS)
future2.get(longTimeoutSec, TimeUnit.SECONDS)
assertResult(1)(trafficController.numScheduledTasks)
assertResult(task2Weight)(throttle.getTotalHostMemoryBytes)
assertResult(task2HostMemory)(throttle.getTotalHostMemoryBytes)
}

test("submit one task heavier than maxWeight") {
test("submit one task heavier than maxHostMemoryBytes") {
val future = executor.submit(() => Thread.sleep(10), throttle.maxInFlightHostMemoryBytes + 1)
future.get(longTimeoutSec, TimeUnit.SECONDS)
assert(future.isDone)
assertResult(0)(trafficController.numScheduledTasks)
assertResult(0)(throttle.getTotalHostMemoryBytes)
}

test("submit multiple tasks such that total weight does not exceed maxWeight") {
test("submit multiple tasks such that totalHostMemoryBytes does not exceed maxHostMemoryBytes") {
val numTasks = 10
val taskRunTime = 10
var future: Future[Unit] = null
Expand All @@ -125,10 +125,10 @@ class ThrottlingExecutorSuite extends AnyFunSuite with BeforeAndAfterEach {
assertResult(10)(throttle.getTotalHostMemoryBytes)

val task2 = new TestTask
val task2Weight = 100
val task2HostMemory = 100
val exec = Executors.newSingleThreadExecutor()
val future2 = exec.submit(new Runnable {
override def run(): Unit = executor.submit(task2, task2Weight)
override def run(): Unit = executor.submit(task2, task2HostMemory)
})
executor.shutdownNow(longTimeoutSec, TimeUnit.SECONDS)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,34 @@

package com.nvidia.spark.rapids.io.async

import java.util.concurrent.{ExecutionException, Executors, ExecutorService, TimeUnit}
import java.util.concurrent.{ExecutionException, Executors, ExecutorService, Future, TimeUnit}

import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.TimeLimitedTests
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._

class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach {
class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach with TimeLimitedTests {

private var throttle: HostMemoryThrottle = _
class RecordingExecOrderHostMemoryThrottle(maxInFlightHostMemoryBytes: Long)
extends HostMemoryThrottle(maxInFlightHostMemoryBytes) {
var tasksScheduled = Seq.empty[TestTask]

override def taskScheduled[T](task: Task[T]): Unit = {
tasksScheduled = tasksScheduled :+ task.asInstanceOf[TestTask]
super.taskScheduled(task)
}
}

val timeLimit: Span = 10.seconds

private var throttle: RecordingExecOrderHostMemoryThrottle = _
private var controller: TrafficController = _
private var executor: ExecutorService = _

override def beforeEach(): Unit = {
throttle = new HostMemoryThrottle(100)
throttle = new RecordingExecOrderHostMemoryThrottle(100)
controller = new TrafficController(throttle)
executor = Executors.newSingleThreadExecutor()
}
Expand Down Expand Up @@ -76,6 +91,63 @@ class TrafficControllerSuite extends AnyFunSuite with BeforeAndAfterEach {
f.get(1, TimeUnit.SECONDS)
}

test("big task should be scheduled after all running tasks are completed") {
val taskMemoryBytes = 50
val t1 = new TestTask(taskMemoryBytes)
controller.blockUntilRunnable(t1)

val t2 = new TestTask(150)
val f = executor.submit(new Runnable {
override def run(): Unit = controller.blockUntilRunnable(t2)
})
Thread.sleep(100)
assert(!f.isDone)

controller.taskCompleted(t1)
f.get(1, TimeUnit.SECONDS)
}

test("all tasks are bigger than the total memory limit") {
val bigTaskMemoryBytes = 130
val (tasks, futures) = (0 to 2).map { _ =>
val t = new TestTask(bigTaskMemoryBytes)
val f: Future[_] = executor.submit(new Runnable {
override def run(): Unit = controller.blockUntilRunnable(t)
})
(t, f.asInstanceOf[Future[Unit]])
}.unzip
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(0).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled.head)(tasks(0))

// The first task has been completed
controller.taskCompleted(tasks(0))
// Wait for the second task to be scheduled
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(1).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled(1))(tasks(1))

// The second task has been completed
controller.taskCompleted(tasks(1))
// Wait for the third task to be scheduled
while (controller.numScheduledTasks == 0) {
Thread.sleep(100)
}
assert(futures(2).isDone)
assertResult(1)(controller.numScheduledTasks)
assertResult(throttle.tasksScheduled(2))(tasks(2))

// The third task has been completed
controller.taskCompleted(tasks(2))
assertResult(0)(controller.numScheduledTasks)
}

test("shutdown while blocking") {
val t1 = new TestTask(10)
controller.blockUntilRunnable(t1)
Expand Down

0 comments on commit f3ac8be

Please sign in to comment.