diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java index 5f3c3385a9431..af742fe8cb24a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -21,13 +21,20 @@ /** * A 'reducer' for output of user-defined functions. * - * A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x), - * if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x. + * @see ReducibleFunction + * + * A user defined function f_source(x) is 'reducible' on another user_defined function f_target(x) if + * + * * @param reducer input type * @param reducer output type * @since 4.0.0 */ @Evolving public interface Reducer { - O reduce(I arg1); + O reduce(I arg); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 6d57909e1b984..9d2215c1167cc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -17,23 +17,34 @@ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; -import scala.Option; /** * Base class for user-defined functions that can be 'reduced' on another function. * * A function f_source(x) is 'reducible' on another function f_target(x) if - * there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. - * + *
    + *
  • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
  • + *
  • More generally, there exists two reducer functions r1(x) and r2(x) such that + * r1(f_source(x)) = r2(f_target(x)) for all input x.
  • + *
*

* Examples: *

    - *
  • Bucket functions + *
  • Bucket functions where one side has reducer *
      *
    • f_source(x) = bucket(4, x)
    • *
    • f_target(x) = bucket(2, x)
    • - *
    • r(x) = x / 2
    • + *
    • r(x) = x % 2
    • *
    + * + *
  • Bucket functions where both sides have reducer + *
      + *
    • f_source(x) = bucket(16, x)
    • + *
    • f_target(x) = bucket(12, x)
    • + *
    • r1(x) = x % 4
    • + *
    • r2(x) = x % 4
    • + *
    + * *
  • Date functions *
      *
    • f_source(x) = days(x)
    • @@ -49,24 +60,42 @@ public interface ReducibleFunction { /** - * If this function is 'reducible' on another function, return the {@link Reducer} function. + * This method is for bucket functions. + * + * If this bucket function is 'reducible' on another bucket function, return the {@link Reducer} function. *

      - * Example: + * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) *

        - *
      • this_function = bucket(4, x) - *
      • other function = bucket(2, x) + *
      • thisFunction = bucket
      • + *
      • otherFunction = bucket
      • + *
      • thisNumBuckets = Int(4)
      • + *
      • otherNumBuckets = Int(2)
      • *
      - * Invoke with arguments + * + * @param otherFunction the other bucket function + * @param thisNumBuckets number of buckets for this bucket function + * @param otherNumBuckets number of buckets for the other bucket function + * @return a reduction function if it is reducible, null if not + */ + default Reducer reducer(ReducibleFunction otherFunction, int thisNumBuckets, int otherNumBuckets) { + return reducer(otherFunction); + } + + /** + * This method is for all other functions. + * + * If this function is 'reducible' on another function, return the {@link Reducer} function. + *

      + * Example of reducing f_source = days(x) on f_target = hours(x) *

        - *
      • other = bucket
      • - *
      • this param = Int(4)
      • - *
      • other param = Int(2)
      • + *
      • thisFunction = days
      • + *
      • otherFunction = hours
      • *
      - * @param other the other function - * @param thisParam param for this function - * @param otherParam param for the other function - * @return a reduction function if it is reducible, none if not + * + * @param otherFunction the other function + * @return a reduction function if it is reducible, null if not. */ - Option> reducer(ReducibleFunction other, Option thisParam, - Option otherParam); + default Reducer reducer(ReducibleFunction otherFunction) { + return reducer(otherFunction, 0, 0); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 5dc048f5dafbd..eff0a0ddfe71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -70,9 +70,10 @@ case class TransformExpression( } else { (function, other.function) match { case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => - val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt) - val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt) - reducer.isDefined || otherReducer.isDefined + val reducer = f.reducer(o, numBucketsOpt.getOrElse(0), other.numBucketsOpt.getOrElse(0)) + val otherReducer = + o.reducer(f, other.numBucketsOpt.getOrElse(0), numBucketsOpt.getOrElse(0)) + reducer != null || otherReducer != null case _ => false } } @@ -90,7 +91,10 @@ case class TransformExpression( def reducers(other: TransformExpression): Option[Reducer[_, _]] = { (function, other.function) match { case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => - e1.reducer(e2, numBucketsOpt, other.numBucketsOpt) + val reducer = e1.reducer(e2, + numBucketsOpt.getOrElse(0), + other.numBucketsOpt.getOrElse(0)) + Option(reducer) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0bdb1fde67b22..33ea5ad52cd5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -851,25 +851,23 @@ case class KeyGroupedShuffleSpec( * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is * 'reducible' on the corresponding partition expression function of the other shuffle spec. *

      - * If a value is returned, there must be one Option[[Reducer]] per partition expression. + * If a value is returned, there must be one [[Reducer]] per partition expression. * A None value in the set indicates that the particular partition expression is not reducible * on the corresponding expression on the other shuffle spec. *

      * Returning none also indicates that none of the partition expressions can be reduced on the * corresponding expression on the other shuffle spec. + * + * @param other other key-grouped shuffle spec */ - def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { - other match { - case otherSpec: KeyGroupedShuffleSpec => - val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { - case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) - case (_, _) => None - } + def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { + val results = partitioning.expressions.zip(other.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) + case (_, _) => None + } - // optimize to not return a value, if none of the partition expressions are reducible - if (results.forall(p => p.isEmpty)) None else Some(results) - case _ => None - } + // optimize to not return a value, if none of the partition expressions are reducible + if (results.forall(p => p.isEmpty)) None else Some(results) } override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && @@ -883,8 +881,8 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue(row: InternalRow, - expressions: Seq[Expression], - reducers: Seq[Option[Reducer[_, _]]]): + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[_, _]]]): InternalRowComparableWrapper = { val partitionVals = row.toSeq(expressions.map(_.dataType)) val reducedRow = partitionVals.zip(reducers).map{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b34990e1b7166..7ff682178ad27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -569,8 +569,8 @@ case class EnsureRequirements( } private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], - expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[_, _]]]]) = { + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { case Some(reducers) => commonPartValues.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 403081b66551e..ec275fe101fd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1475,6 +1475,168 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-47094: Support compatible buckets with common divisor") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((6, 4), (4, 6)), + ((6, 6), (4, 4)), + ((4, 4), (6, 6)), + ((4, 6), (6, 4))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, columns2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt + val expectedBuckets = gcd(table1buckets1, table2buckets1) * + gcd(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp") + )) + } + } + } + } + test("SPARK-47094: Support compatible buckets with less join keys than partition keys") { val table1 = "tab1e1" val table2 = "table2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 7a77c00b577fc..176e597fe44bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -87,21 +87,31 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In } override def reducer(func: ReducibleFunction[_, _], - thisNumBuckets: Option[_], - otherNumBuckets: Option[_]): Option[Reducer[Int, Int]] = { - (thisNumBuckets, otherNumBuckets) match { - case (Some(thisNumBucketsVal: Int), Some(otherNumBucketsVal: Int)) - if func == BucketFunction && - ((thisNumBucketsVal > otherNumBucketsVal) && - (thisNumBucketsVal % otherNumBucketsVal == 0)) => - Some(BucketReducer(thisNumBucketsVal, otherNumBucketsVal)) - case _ => None + thisNumBuckets: Int, + otherNumBuckets: Int): Reducer[Int, Int] = { + + if (func == BucketFunction) { + if ((thisNumBuckets > otherNumBuckets) + && (thisNumBuckets % otherNumBuckets == 0)) { + BucketReducer(thisNumBuckets, otherNumBuckets) + } else { + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) + if (gcd != thisNumBuckets) { + BucketReducer(thisNumBuckets, gcd) + } else { + null + } + } + } else { + null } } + + private def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt } -case class BucketReducer(thisNumBuckets: Int, otherNumBuckets: Int) extends Reducer[Int, Int] { - override def reduce(bucket: Int): Int = bucket % otherNumBuckets +case class BucketReducer(thisNumBuckets: Int, divisor: Int) extends Reducer[Int, Int] { + override def reduce(bucket: Int): Int = bucket % divisor } object UnboundStringSelfFunction extends UnboundFunction {