Skip to content

Commit

Permalink
consolidate bounce buffer pool definitions
Browse files Browse the repository at this point in the history
Signed-off-by: Zach Puller <[email protected]>
  • Loading branch information
zpuller committed Jan 17, 2025
1 parent 0edcb08 commit 872dc5e
Showing 1 changed file with 65 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableSeq
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.internal.HostByteBufferIterator
import org.apache.commons.io.IOUtils
import scala.collection.JavaConverters.collectionAsScalaIterableConverter

import org.apache.spark.{SparkConf, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -258,7 +257,7 @@ object SpillableHostBufferHandle extends Logging {
val (bb, len) = chunkedPacker.next()
withResource(bb) { _ =>
withResource(new NvtxRange("chunked packer bounce buffer", NvtxColor.RED)) { _ =>
builder.copyNext(bb.dmb, len, Cuda.DEFAULT_STREAM)
builder.copyNext(bb.buf.get, len, Cuda.DEFAULT_STREAM)
// copyNext is synchronous w.r.t. the cuda stream passed,
// no need to synchronize here.
}
Expand Down Expand Up @@ -1493,16 +1492,6 @@ object SpillableColumnarBatchHandle {
}
}


class PooledDeviceBounceBuffer(var buf: DeviceMemoryBuffer,
deviceBounceBufferPool: DeviceBounceBufferPool)
extends DeviceBounceBuffer(buf) {
override def close(): Unit = {
super.close()
deviceBounceBufferPool.returnBuffer(this)
}
}

object SpillFramework extends Logging {
// public for tests. Some tests not in the `spill` package require setting this
// because they need fine control over allocations.
Expand All @@ -1516,39 +1505,7 @@ object SpillFramework extends Logging {
storesInternal
}

class HostBounceBufferPool(rapidsConf: RapidsConf) extends AutoCloseable {
private val pool = new ConcurrentLinkedQueue[HostMemoryBuffer]()

{
val bbCount = rapidsConf.spillToDiskBounceBufferCount
for (_ <- 1L to bbCount) {
pool.offer(HostMemoryBuffer.allocate(rapidsConf.spillToDiskBounceBufferSize))
}
}

def nextBuffer(): HostMemoryBuffer = synchronized {
var buf = pool.poll()
while (buf == null) {
wait()
buf = pool.poll()
}
buf
}

def returnToPool(hmb: HostMemoryBuffer): Unit = synchronized {
pool.offer(hmb)
notifyAll()
}

override def close(): Unit = {
if (pool.size() != rapidsConf.spillToDiskBounceBufferCount) {
throw new IllegalStateException("HostBounceBufferPool closed while buffers are " +
"still being used");
}
pool.forEach(_.close())
}
}
private var hostSpillBounceBufferPool: HostBounceBufferPool = _
private var hostSpillBounceBufferPool: BounceBufferPool[HostMemoryBuffer] = _

private lazy val conf: SparkConf = {
val env = SparkEnv.get
Expand All @@ -1573,41 +1530,15 @@ object SpillFramework extends Logging {
} else {
Some(rapidsConf.hostSpillStorageSize)
}
// this should hopefully be pinned, but it would work without
hostSpillBounceBufferPool = new HostBounceBufferPool(rapidsConf)

chunkedPackBounceBufferPool = new DeviceBounceBufferPool {
private val pool = new ConcurrentLinkedQueue[DeviceBounceBuffer]()

{
val bbCount = rapidsConf.chunkedPackBounceBufferCount
for (_ <- 1L to bbCount) {
pool.offer(new PooledDeviceBounceBuffer(
DeviceMemoryBuffer.allocate(rapidsConf.chunkedPackBounceBufferSize), this))
}
}

override def bufferSize: Long = rapidsConf.chunkedPackBounceBufferSize
override def nextBuffer(): DeviceBounceBuffer = synchronized {
// can block waiting for bounceBuffer to be released
var buffer = pool.poll()
while (buffer == null) {
wait()
buffer = pool.poll()
}

buffer.acquire()
}
override def close(): Unit = {
// this closes the DeviceMemoryBuffer wrapped by the bounce buffer class
pool.asScala.foreach(_.release())
}
hostSpillBounceBufferPool = new BounceBufferPool[HostMemoryBuffer](
rapidsConf.spillToDiskBounceBufferSize,
rapidsConf.spillToDiskBounceBufferCount,
HostMemoryBuffer.allocate)

override def returnBuffer(buffer: DeviceBounceBuffer): Unit = synchronized {
pool.offer(buffer)
notifyAll()
}
}
chunkedPackBounceBufferPool = new BounceBufferPool[DeviceMemoryBuffer](
rapidsConf.chunkedPackBounceBufferSize,
rapidsConf.chunkedPackBounceBufferCount,
DeviceMemoryBuffer.allocate)
storesInternal = new SpillableStores {
override var deviceStore: SpillableDeviceStore = new SpillableDeviceStore
override var hostStore: SpillableHostStore = new SpillableHostStore(hostSpillStorageSize)
Expand All @@ -1633,15 +1564,12 @@ object SpillFramework extends Logging {
}

def withHostSpillBounceBuffer[T](body: HostMemoryBuffer => T): T = {
val hmb = hostSpillBounceBufferPool.nextBuffer()
withResource(new NvtxRange("host spill bounce buffer", NvtxColor.RED)) { _ =>
val res = body(hmb)
hostSpillBounceBufferPool.returnToPool(hmb)
res
withResource(hostSpillBounceBufferPool.nextBuffer()) { hmb =>
body(hmb.buf.get)
}
}

var chunkedPackBounceBufferPool: DeviceBounceBufferPool = _
var chunkedPackBounceBufferPool: BounceBufferPool[DeviceMemoryBuffer] = _

// if the stores have already shut down, we don't want to create them here
// so we use `storesInternal` directly in these remove functions.
Expand All @@ -1668,60 +1596,71 @@ object SpillFramework extends Logging {
/**
* A bounce buffer wrapper class that supports the concept of acquisition.
*
* The bounce buffer is acquired exclusively. So any calls to acquire while the
* buffer is in use will block at `acquire`. Calls to `release` notify the blocked
* threads, and they will check to see if they can acquire.
* The bounce buffer is acquired from a BounceBufferPool, so any calls to
* BounceBufferPool.nextBuffer will block if the pool has no available buffers.
*
* `close` is the interface to unacquire the bounce buffer.
* 'close' restores a bounce buffer to the pool for other callers to use.
*
* `release` actually closes the underlying DeviceMemoryBuffer, and should be called
* `release` actually closes the underlying buffer, and should be called
* once at the end of the lifetime of the executor.
*
* @param dmb - actual cudf DeviceMemoryBuffer that this class is protecting.
* @param buf - actual cudf DeviceMemoryBuffer that this class is protecting.
* @param pool - the pool to which this buffer belongs
*/
private[spill] case class DeviceBounceBuffer(var dmb: DeviceMemoryBuffer) extends AutoCloseable {
private var acquired: Boolean = false
def acquire(): DeviceBounceBuffer = synchronized {
while (acquired) {
wait()
}
acquired = true
this
}

private def unaquire(): Unit = synchronized {
acquired = false
notifyAll()
}

private[spill] class BounceBuffer[T <: AutoCloseable](
var buf: Option[T],
private val pool: BounceBufferPool[T])
extends AutoCloseable {
override def close(): Unit = {
unaquire()
pool.returnBuffer(this)
}

def release(): Unit = synchronized {
if (acquired) {
throw new IllegalStateException(
"closing device buffer pool, but some bounce buffers are in use.")
}
if (dmb != null) {
dmb.close()
dmb = null
}
def release(): Unit = {
buf.foreach(_.close())
buf = None
}
}


/**
* A bounce buffer pool with buffers of size `bufferSize`
* A bounce buffer pool with buffers of size `bufSize`
*
* This pool returns instances of `DeviceBounceBuffer`, that should
* This pool returns instances of `BounceBuffer[T]`, that should
* be closed in order to be reused.
*
* Callers should synchronize before calling close on their `DeviceMemoryBuffer`s.
*/
trait DeviceBounceBufferPool extends AutoCloseable {
def bufferSize: Long
def nextBuffer(): DeviceBounceBuffer
def returnBuffer(buffer: DeviceBounceBuffer): Unit
class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long,
private val bbCount: Long,
private val allocator: Long => T) extends AutoCloseable {
private val pool = new ConcurrentLinkedQueue[BounceBuffer[T]]
for (_ <- 1L to bbCount) {
pool.offer(new BounceBuffer[T](Some(allocator(bufSize)), this))
}

def bufferSize: Long = bufSize
def nextBuffer(): BounceBuffer[T] = synchronized {
var res = pool.poll()
while (res == null) {
wait()
res = pool.poll()
}

res
}
def returnBuffer(buffer: BounceBuffer[T]): Unit = synchronized {
pool.offer(buffer)
notifyAll()
}

override def close(): Unit = {
if (pool.size() < bbCount) {
throw new IllegalStateException("tried to close BounceBufferPool when buffers are still " +
"being used")
}

pool.forEach(_.release())
}
}

/**
Expand All @@ -1739,8 +1678,8 @@ trait DeviceBounceBufferPool extends AutoCloseable {
* @param bounceBufferPool bounce buffer pool to use during the lifetime of this packer.
*/
class ChunkedPacker(table: Table,
bounceBufferPool: DeviceBounceBufferPool)
extends Iterator[(DeviceBounceBuffer, Long)] with Logging with AutoCloseable {
bounceBufferPool: BounceBufferPool[DeviceMemoryBuffer])
extends Iterator[(BounceBuffer[DeviceMemoryBuffer], Long)] with Logging with AutoCloseable {

private var closed: Boolean = false

Expand Down Expand Up @@ -1790,13 +1729,12 @@ class ChunkedPacker(table: Table,
chunkedPack.hasNext
}

override def next(): (DeviceBounceBuffer, Long) = {
// this ONLY CLOSES if EXCEPT
override def next(): (BounceBuffer[DeviceMemoryBuffer], Long) = {
closeOnExcept(bounceBufferPool.nextBuffer()) { bounceBuffer =>
if (closed) {
throw new IllegalStateException(s"ChunkedPacker is closed")
}
val bytesWritten = chunkedPack.next(bounceBuffer.dmb)
val bytesWritten = chunkedPack.next(bounceBuffer.buf.get)
// we increment the refcount because the caller has no idea where
// this memory came from, so it should close it.
(bounceBuffer, bytesWritten)
Expand Down

0 comments on commit 872dc5e

Please sign in to comment.