Skip to content

Commit

Permalink
Review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Jul 8, 2024
1 parent 8940f6e commit 361a725
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,29 @@ case class BatchScanExec(

override def outputPartitioning: Partitioning = {
super.outputPartitioning match {
case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined =>
// We allow duplicated partition values if
// `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
val newPartValues = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
case k: KeyGroupedPartitioning =>
val expressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i))
case _ => k.expressions
}
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case k: KeyGroupedPartitioning if spjParams.joinKeyPositions.isDefined =>
val expressions = spjParams.joinKeyPositions.get.map(i => k.expressions(i))
val newPartValues = k.partitionValues.map{r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
spjParams.joinKeyPositions.get, r)
InternalRowComparableWrapper(projectedRow, expressions)
}.distinct.map(_.row)

val newPartValues = spjParams.commonPartitionValues match {
case Some(commonPartValues) =>
// We allow duplicated partition values if
// `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
commonPartValues.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
case None =>
spjParams.joinKeyPositions match {
case Some(projectionPositions) => k.partitionValues.map{r =>
val projectedRow = KeyGroupedPartitioning.project(expressions,
projectionPositions, r)
InternalRowComparableWrapper(projectedRow, expressions)
}.distinct.map(_.row)
case _ => k.partitionValues
}
}
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case p => p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ case class EnsureRequirements(
case ((child, dist), idx) =>
if (bestSpecOpt.isDefined && bestSpecOpt.get.isCompatibleWith(specs(idx))) {
bestSpecOpt match {
// If keyGroupCompatible = false, we can still perform SPJ
// by shuffling the other side based on join keys (see the else case below).
// Hence we need to ensure that after this call, the outputPartitioning of the
// partitioned side's BatchScanExec is grouped by join keys to match,
// and we do that by pushing down the join keys
case Some(KeyGroupedShuffleSpec(_, _, Some(joinKeyPositions))) =>
populateJoinKeyPositions(child, Some(joinKeyPositions))
case _ => child
Expand Down Expand Up @@ -583,8 +588,9 @@ case class EnsureRequirements(
}


private def populateJoinKeyPositions(plan: SparkPlan,
joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match {
private def populateJoinKeyPositions(
plan: SparkPlan,
joinKeyPositions: Option[Seq[Int]]): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
Expand Down

0 comments on commit 361a725

Please sign in to comment.