Skip to content

Commit

Permalink
[SPARK-50593][SQL] SPJ: Support truncate transform
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Dec 17, 2024
1 parent b2c8b30 commit 7862f50
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
.filter(KeyGroupedPartitioning.isReference)

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that join keys (required clustering keys)
Expand Down Expand Up @@ -457,14 +458,7 @@ object KeyGroupedPartitioning {

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
def isSupportedTransform(transform: TransformExpression): Boolean = {
transform.children.size == 1 && isReference(transform.children.head)
}

@tailrec
def isReference(e: Expression): Boolean = e match {
case _: Attribute => true
case g: GetStructField => isReference(g.child)
case _ => false
transform.children.count(isReference) == 1
}

expressions.forall {
Expand All @@ -473,6 +467,13 @@ object KeyGroupedPartitioning {
case _ => false
}
}

@tailrec
def isReference(e: Expression): Boolean = e match {
case _: Attribute => true
case g: GetStructField => isReference(g.child)
case _ => false
}
}

/**
Expand Down Expand Up @@ -791,7 +792,7 @@ case class KeyGroupedShuffleSpec(
distKeyToPos.getOrElseUpdate(distKey.canonicalized, mutable.BitSet.empty).add(distKeyPos)
}
partitioning.expressions.map { e =>
val leaves = e.collectLeaves()
val leaves = e.collectLeaves().filter(KeyGroupedPartitioning.isReference)
assert(leaves.size == 1, s"Expected exactly one child from $e, but found ${leaves.size}")
distKeyToPos.getOrElse(leaves.head.canonicalized, mutable.BitSet.empty)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ case class PartitionInternalRow(keys: Array[Any])
return false
}
// Just compare by reference, not by value
this.keys == other.asInstanceOf[PartitionInternalRow].keys
this.keys sameElements other.asInstanceOf[PartitionInternalRow].keys
}
override def hashCode: Int = {
Objects.hashCode(keys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ case class EnsureRequirements(
distribution: ClusteredDistribution): Option[KeyGroupedShuffleSpec] = {
def tryCreate(partitioning: KeyGroupedPartitioning): Option[KeyGroupedShuffleSpec] = {
val attributes = partitioning.expressions.flatMap(_.collectLeaves())
.filter(KeyGroupedPartitioning.isReference)
val clustering = distribution.clustering

val satisfies = if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, TransformExpression}
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog}
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
import org.apache.spark.sql.connector.catalog.{Column, Identifier, InMemoryTableCatalog, PartitionInternalRow}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.connector.distributions.Distributions
import org.apache.spark.sql.connector.expressions._
Expand All @@ -37,6 +38,7 @@ import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
private val functions = Seq(
Expand Down Expand Up @@ -195,10 +197,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")

val df = sql(s"SELECT * FROM testcat.ns.$table")
val distribution = physical.ClusteredDistribution(
Seq(TransformExpression(TruncateFunction, Seq(attr("data"), Literal(2)))))
val transformExpression = Seq(TransformExpression(
TruncateFunction, Seq(attr("data"), Literal(2))))
val distribution = physical.ClusteredDistribution(transformExpression)
val partValues = Seq(
PartitionInternalRow(Array(UTF8String.fromString("aa"))),
PartitionInternalRow(Array(UTF8String.fromString("bb"))),
PartitionInternalRow(Array(UTF8String.fromString("cc"))))
val partitioning = new KeyGroupedPartitioning(transformExpression, 3, partValues, partValues)

checkQueryPlan(df, distribution, physical.UnknownPartitioning(0))
checkQueryPlan(df, distribution, partitioning)
}

/**
Expand Down Expand Up @@ -2504,4 +2512,43 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase {
assert(scans.forall(_.inputRDD.partitions.length == 2))
}
}

test("SPARK-50593: Support truncate transform") {
val partitions: Array[Transform] = Array(
Expressions.apply("truncate", Expressions.column("data"), Expressions.literal(2))
)

// create a table with 3 partitions, partitioned by `truncate` transform
createTable("table", columns, partitions)
sql(s"INSERT INTO testcat.ns.table VALUES " +
s"(0, 'aaa', CAST('2022-01-01' AS timestamp)), " +
s"(1, 'bbb', CAST('2021-01-01' AS timestamp)), " +
s"(2, 'ccc', CAST('2020-01-01' AS timestamp))")

createTable("table2", columns2, partitions)
sql(s"INSERT INTO testcat.ns.table2 VALUES " +
s"(1, 5, 'aaa')," +
s"(5, 10, 'bbb')," +
s"(20, 40, 'bbb')," +
s"(40, 80, 'ddd')")

withSQLConf(
SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true") {

val df =
sql(
selectWithMergeJoinHint("table", "table2") +
"id, store_id, dept_id " +
"FROM testcat.ns.table JOIN testcat.ns.table2 " +
"ON table.data = table2.data " +
"SORT BY id, store_id, dept_id")
val shuffles = collectShuffles(df.queryExecution.executedPlan)
assert(shuffles.isEmpty, "should not add shuffle for both sides of the join")
checkAnswer(df,
Seq(Row(0, 1, 5), Row(1, 5, 10), Row(1, 20, 40))
)
val scans = collectScans(df.queryExecution.executedPlan)
assert(scans.forall(_.inputRDD.partitions.length == 4))
}
}
}

0 comments on commit 7862f50

Please sign in to comment.