Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support retry on the driver side #11980

Closed
firestarman opened this issue Jan 20, 2025 · 2 comments
Closed

[FEA] Support retry on the driver side #11980

firestarman opened this issue Jan 20, 2025 · 2 comments
Labels
feature request New feature or request

Comments

@firestarman
Copy link
Collaborator

firestarman commented Jan 20, 2025

Is your feature request related to a problem? Please describe.
[Not sure if this is a must-have feature request, but still filed for discussion.]

This is relevant to #11979.

Some calls to getColumnarBatch will happen on the driver side, e.g. reading a host batch for broadcasting. Not sure if we need to protect these calls by a retry block. If yes, then we need to support retry on the driver side.

/**
 * Create host columnar batches from either serialized buffers or device columnar batch. This
 * method can be safely called in both driver node and executor nodes. For now, it is used on
 * the driver side for reusing GPU broadcast results in the CPU.
 *
 * NOTE: The caller is responsible to release these host columnar batches.
 */
def hostBatch: ColumnarBatch = this.synchronized {
  maybeGpuBatch.map { spillable =>
    withResource(spillable.getColumnarBatch()) { batch =>
      val hostColumns: Array[ColumnVector] = GpuColumnVector
        .extractColumns(batch)
        .safeMap(_.copyToHost())
      new ColumnarBatch(hostColumns, numRows)
    }
  }.getOrElse {

Describe the solution you'd like
Current retry requires a task id to work, maybe we can introduce some special task IDs for the driver threads.

@firestarman firestarman added ? - Needs Triage Need team to review and classify feature request New feature or request labels Jan 20, 2025
@mattahrens mattahrens removed the ? - Needs Triage Need team to review and classify label Jan 21, 2025
@abellina
Copy link
Collaborator

There are two parts to this function (hostBatch), one part that only executes in the executor, and another part that must execute in the driver.

  1. (executor part) The getColumnarBatch part of this function cannot be called on the driver, as we don't have GPU memory there. We must be always in the driver in the getOrElse block. But yes, getColumnarBatch could be materializing a CB from host or disk onto GPU, and then copy to host. This could OOM on the GPU, and could OOM on the host. We should probably have retry blocks here.

  2. For the driver part (getOrElse) it is going to build host columns using buildHostColumns and ultimately JCudfSerialization.unpackHostColumnVectors. This is done off the jcudf serialized buffer, and I don't see allocations on the host here that we hadn't already allocated (HostConcatResult holds the host buffer here).

@firestarman
Copy link
Collaborator Author

firestarman commented Jan 22, 2025

Thx for the info, close this.
FYI, I met a NPE when running the aqe tests after adding the retry to getColumnarBatch in the method doWriteObject.

diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala
index ac0922476..22ca230d7 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala
@@ -25,7 +25,7 @@ import scala.concurrent.ExecutionContext
 import scala.ref.WeakReference
 import scala.util.control.NonFatal
 
-import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange}
+import ai.rapids.cudf.{HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange, Table}
 import ai.rapids.cudf.JCudfSerialization.HostConcatResult
 import com.nvidia.spark.rapids._
 import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
@@ -182,8 +182,10 @@ class SerializeConcatHostBuffersDeserializeBatch(
       case justRows: JustRowsColumnarBatch =>
         JCudfSerialization.writeRowsToStream(out, justRows.numRows())
       case scb: SpillableColumnarBatch =>
-        val table = withResource(scb.getColumnarBatch()) { cb =>
-          GpuColumnVector.from(cb)
+        val table = withRetryNoSplit[Table] {
+          withResource(scb.getColumnarBatch()) { cb =>
+            GpuColumnVector.from(cb)
+          }
         }
         withResou

Error stack

E                   Caused by: java.lang.NullPointerException
E                   	at com.nvidia.spark.rapids.RmmRapidsRetryIterator$RmmRapidsRetryIterator.$anonfun$next$2(RmmRapidsRetryIterator.scala:635)
E                   	at com.nvidia.spark.rapids.RmmRapidsRetryIterator$RmmRapidsRetryIterator.$anonfun$next$2$adapted(RmmRapidsRetryIterator.scala:630)
E                   	at scala.Option.foreach(Option.scala:407)
E                   	at com.nvidia.spark.rapids.RmmRapidsRetryIterator$RmmRapidsRetryIterator.next(RmmRapidsRetryIterator.scala:630)
E                   	at com.nvidia.spark.rapids.RmmRapidsRetryIterator$RmmRapidsRetryAutoCloseableIterator.next(RmmRapidsRetryIterator.scala:553)
E                   	at com.nvidia.spark.rapids.RmmRapidsRetryIterator$.drainSingleWithVerification(RmmRapidsRetryIterator.scala:291)
E                   	at com.nvidia.spark.rapids.RmmRapidsRetryIterator$.withRetryNoSplit(RmmRapidsRetryIterator.scala:185)
E                   	at org.apache.spark.sql.rapids.execution.SerializeConcatHostBuffersDeserializeBatch.$anonfun$doWriteObject$1(GpuBroadcastExchangeExec.scala:186)
E                   	at org.apache.spark.sql.rapids.execution.SerializeConcatHostBuffersDeserializeBatch.$anonfun$doWriteObject$1$adapted(GpuBroadcastExchangeExec.scala:181)
E                   	at scala.Option.map(Option.scala:230)
E                   	at org.apache.spark.sql.rapids.execution.SerializeConcatHostBuffersDeserializeBatch.doWriteObject(GpuBroadcastExchangeExec.scala:181)
E                   	at org.apache.spark.sql.rapids.execution.SerializeConcatHostBuffersDeserializeBatch.writeObject(GpuBroadcastExchangeExec.scala:156)
E                   	at sun.reflect.GeneratedMethodAccessor276.invoke(Unknown Source)

It is a random issue. The test will pass if running them separately. And the failing tests vary from run to run, but all are caused by NPE. So I guess this would be related to the order of task context initialization and calling the doWriteObject.

=========================== short test summary info ============================
FAILED ../../src/main/python/aqe_test.py::test_aqe_struct_self_join[DATAGEN_SEED=1737514884, TZ=UTC, INJECT_OOM, IGNORE_ORDER({'local': True})]
FAILED ../../src/main/python/aqe_test.py::test_aqe_join_reused_exchange_inequality_condition[left anti][DATAGEN_SEED=1737514884, TZ=UTC, INJECT_OOM, IGNORE_ORDER({'local': True}), ALLOW_NON_GPU(BroadcastNestedLoopJoinExec,Cast,DateSub)]
FAILED ../../src/main/python/aqe_test.py::test_aqe_join_with_dpp[DATAGEN_SEED=1737514884, TZ=UTC, INJECT_OOM, IGNORE_ORDER({'local': True})]
FAILED ../../src/main/python/aqe_test.py::test_aqe_join_with_dpp_multi_columns[DATAGEN_SEED=1737514884, TZ=UTC, INJECT_OOM, IGNORE_ORDER({'local': True})]
=== 4 failed, 14 passed, 2 skipped, 38459 deselected, 13 warnings in 40.55s ====

FAILED ../../src/main/python/aqe_test.py::test_aqe_join_reused_exchange_inequality_condition[cross][DATAGEN_SEED=1737515272, TZ=UTC, INJECT_OOM, IGNORE_ORDER({'local': True}), ALLOW_NON_GPU(BroadcastNestedLoopJoinExec,Cast,DateSub)]
FAILED ../../src/main/python/aqe_test.py::test_aqe_join_reused_exchange_inequality_condition[left semi][DATAGEN_SEED=1737515272, TZ=UTC, INJECT_OOM, IGNORE_ORDER({'local': True}), ALLOW_NON_GPU(BroadcastNestedLoopJoinExec,Cast,DateSub)]
=== 2 failed, 16 passed, 2 skipped, 38459 deselected, 13 warnings in 41.17s ====

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants