Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Dec 20, 2024
1 parent 2f37bc9 commit bebae9f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,11 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
.filter(KeyGroupedPartitioning.isReference)

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that join keys (required clustering keys)
// overlap with partition keys (KeyGroupedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x)))
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
Expand Down Expand Up @@ -792,9 +790,16 @@ case class KeyGroupedShuffleSpec(
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}
partitioning.expressions.map { e =>
val leaves = e.collectLeaves().filter(KeyGroupedPartitioning.isReference)
assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}")
distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty)
val leaves = e.collectLeaves()
val attrs = leaves.filter(KeyGroupedPartitioning.isReference)
assert(leaves.size == 1 || attrs.size == 1,
s"Expected exactly one reference or child from $e, but found ${leaves.size}")
val head = if (attrs.size == 1) {
attrs.head
} else {
leaves.head
}
distKeyToPos.getOrElse(head.canonicalized, mutable.BitSet.empty)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,6 @@ case class EnsureRequirements(
distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = {
def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
val attributes = partitioning.expressions.flatMap(_.collectLeaves())
.filter(KeyGroupedPartitioning.isReference)
val clustering = distribution.clustering

val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2533,7 +2533,9 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
s"(40, 80, 'ddd')")

withSQLConf(
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") {
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {

val df =
sql(
Expand Down

0 comments on commit bebae9f

Please sign in to comment.