diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ae39e2e183e4a..9950d336074d2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import scala.collection.immutable.ArraySeq import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.log10 @@ -149,7 +150,8 @@ private[spark] class KeyGroupedPartitioner( override val numPartitions: Int) extends Partitioner { override def getPartition(key: Any): Int = { val keys = key.asInstanceOf[Seq[Any]] - valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) + val normalizedKeys = ArraySeq.from(keys) + valueMap.getOrElseUpdate(normalizedKeys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index d37c9d9f6452a..98b5c641096fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.DataType /** @@ -30,7 +33,7 @@ import org.apache.spark.sql.types.DataType case class TransformExpression( function: BoundFunction, children: Seq[Expression], - numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable { + numBucketsOpt: Option[Int] = None) extends Expression { override def nullable: Boolean = true @@ -113,4 +116,32 @@ case class TransformExpression( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + + lazy val resolvedFunction: Option[Expression] = this match { + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => + Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, + Seq(Literal(numBuckets)) ++ arguments)) + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)) + case _ => None + } + + override def eval(input: InternalRow): Any = { + resolvedFunction match { + case Some(fn) => fn.eval(input) + case None => throw QueryExecutionErrors.cannotEvaluateExpressionError(this) + } + } + + /** + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. + * + * @param ctx a [[CodegenContext]] + * @param ev an [[ExprCode]] with unique terms. + * @return an [[ExprCode]] containing the Java source code to generate the given expression + */ + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index c6cfccb74c161..47e6427aa6de3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -161,7 +161,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { val declaredInputTypes = scalarFunc.inputTypes().toImmutableArraySeq val argClasses = declaredInputTypes.map(EncoderUtils.dataTypeJavaClass) findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { - case Some(m) if Modifier.isStatic(m.getModifiers) => + case Some(m) if isStatic(m) => StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, propagateNull = false, returnNullable = scalarFunc.isResultNullable, @@ -204,4 +204,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { None } } + + private def isStatic(m: Method) = { + val javaStatic = Modifier.isStatic(m.getModifiers) + val scalaObjModule = m.getDeclaringClass.getField("MODULE$") + val scalaStatic = scalaObjModule != null && Modifier.isStatic(scalaObjModule.getModifiers) + javaStatic || scalaStatic + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2364130f79e4c..93505d06b5b7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -870,12 +870,30 @@ case class KeyGroupedShuffleSpec( if (results.forall(p => p.isEmpty)) None else Some(results) } - override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && - // Only support partition expressions are AttributeReference for now - partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + override def canCreatePartitioning: Boolean = { + // Allow one side shuffle for SPJ for now only if partially-clustered is not enabled + // and for join keys less than partition keys only if transforms are not enabled. + val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + e: Expression => e.isInstanceOf[AttributeReference] + } else { + e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] + } + SQLConf.get.v2BucketingShuffleEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + partitioning.expressions.forall(checkExprType) + } + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map{ + case (c, e: TransformExpression) => TransformExpression( + e.function, Seq(c), e.numBucketsOpt) + case (c, _) => c + } + KeyGroupedPartitioning(newExpressions, + partitioning.numPartitions, + partitioning.partitionValues) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index ec275fe101fd6..9a9cd2a65ff13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1136,7 +1136,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = createJoinTestDF(Seq("arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (shuffle) { - assert(shuffles.size == 2, "partitioning with transform not work now") + assert(shuffles.size == 1, "partitioning with transform should trigger SPJ") } else { assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + " is not enabled") @@ -1931,22 +1931,19 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "(6, 50.0, cast('2023-02-01' as timestamp))") Seq(true, false).foreach { pushdownValues => - Seq(true, false).foreach { partiallyClustered => - withSQLConf( - SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key - -> partiallyClustered.toString, - SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { - val df = createJoinTestDF(Seq("id" -> "item_id")) - val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 1, "SPJ should be triggered") - checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), - Row(1, "aa", 30.0, 89.0), - Row(1, "aa", 40.0, 42.0), - Row(1, "aa", 40.0, 89.0), - Row(3, "bb", 10.0, 19.5))) - } + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) } } } @@ -1992,4 +1989,109 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-48012: one-side shuffle with partition transforms") { + val items_partitions = Array(bucket(2, "id"), identity("arrive_time")) + val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id")) + + Seq(items_partitions, items_partitions2).foreach { partition => + catalog.clearTables() + + createTable(items, itemsColumns, partition) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " + + "(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " + + "(5, 'ff', 32.1, cast('2020-03-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 10.7, cast('2020-01-01' as timestamp))," + + "(3, 19.5, cast('2020-02-01' as timestamp))," + + "(4, 56.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle side that does not report partitioning") + + checkAnswer(df, Seq( + Row(1, "bb", 30.0, 42.0), + Row(1, "aa", 40.0, 42.0), + Row(4, "ee", 15.5, 56.5))) + } + } + } + + test("SPARK-48012: one-side shuffle with partition transforms and pushdown values") { + val items_partitions = Array(bucket(2, "id"), identity("arrive_time")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " + + "(1, 'cc', 30.0, cast('2020-01-02' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 10.7, cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDown => { + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDown.toString) { + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle side that does not report partitioning") + + checkAnswer(df, Seq( + Row(1, "bb", 30.0, 42.0), + Row(1, "aa", 40.0, 42.0))) + } + } + } + } + + test("SPARK-48012: one-side shuffle with partition transforms " + + "with fewer join keys than partition kes") { + val items_partitions = Array(bucket(2, "id"), identity("name")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" + + "less join keys than partition keys for now.") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 5cdb900901056..5364fc5d62423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.sql.connector.catalog.functions -import java.sql.Timestamp +import java.time.{Instant, LocalDate, ZoneId} +import java.time.temporal.ChronoUnit import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +46,13 @@ object YearsFunction extends ScalarFunction[Long] { override def name(): String = "years" override def canonicalName(): String = name() - def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900 + val UTC: ZoneId = ZoneId.of("UTC") + val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate + + def invoke(ts: Long): Long = { + val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + } } object DaysFunction extends BoundFunction {