Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
Signed-off-by: Zach Puller <[email protected]>
  • Loading branch information
zpuller committed Jan 17, 2025
1 parent fbdc9fb commit f404831
Showing 1 changed file with 28 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
import java.nio.file.StandardOpenOption
import java.util
import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}

import scala.collection.mutable

Expand All @@ -43,6 +43,7 @@ import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.BlockId


/**
* Spark-RAPIDS Spill Framework
*
Expand Down Expand Up @@ -256,7 +257,7 @@ object SpillableHostBufferHandle extends Logging {
while (chunkedPacker.hasNext) {
val (bb, len) = chunkedPacker.next()
withResource(bb) { _ =>
builder.copyNext(bb.buf.get, len, Cuda.DEFAULT_STREAM)
builder.copyNext(bb.buf, len, Cuda.DEFAULT_STREAM)
// copyNext is synchronous w.r.t. the cuda stream passed,
// no need to synchronize here.
}
Expand Down Expand Up @@ -1559,7 +1560,7 @@ object SpillFramework extends Logging {

def withHostSpillBounceBuffer[T](body: HostMemoryBuffer => T): T = {
withResource(hostSpillBounceBufferPool.nextBuffer()) { hmb =>
body(hmb.buf.get)
body(hmb.buf)
}
}

Expand Down Expand Up @@ -1602,16 +1603,16 @@ object SpillFramework extends Logging {
* @param pool - the pool to which this buffer belongs
*/
private[spill] class BounceBuffer[T <: AutoCloseable](
var buf: Option[T],
var buf: T,
private val pool: BounceBufferPool[T])
extends AutoCloseable {
override def close(): Unit = {
pool.returnBuffer(this)
}

def release(): Unit = {
buf.foreach(_.close())
buf = None
buf.close()
buf = null.asInstanceOf[T]
}
}

Expand All @@ -1626,34 +1627,35 @@ private[spill] class BounceBuffer[T <: AutoCloseable](
*/
class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long,
private val bbCount: Long,
private val allocator: Long => T) extends AutoCloseable {
private val pool = new ConcurrentLinkedQueue[BounceBuffer[T]]
private val allocator: Long => T)
extends AutoCloseable with Logging {

private val pool = new LinkedBlockingQueue[BounceBuffer[T]]
for (_ <- 1L to bbCount) {
pool.offer(new BounceBuffer[T](Some(allocator(bufSize)), this))
pool.offer(new BounceBuffer[T](allocator(bufSize), this))
}

def bufferSize: Long = bufSize
def nextBuffer(): BounceBuffer[T] = synchronized {
var res = pool.poll()
while (res == null) {
wait()
res = pool.poll()
def nextBuffer(): BounceBuffer[T] = pool.take()
def returnBuffer(buffer: BounceBuffer[T]): Unit = {
if (closed) {
buffer.release()
} else {
pool.offer(buffer)
}

res
}
def returnBuffer(buffer: BounceBuffer[T]): Unit = synchronized {
pool.offer(buffer)
notifyAll()
}

private var closed = false
override def close(): Unit = {
if (pool.size() < bbCount) {
throw new IllegalStateException("tried to close BounceBufferPool when buffers are still " +
"being used")
}
if (!closed) {
closed = true
if (pool.size() < bbCount) {
logError("tried to close BounceBufferPool when buffers are still " +
"being used")
}

pool.forEach(_.release())
pool.forEach(_.release())
}
}
}

Expand Down Expand Up @@ -1728,7 +1730,7 @@ class ChunkedPacker(table: Table,
if (closed) {
throw new IllegalStateException(s"ChunkedPacker is closed")
}
val bytesWritten = chunkedPack.next(bounceBuffer.buf.get)
val bytesWritten = chunkedPack.next(bounceBuffer.buf)
// we increment the refcount because the caller has no idea where
// this memory came from, so it should close it.
(bounceBuffer, bytesWritten)
Expand Down

0 comments on commit f404831

Please sign in to comment.