diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java index 5f3c3385a9431..af742fe8cb24a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -21,13 +21,20 @@ /** * A 'reducer' for output of user-defined functions. * - * A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x), - * if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x. + * @see ReducibleFunction + * + * A user defined function f_source(x) is 'reducible' on another user_defined function f_target(x) if + *
* Examples: *
- * Example: + * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) *
+ * Example of reducing f_source = days(x) on f_target = hours(x) *
- * If a value is returned, there must be one Option[[Reducer]] per partition expression. + * If a value is returned, there must be one [[Reducer]] per partition expression. * A None value in the set indicates that the particular partition expression is not reducible * on the corresponding expression on the other shuffle spec. *
* Returning none also indicates that none of the partition expressions can be reduced on the * corresponding expression on the other shuffle spec. + * + * @param other other key-grouped shuffle spec */ - def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { - other match { - case otherSpec: KeyGroupedShuffleSpec => - val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { - case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) - case (_, _) => None - } + def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { + val results = partitioning.expressions.zip(other.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) + case (_, _) => None + } - // optimize to not return a value, if none of the partition expressions are reducible - if (results.forall(p => p.isEmpty)) None else Some(results) - case _ => None - } + // optimize to not return a value, if none of the partition expressions are reducible + if (results.forall(p => p.isEmpty)) None else Some(results) } override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && @@ -883,8 +881,8 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue(row: InternalRow, - expressions: Seq[Expression], - reducers: Seq[Option[Reducer[_, _]]]): + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[_, _]]]): InternalRowComparableWrapper = { val partitionVals = row.toSeq(expressions.map(_.dataType)) val reducedRow = partitionVals.zip(reducers).map{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b34990e1b7166..7ff682178ad27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -569,8 +569,8 @@ case class EnsureRequirements( } private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], - expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[_, _]]]]) = { + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { case Some(reducers) => commonPartValues.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 403081b66551e..ec275fe101fd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1475,6 +1475,168 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-47094: Support compatible buckets with common divisor") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((6, 4), (4, 6)), + ((6, 6), (4, 4)), + ((4, 4), (6, 6)), + ((4, 6), (6, 4))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, columns2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt + val expectedBuckets = gcd(table1buckets1, table2buckets1) * + gcd(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp") + )) + } + } + } + } + test("SPARK-47094: Support compatible buckets with less join keys than partition keys") { val table1 = "tab1e1" val table2 = "table2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 7a77c00b577fc..176e597fe44bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -87,21 +87,31 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In } override def reducer(func: ReducibleFunction[_, _], - thisNumBuckets: Option[_], - otherNumBuckets: Option[_]): Option[Reducer[Int, Int]] = { - (thisNumBuckets, otherNumBuckets) match { - case (Some(thisNumBucketsVal: Int), Some(otherNumBucketsVal: Int)) - if func == BucketFunction && - ((thisNumBucketsVal > otherNumBucketsVal) && - (thisNumBucketsVal % otherNumBucketsVal == 0)) => - Some(BucketReducer(thisNumBucketsVal, otherNumBucketsVal)) - case _ => None + thisNumBuckets: Int, + otherNumBuckets: Int): Reducer[Int, Int] = { + + if (func == BucketFunction) { + if ((thisNumBuckets > otherNumBuckets) + && (thisNumBuckets % otherNumBuckets == 0)) { + BucketReducer(thisNumBuckets, otherNumBuckets) + } else { + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) + if (gcd != thisNumBuckets) { + BucketReducer(thisNumBuckets, gcd) + } else { + null + } + } + } else { + null } } + + private def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt } -case class BucketReducer(thisNumBuckets: Int, otherNumBuckets: Int) extends Reducer[Int, Int] { - override def reduce(bucket: Int): Int = bucket % otherNumBuckets +case class BucketReducer(thisNumBuckets: Int, divisor: Int) extends Reducer[Int, Int] { + override def reduce(bucket: Int): Int = bucket % divisor } object UnboundStringSelfFunction extends UnboundFunction {