Skip to content

Commit

Permalink
Replace AsIsExchangeExec with AsIsShuffleExchangeExec
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Oct 24, 2023
1 parent f53ee17 commit 2e5c792
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.plans.physical

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan}
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseSampledStage}
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -456,6 +456,7 @@ object QueryExecution {
PlanSubqueries(sparkSession),
RemoveRedundantProjects,
EnsureRequirements(),
ReuseSampledStage,
// `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the
// sort order of each node is checked to be valid.
ReplaceHashWithSortAgg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ case class AdaptiveSparkPlanExec(
CoalesceBucketsInJoin,
RemoveRedundantProjects,
ensureRequirements,
ReuseSampledStage,
AdjustShuffleExchangePosition,
ValidateSparkPlan,
ReplaceHashWithSortAgg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,8 +588,8 @@ case class EnsureRequirements(
def apply(plan: SparkPlan): SparkPlan = {
val newPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(_: RangePartitioning, child, _, _)
if ! child.isInstanceOf[AsIsExchangeExec] =>
operator.withNewChildren(Seq(AsIsExchangeExec(child)))
if ! child.isInstanceOf[AsIsShuffleExchangeExec] =>
operator.withNewChildren(Seq(AsIsShuffleExchangeExec(child)))
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _)
if optimizeOutRepartition &&
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
Expand Down Expand Up @@ -637,3 +637,11 @@ case class EnsureRequirements(
}
}
}

object ReuseSampledStage extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(_: RangePartitioning, child, _, _)
if !child.isInstanceOf[AsIsShuffleExchangeExec] =>
operator.withNewChildren(Seq(AsIsShuffleExchangeExec(child)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,3 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
|""".stripMargin
}
}

case class AsIsExchangeExec(override val child: SparkPlan) extends Exchange {
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def doExecute(): RDD[InternalRow] = child.doExecute()
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,9 @@ object ShuffleExchangeExec {
position += 1
position
}
case a: AsIsPartitioning =>
val projection = UnsafeProjection.create(a.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
Expand Down

0 comments on commit 2e5c792

Please sign in to comment.