diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index 3f09e2fe2..8f4c7424e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.comet +import org.apache.spark.TaskContext import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} @@ -104,7 +105,17 @@ case class CometTakeOrderedAndProjectExec( val topKAndProjection = CometExecUtils .getProjectionNativePlan(projectList, output, sortOrder, child, limit) .get - CometExec.getCometIterator(Seq(iter), topKAndProjection) + val it = CometExec.getCometIterator(Seq(iter), topKAndProjection) + setSubqueries(it.id, this) + + Option(TaskContext.get()).foreach { context => + context.addTaskCompletionListener[Unit] { _ => + it.close() + cleanSubqueries(it.id, this) + } + } + + it } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 1c9bf9947..60ff17eb4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -89,6 +89,26 @@ abstract class CometExec extends CometPlan { CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName)) (total, rows) } + + protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = { + sparkPlan.children.foreach(setSubqueries(planId, _)) + + sparkPlan.expressions.foreach { + _.collect { case sub: ScalarSubquery => + CometScalarSubquery.setSubquery(planId, sub) + } + } + } + + protected def cleanSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = { + sparkPlan.children.foreach(cleanSubqueries(planId, _)) + + sparkPlan.expressions.foreach { + _.collect { case sub: ScalarSubquery => + CometScalarSubquery.removeSubquery(planId, sub) + } + } + } } object CometExec { @@ -177,26 +197,6 @@ abstract class CometNativeExec extends CometExec { runningSubqueries.clear() } - private def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = { - sparkPlan.children.foreach(setSubqueries(planId, _)) - - sparkPlan.expressions.foreach { - _.collect { case sub: ScalarSubquery => - CometScalarSubquery.setSubquery(planId, sub) - } - } - } - - private def cleanSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = { - sparkPlan.children.foreach(cleanSubqueries(planId, _)) - - sparkPlan.expressions.foreach { - _.collect { case sub: ScalarSubquery => - CometScalarSubquery.removeSubquery(planId, sub) - } - } - } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { serializedPlanOpt.plan match { case None => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 0e1204df8..40ec349e2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark40Plus} +import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark35Plus, isSpark40Plus} class CometExecSuite extends CometTestBase { import testImplicits._ @@ -63,6 +63,27 @@ class CometExecSuite extends CometTestBase { } } + test("subquery execution under CometTakeOrderedAndProjectExec should not fail") { + assume(isSpark35Plus, "SPARK-45584 is fixed in Spark 3.5+") + + withTable("t1") { + sql(""" + |CREATE TABLE t1 USING PARQUET + |AS SELECT * FROM VALUES + |(1, "a"), + |(2, "a"), + |(3, "a") t(id, value) + |""".stripMargin) + val df = sql(""" + |WITH t2 AS ( + | SELECT * FROM t1 ORDER BY id + |) + |SELECT *, (SELECT COUNT(*) FROM t2) FROM t2 LIMIT 10 + |""".stripMargin) + checkSparkAnswer(df) + } + } + test("Window range frame should fall back to Spark") { val df = Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2"))