diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala index 7fdb3d932f9..40e39aba318 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/spill/SpillFramework.scala @@ -22,7 +22,7 @@ import java.nio.channels.{Channels, FileChannel, WritableByteChannel} import java.nio.file.StandardOpenOption import java.util import java.util.UUID -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue} import scala.collection.mutable @@ -43,6 +43,7 @@ import org.apache.spark.sql.types.DataType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.BlockId + /** * Spark-RAPIDS Spill Framework * @@ -256,7 +257,7 @@ object SpillableHostBufferHandle extends Logging { while (chunkedPacker.hasNext) { val (bb, len) = chunkedPacker.next() withResource(bb) { _ => - builder.copyNext(bb.buf.get, len, Cuda.DEFAULT_STREAM) + builder.copyNext(bb.buf, len, Cuda.DEFAULT_STREAM) // copyNext is synchronous w.r.t. the cuda stream passed, // no need to synchronize here. } @@ -1559,7 +1560,7 @@ object SpillFramework extends Logging { def withHostSpillBounceBuffer[T](body: HostMemoryBuffer => T): T = { withResource(hostSpillBounceBufferPool.nextBuffer()) { hmb => - body(hmb.buf.get) + body(hmb.buf) } } @@ -1602,7 +1603,7 @@ object SpillFramework extends Logging { * @param pool - the pool to which this buffer belongs */ private[spill] class BounceBuffer[T <: AutoCloseable]( - var buf: Option[T], + var buf: T, private val pool: BounceBufferPool[T]) extends AutoCloseable { override def close(): Unit = { @@ -1610,8 +1611,8 @@ private[spill] class BounceBuffer[T <: AutoCloseable]( } def release(): Unit = { - buf.foreach(_.close()) - buf = None + buf.close() + buf = null.asInstanceOf[T] } } @@ -1626,34 +1627,35 @@ private[spill] class BounceBuffer[T <: AutoCloseable]( */ class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long, private val bbCount: Long, - private val allocator: Long => T) extends AutoCloseable { - private val pool = new ConcurrentLinkedQueue[BounceBuffer[T]] + private val allocator: Long => T) + extends AutoCloseable with Logging { + + private val pool = new LinkedBlockingQueue[BounceBuffer[T]] for (_ <- 1L to bbCount) { - pool.offer(new BounceBuffer[T](Some(allocator(bufSize)), this)) + pool.offer(new BounceBuffer[T](allocator(bufSize), this)) } def bufferSize: Long = bufSize - def nextBuffer(): BounceBuffer[T] = synchronized { - var res = pool.poll() - while (res == null) { - wait() - res = pool.poll() + def nextBuffer(): BounceBuffer[T] = pool.take() + def returnBuffer(buffer: BounceBuffer[T]): Unit = { + if (closed) { + buffer.release() + } else { + pool.offer(buffer) } - - res - } - def returnBuffer(buffer: BounceBuffer[T]): Unit = synchronized { - pool.offer(buffer) - notifyAll() } + private var closed = false override def close(): Unit = { - if (pool.size() < bbCount) { - throw new IllegalStateException("tried to close BounceBufferPool when buffers are still " + - "being used") - } + if (!closed) { + closed = true + if (pool.size() < bbCount) { + logError("tried to close BounceBufferPool when buffers are still " + + "being used") + } - pool.forEach(_.release()) + pool.forEach(_.release()) + } } } @@ -1728,7 +1730,7 @@ class ChunkedPacker(table: Table, if (closed) { throw new IllegalStateException(s"ChunkedPacker is closed") } - val bytesWritten = chunkedPack.next(bounceBuffer.buf.get) + val bytesWritten = chunkedPack.next(bounceBuffer.buf) // we increment the refcount because the caller has no idea where // this memory came from, so it should close it. (bounceBuffer, bytesWritten)