Skip to content

Commit

Permalink
fix: subquery execution under CometTakeOrderedAndProjectExec should n…
Browse files Browse the repository at this point in the history
…ot fail (#748)
  • Loading branch information
viirya authored Aug 1, 2024
1 parent 2318a8e commit 2c9be0a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.spark.sql.comet

import org.apache.spark.TaskContext
import org.apache.spark.rdd.{ParallelCollectionRDD, RDD}
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder}
Expand Down Expand Up @@ -104,7 +105,17 @@ case class CometTakeOrderedAndProjectExec(
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), topKAndProjection)
val it = CometExec.getCometIterator(Seq(iter), topKAndProjection)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
context.addTaskCompletionListener[Unit] { _ =>
it.close()
cleanSubqueries(it.id, this)
}
}

it
}
}
}
Expand Down
40 changes: 20 additions & 20 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ abstract class CometExec extends CometPlan {
CometExec.decodeBatches(countAndBytes._2, this.getClass.getSimpleName))
(total, rows)
}

protected def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
sparkPlan.children.foreach(setSubqueries(planId, _))

sparkPlan.expressions.foreach {
_.collect { case sub: ScalarSubquery =>
CometScalarSubquery.setSubquery(planId, sub)
}
}
}

protected def cleanSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
sparkPlan.children.foreach(cleanSubqueries(planId, _))

sparkPlan.expressions.foreach {
_.collect { case sub: ScalarSubquery =>
CometScalarSubquery.removeSubquery(planId, sub)
}
}
}
}

object CometExec {
Expand Down Expand Up @@ -177,26 +197,6 @@ abstract class CometNativeExec extends CometExec {
runningSubqueries.clear()
}

private def setSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
sparkPlan.children.foreach(setSubqueries(planId, _))

sparkPlan.expressions.foreach {
_.collect { case sub: ScalarSubquery =>
CometScalarSubquery.setSubquery(planId, sub)
}
}
}

private def cleanSubqueries(planId: Long, sparkPlan: SparkPlan): Unit = {
sparkPlan.children.foreach(cleanSubqueries(planId, _))

sparkPlan.expressions.foreach {
_.collect { case sub: ScalarSubquery =>
CometScalarSubquery.removeSubquery(planId, sub)
}
}
}

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
serializedPlanOpt.plan match {
case None =>
Expand Down
23 changes: 22 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String

import org.apache.comet.CometConf
import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark40Plus}
import org.apache.comet.CometSparkSessionExtensions.{isSpark33Plus, isSpark34Plus, isSpark35Plus, isSpark40Plus}

class CometExecSuite extends CometTestBase {
import testImplicits._
Expand All @@ -63,6 +63,27 @@ class CometExecSuite extends CometTestBase {
}
}

test("subquery execution under CometTakeOrderedAndProjectExec should not fail") {
assume(isSpark35Plus, "SPARK-45584 is fixed in Spark 3.5+")

withTable("t1") {
sql("""
|CREATE TABLE t1 USING PARQUET
|AS SELECT * FROM VALUES
|(1, "a"),
|(2, "a"),
|(3, "a") t(id, value)
|""".stripMargin)
val df = sql("""
|WITH t2 AS (
| SELECT * FROM t1 ORDER BY id
|)
|SELECT *, (SELECT COUNT(*) FROM t2) FROM t2 LIMIT 10
|""".stripMargin)
checkSparkAnswer(df)
}
}

test("Window range frame should fall back to Spark") {
val df =
Seq((1L, "1"), (1L, "1"), (2147483650L, "1"), (3L, "2"), (2L, "1"), (2147483650L, "2"))
Expand Down

0 comments on commit 2c9be0a

Please sign in to comment.