Skip to content

Commit

Permalink
fix: Use RDD partition index (#1112)
Browse files Browse the repository at this point in the history
* fix: Use RDD partition index

* fix

* fix

* fix
  • Loading branch information
viirya authored Nov 25, 2024
1 parent 7b1a290 commit 5400fd7
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 21 deletions.
12 changes: 10 additions & 2 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ import org.apache.comet.vector.NativeUtil
* The input iterators producing sequence of batches of Arrow Arrays.
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* @param numParts
* The number of partitions.
* @param partitionIndex
* The index of the partition.
*/
class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode)
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIndex: Int)
extends Iterator[ColumnarBatch] {

private val nativeLib = new Native()
Expand Down Expand Up @@ -92,11 +98,13 @@ class CometExecIterator(
}

def getNextBatch(): Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)

nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
val ctx = TaskContext.get()
nativeLib.executePlan(ctx.stageId(), ctx.partitionId(), plan, arrayAddrs, schemaAddrs)
nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ object CometExecUtils {
childPlan: RDD[ColumnarBatch],
outputAttribute: Seq[Attribute],
limit: Int): RDD[ColumnarBatch] = {
childPlan.mapPartitionsInternal { iter =>
val numParts = childPlan.getNumPartitions
childPlan.mapPartitionsWithIndexInternal { case (idx, iter) =>
val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit).get
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp)
CometExec.getCometIterator(Seq(iter), outputAttribute.length, limitOp, numParts, idx)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,13 @@ case class CometTakeOrderedAndProjectExec(
val localTopK = if (orderingSatisfies) {
CometExecUtils.getNativeLimitRDD(childRDD, child.output, limit)
} else {
childRDD.mapPartitionsInternal { iter =>
val numParts = childRDD.getNumPartitions
childRDD.mapPartitionsWithIndexInternal { case (idx, iter) =>
val topK =
CometExecUtils
.getTopKNativePlan(child.output, sortOrder, child, limit)
.get
CometExec.getCometIterator(Seq(iter), child.output.length, topK)
CometExec.getCometIterator(Seq(iter), child.output.length, topK, numParts, idx)
}
}

Expand All @@ -102,7 +103,7 @@ case class CometTakeOrderedAndProjectExec(
val topKAndProjection = CometExecUtils
.getProjectionNativePlan(projectList, child.output, sortOrder, child, limit)
.get
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection)
val it = CometExec.getCometIterator(Seq(iter), output.length, topKAndProjection, 1, 0)
setSubqueries(it.id, this)

Option(TaskContext.get()).foreach { context =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,20 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
*/
private[spark] class ZippedPartitionsRDD(
sc: SparkContext,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch],
var f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch],
var zipRdds: Seq[RDD[ColumnarBatch]],
preservesPartitioning: Boolean = false)
extends ZippedPartitionsBaseRDD[ColumnarBatch](sc, zipRdds, preservesPartitioning) {

// We need to get the number of partitions in `compute` but `getNumPartitions` is not available
// on the executors. So we need to capture it here.
private val numParts: Int = this.getNumPartitions

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
val iterators =
zipRdds.zipWithIndex.map(pair => pair._1.iterator(partitions(pair._2), context))
f(iterators)
f(iterators, numParts, s.index)
}

override def clearDependencies(): Unit = {
Expand All @@ -52,7 +56,8 @@ private[spark] class ZippedPartitionsRDD(

object ZippedPartitionsRDD {
def apply(sc: SparkContext, rdds: Seq[RDD[ColumnarBatch]])(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
f: (Seq[Iterator[ColumnarBatch]], Int, Int) => Iterator[ColumnarBatch])
: RDD[ColumnarBatch] =
withScope(sc) {
new ZippedPartitionsRDD(sc, f, rdds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,14 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
outputPartitioning: Partitioning,
serializer: Serializer,
metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = {
val numParts = rdd.getNumPartitions
val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch](
rdd.map(
(0, _)
), // adding fake partitionId that is always 0 because ShuffleDependency requires it
serializer = serializer,
shuffleWriterProcessor =
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics),
new CometShuffleWriteProcessor(outputPartitioning, outputAttributes, metrics, numParts),
shuffleType = CometNativeShuffle,
partitioner = new Partitioner {
override def numPartitions: Int = outputPartitioning.numPartitions
Expand Down Expand Up @@ -449,7 +450,8 @@ object CometShuffleExchangeExec extends ShimCometShuffleExchangeExec {
class CometShuffleWriteProcessor(
outputPartitioning: Partitioning,
outputAttributes: Seq[Attribute],
metrics: Map[String, SQLMetric])
metrics: Map[String, SQLMetric],
numParts: Int)
extends ShimCometShuffleWriteProcessor {

private val OFFSET_LENGTH = 8
Expand Down Expand Up @@ -499,7 +501,9 @@ class CometShuffleWriteProcessor(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
nativePlan,
nativeMetrics)
nativeMetrics,
numParts,
context.partitionId())

while (cometIter.hasNext) {
cometIter.next()
Expand Down
36 changes: 29 additions & 7 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -120,20 +120,37 @@ object CometExec {
def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator): CometExecIterator = {
getCometIterator(inputs, numOutputCols, nativePlan, CometMetricNode(Map.empty))
nativePlan: Operator,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
getCometIterator(
inputs,
numOutputCols,
nativePlan,
CometMetricNode(Map.empty),
numParts,
partitionIdx)
}

def getCometIterator(
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
nativePlan: Operator,
nativeMetrics: CometMetricNode): CometExecIterator = {
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIdx: Int): CometExecIterator = {
val outputStream = new ByteArrayOutputStream()
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray
new CometExecIterator(newIterId, inputs, numOutputCols, bytes, nativeMetrics)
new CometExecIterator(
newIterId,
inputs,
numOutputCols,
bytes,
nativeMetrics,
numParts,
partitionIdx)
}

/**
Expand Down Expand Up @@ -214,13 +231,18 @@ abstract class CometNativeExec extends CometExec {
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)

def createCometExecIter(inputs: Seq[Iterator[ColumnarBatch]]): CometExecIterator = {
def createCometExecIter(
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
partitionIndex: Int): CometExecIterator = {
val it = new CometExecIterator(
CometExec.newIterId,
inputs,
output.length,
serializedPlanCopy,
nativeMetrics)
nativeMetrics,
numParts,
partitionIndex)

setSubqueries(it.id, this)

Expand Down Expand Up @@ -295,7 +317,7 @@ abstract class CometNativeExec extends CometExec {
throw new CometRuntimeException(s"No input for CometNativeExec:\n $this")
}

ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
}
}

Expand Down
4 changes: 3 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometNativeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class CometNativeSuite extends CometTestBase {
override def next(): ColumnarBatch = throw new NullPointerException()
}),
1,
limitOp)
limitOp,
1,
0)
cometIter.next()
cometIter.close()
value
Expand Down

0 comments on commit 5400fd7

Please sign in to comment.