From bebae9f3d33f1b967c480e315dc12800c2c52179 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 19 Dec 2024 18:29:43 -0800 Subject: [PATCH] Fix tests --- .../catalyst/plans/physical/partitioning.scala | 17 +++++++++++------ .../execution/exchange/EnsureRequirements.scala | 1 - .../connector/KeyGroupedPartitioningSuite.scala | 4 +++- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index de9bc645a83b1..1efcc6e36a181 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -383,13 +383,11 @@ case class KeyGroupedPartitioning( } else { // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - .filter(KeyGroupedPartitioning.isReference) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { // check that join keys (required clustering keys) // overlap with partition keys (KeyGroupedPartitioning attributes) - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - expressions.forall(_.collectLeaves().size == 1) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) } else { attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) } @@ -792,9 +790,16 @@ case class KeyGroupedShuffleSpec( distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos) } partitioning.expressions.map { e => - val leaves = e.collectLeaves().filter(KeyGroupedPartitioning.isReference) - assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}") - distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty) + val leaves = e.collectLeaves() + val attrs = leaves.filter(KeyGroupedPartitioning.isReference) + assert(leaves.size == 1 || attrs.size == 1, + s"Expected exactly one reference or child from $e, but found ${leaves.size}") + val head = if (attrs.size == 1) { + attrs.head + } else { + leaves.head + } + distKeyToPos.getOrElse(head.canonicalized, mutable.BitSet.empty) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 10dfa21c1b57c..8ec903f8e61da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -627,7 +627,6 @@ case class EnsureRequirements( distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = { def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = { val attributes = partitioning.expressions.flatMap(_.collectLeaves()) - .filter(KeyGroupedPartitioning.isReference) val clustering = distribution.clustering val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 81534ef73588c..1ffd20644307f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -2533,7 +2533,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(40, 80, 'ddd')") withSQLConf( - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") { + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { val df = sql(