Skip to content

Commit

Permalink
Inject AsIsExchangeExec
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Oct 13, 2023
1 parent 0257b77 commit f53ee17
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.catalyst.plans.physical

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}
Expand Down Expand Up @@ -258,6 +258,20 @@ case object SinglePartition extends Partitioning {
SinglePartitionShuffleSpec
}

case class AsIsPartitioning(numPartitions: Int)
extends LeafExpression with Partitioning with Unevaluable {

override def nullable: Boolean = false
override def dataType: DataType = IntegerType

override def satisfies0(required: Distribution): Boolean = false

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
AsIsShuffleSpec(this, distribution)

def partitionIdExpression: Expression = SparkPartitionID()
}

/**
* Represents a partitioning where rows are split up across partitions based on the hash
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
Expand Down Expand Up @@ -642,6 +656,13 @@ case class RangeShuffleSpec(
}
}

case class AsIsShuffleSpec(
partitioning: AsIsPartitioning, distribution: ClusteredDistribution) extends ShuffleSpec {
override def numPartitions: Int = partitioning.numPartitions
override def isCompatibleWith(other: ShuffleSpec): Boolean = false
override def canCreatePartitioning: Boolean = false
}

case class HashShuffleSpec(
partitioning: HashPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@ case class EnsureRequirements(

def apply(plan: SparkPlan): SparkPlan = {
val newPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(_: RangePartitioning, child, _, _)
if ! child.isInstanceOf[AsIsExchangeExec] =>
operator.withNewChildren(Seq(AsIsExchangeExec(child)))
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _)
if optimizeOutRepartition &&
(shuffleOrigin == REPARTITION_BY_COL || shuffleOrigin == REPARTITION_BY_NUM) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,11 @@ case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchan
|""".stripMargin
}
}

case class AsIsExchangeExec(override val child: SparkPlan) extends Exchange {
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
override def outputPartitioning: Partitioning = child.outputPartitioning
override protected def doExecute(): RDD[InternalRow] = child.doExecute()
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,86 @@ case object REBALANCE_PARTITIONS_BY_NONE extends ShuffleOrigin
// the output needs to be partitioned by the given columns.
case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin

case class AsIsShuffleExchangeExec(override val child: SparkPlan) extends ShuffleExchangeLike {
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private[sql] lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")
) ++ readMetrics ++ writeMetrics

override def numMappers: Int = child.outputPartitioning.numPartitions
override def numPartitions: Int = child.outputPartitioning.numPartitions
override def advisoryPartitionSize: Option[Long] = None
override def shuffleOrigin: ShuffleOrigin = REPARTITION_BY_NUM

override def outputPartitioning: Partitioning =
AsIsPartitioning(child.outputPartitioning.numPartitions)

override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = {
new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs)
}

private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))

@transient lazy val inputRDD: RDD[InternalRow] = child.execute()

// 'mapOutputStatisticsFuture' is only needed when enable AQE.
@transient
override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
if (inputRDD.getNumPartitions == 0) {
Future.successful(null)
} else {
sparkContext.submitMapStage(shuffleDependency)
}
}

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
Statistics(dataSize, Some(rowCount))
}

/**
* A [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
@transient
lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, InternalRow] = {
val dep = ShuffleExchangeExec.prepareShuffleDependency(
inputRDD,
child.output,
outputPartitioning,
serializer,
writeMetrics)
metrics("numPartitions").set(dep.partitioner.numPartitions)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(
sparkContext, executionId, metrics("numPartitions") :: Nil)
dep
}

/**
* Caches the created ShuffleRowRDD so we can reuse that.
*/
private var cachedShuffleRDD: ShuffledRowRDD = null

protected override def doExecute(): RDD[InternalRow] = {
// Returns the same ShuffleRowRDD if this plan is used by multiple plans.
if (cachedShuffleRDD == null) {
cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics)
}
cachedShuffleRDD
}

override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)
}

/**
* Performs a shuffle that will result in the desired partitioning.
*/
Expand Down Expand Up @@ -275,6 +355,7 @@ object ShuffleExchangeExec {
: ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case AsIsPartitioning(n) => new PartitionIdPassthrough(n)
case HashPartitioning(_, n) =>
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
Expand Down

0 comments on commit f53ee17

Please sign in to comment.