Skip to content

Commit

Permalink
[SPARK-48012][SQL] SPJ: Support Transfrom Expressions for One Side Sh…
Browse files Browse the repository at this point in the history
…uffle

 ### Why are the changes needed?

Support SPJ one-side shuffle if other side has partition transform expression

    ### How was this patch tested?
New unit test in KeyGroupedPartitioningSuite

    ### Was this patch authored or co-authored using generative AI tooling?
 No.
  • Loading branch information
szehon-ho committed Apr 27, 2024
1 parent 4957a40 commit 41481b4
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 27 deletions.
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark

import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.immutable.ArraySeq
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.log10
Expand Down Expand Up @@ -149,7 +150,8 @@ private[spark] class KeyGroupedPartitioner(
override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
val keys = key.asInstanceOf[Seq[Any]]
valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
val normalizedKeys = ArraySeq.from(keys)
valueMap.getOrElseUpdate(normalizedKeys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -30,7 +33,7 @@ import org.apache.spark.sql.types.DataType
case class TransformExpression(
function: BoundFunction,
children: Seq[Expression],
numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable {
numBucketsOpt: Option[Int] = None) extends Expression {

override def nullable: Boolean = true

Expand Down Expand Up @@ -113,4 +116,32 @@ case class TransformExpression(

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)

lazy val resolvedFunction: Option[Expression] = this match {
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc,
Seq(Literal(numBuckets)) ++ arguments))
case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) =>
Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments))
case _ => None
}

override def eval(input: InternalRow): Any = {
resolvedFunction match {
case Some(fn) => fn.eval(input)
case None => throw QueryExecutionErrors.cannotEvaluateExpressionError(this)
}
}

/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodegenContext]]
* @param ev an [[ExprCode]] with unique terms.
* @return an [[ExprCode]] containing the Java source code to generate the given expression
*/
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this)
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
val declaredInputTypes = scalarFunc.inputTypes().toImmutableArraySeq
val argClasses = declaredInputTypes.map(EncoderUtils.dataTypeJavaClass)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
case Some(m) if isStatic(m) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
propagateNull = false, returnNullable = scalarFunc.isResultNullable,
Expand Down Expand Up @@ -204,4 +204,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging {
None
}
}

private def isStatic(m: Method) = {
val javaStatic = Modifier.isStatic(m.getModifiers)
val scalaObjModule = m.getDeclaringClass.getField("MODULE$")
val scalaStatic = scalaObjModule != null && Modifier.isStatic(scalaObjModule.getModifiers)
javaStatic || scalaStatic
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -870,12 +870,30 @@ case class KeyGroupedShuffleSpec(
if (results.forall(p => p.isEmpty)) None else Some(results)
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
override def canCreatePartitioning: Boolean = {
// Allow one side shuffle for SPJ for now only if partially-clustered is not enabled
// and for join keys less than partition keys only if transforms are not enabled.
val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
e: Expression => e.isInstanceOf[AttributeReference]
} else {
e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression]
}
SQLConf.get.v2BucketingShuffleEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
partitioning.expressions.forall(checkExprType)
}



override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map{
case (c, e: TransformExpression) => TransformExpression(
e.function, Seq(c), e.numBucketsOpt)
case (c, _) => c
}
KeyGroupedPartitioning(newExpressions,
partitioning.numPartitions,
partitioning.partitionValues)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
val df = createJoinTestDF(Seq("arrive_time" -> "time"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
if (shuffle) {
assert(shuffles.size == 2, "partitioning with transform not work now")
assert(shuffles.size == 1, "partitioning with transform should trigger SPJ")
} else {
assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" +
" is not enabled")
Expand Down Expand Up @@ -1931,22 +1931,19 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
"(6, 50.0, cast('2023-02-01' as timestamp))")

Seq(true, false).foreach { pushdownValues =>
Seq(true, false).foreach { partiallyClustered =>
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key
-> partiallyClustered.toString,
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "SPJ should be triggered")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Row(1, "aa", 40.0, 89.0),
Row(3, "bb", 10.0, 19.5)))
}
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString,
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "SPJ should be triggered")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Row(1, "aa", 40.0, 89.0),
Row(3, "bb", 10.0, 19.5)))
}
}
}
Expand Down Expand Up @@ -1992,4 +1989,109 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
}
}
}

test("SPARK-48012: one-side shuffle with partition transforms") {
val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id"))

Seq(items_partitions, items_partitions2).foreach { partition =>
catalog.clearTables()

createTable(items, itemsColumns, partition)
sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
"(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " +
"(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " +
"(5, 'ff', 32.1, cast('2020-03-01' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(2, 10.7, cast('2020-01-01' as timestamp))," +
"(3, 19.5, cast('2020-02-01' as timestamp))," +
"(4, 56.5, cast('2020-02-01' as timestamp))")

withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "only shuffle side that does not report partitioning")

checkAnswer(df, Seq(
Row(1, "bb", 30.0, 42.0),
Row(1, "aa", 40.0, 42.0),
Row(4, "ee", 15.5, 56.5)))
}
}
}

test("SPARK-48012: one-side shuffle with partition transforms and pushdown values") {
val items_partitions = Array(bucket(2, "id"), identity("arrive_time"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " +
"(1, 'cc', 30.0, cast('2020-01-02' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(2, 10.7, cast('2020-01-01' as timestamp))")

Seq(true, false).foreach { pushDown => {
withSQLConf(
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key ->
pushDown.toString) {
val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 1, "only shuffle side that does not report partitioning")

checkAnswer(df, Seq(
Row(1, "bb", 30.0, 42.0),
Row(1, "aa", 40.0, 42.0)))
}
}
}
}

test("SPARK-48012: one-side shuffle with partition transforms " +
"with fewer join keys than partition kes") {
val items_partitions = Array(bucket(2, "id"), identity("name"))
createTable(items, itemsColumns, items_partitions)

sql(s"INSERT INTO testcat.ns.$items VALUES " +
"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " +
"(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " +
"(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " +
"(4, 'cc', 15.5, cast('2020-02-01' as timestamp))")

createTable(purchases, purchasesColumns, Array.empty)
sql(s"INSERT INTO testcat.ns.$purchases VALUES " +
"(1, 42.0, cast('2020-01-01' as timestamp)), " +
"(1, 89.0, cast('2020-01-03' as timestamp)), " +
"(3, 19.5, cast('2020-02-01' as timestamp)), " +
"(5, 26.0, cast('2023-01-01' as timestamp)), " +
"(6, 50.0, cast('2023-02-01' as timestamp))")

withSQLConf(
SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false",
SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true",
SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false",
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") {
val df = createJoinTestDF(Seq("id" -> "item_id"))
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" +
"less join keys than partition keys for now.")
checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0),
Row(1, "aa", 30.0, 89.0),
Row(1, "aa", 40.0, 42.0),
Row(1, "aa", 40.0, 89.0),
Row(3, "bb", 10.0, 19.5)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
*/
package org.apache.spark.sql.connector.catalog.functions

import java.sql.Timestamp
import java.time.{Instant, LocalDate, ZoneId}
import java.time.temporal.ChronoUnit

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -44,7 +46,13 @@ object YearsFunction extends ScalarFunction[Long] {
override def name(): String = "years"
override def canonicalName(): String = name()

def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900
val UTC: ZoneId = ZoneId.of("UTC")
val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate

def invoke(ts: Long): Long = {
val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate
ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate)
}
}

object DaysFunction extends BoundFunction {
Expand Down

0 comments on commit 41481b4

Please sign in to comment.