diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index 90e3bdcd082cd..d2bdad2d880de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering} -import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.NonFateSharingCache @@ -85,22 +84,25 @@ object InternalRowComparableWrapper { } def mergePartitions( - leftPartitioning: KeyGroupedPartitioning, - rightPartitioning: KeyGroupedPartitioning, - partitionExpression: Seq[Expression]): Seq[InternalRow] = { + leftPartitioning: Seq[InternalRow], + rightPartitioning: Seq[InternalRow], + partitionExpression: Seq[Expression], + intersect: Boolean = false): Seq[InternalRowComparableWrapper] = { val partitionDataTypes = partitionExpression.map(_.dataType) - val partitionsSet = new mutable.HashSet[InternalRowComparableWrapper] - leftPartitioning.partitionValues + val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] + leftPartitioning .map(new InternalRowComparableWrapper(_, partitionDataTypes)) - .foreach(partition => partitionsSet.add(partition)) - rightPartitioning.partitionValues + .foreach(partition => leftPartitionSet.add(partition)) + val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] + rightPartitioning .map(new InternalRowComparableWrapper(_, partitionDataTypes)) - .foreach(partition => partitionsSet.add(partition)) - // SPARK-41471: We keep to order of partitions to make sure the order of - // partitions is deterministic in different case. - val partitionOrdering: Ordering[InternalRow] = { - RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) + .foreach(partition => rightPartitionSet.add(partition)) + + val result = if (intersect) { + leftPartitionSet.intersect(rightPartitionSet) + } else { + leftPartitionSet.union(rightPartitionSet) } - partitionsSet.map(_.row).toSeq.sorted(partitionOrdering) + result.toSeq } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f50eb9b121589..ac5aa0f6bbdc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1635,6 +1635,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_PARTITION_FILTER_ENABLED = + buildConf("spark.sql.sources.v2.bucketing.partition.filter.enabled") + .doc(s"Whether to filter partitions when running storage-partition join. " + + s"When enabled, partitions without matches on the other side can be omitted for " + + s"scanning, if allowed by the join type. This config requires both " + + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + + s"enabled.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index cc28e85525162..f3dd232129e8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -61,7 +61,7 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) val merged = InternalRowComparableWrapper.mergePartitions( - leftPartitioning, rightPartitioning, expressions) + leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions) assert(merged.size == bucketNum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 997576a396d20..6a502a44fad58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -200,7 +200,12 @@ case class BatchScanExec( .get .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap - val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) => + val filteredGroupedPartitions = finalGroupedPartitions.filter { + case (partValues, _) => + commonPartValuesMap.keySet.contains( + InternalRowComparableWrapper(partValues, partExpressions)) + } + val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap .get(InternalRowComparableWrapper(partValue, partExpressions)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 0470aacd4f823..90287c2028467 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -429,8 +429,19 @@ case class EnsureRequirements( // expressions val partitionExprs = leftSpec.partitioning.expressions - var mergedPartValues = InternalRowComparableWrapper - .mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs) + // in case of compatible but not identical partition expressions, we apply 'reduce' + // transforms to group one side's partitions as well as the common partition values + val leftReducers = leftSpec.reducers(rightSpec) + val leftParts = reducePartValues(leftSpec.partitioning.partitionValues, + partitionExprs, + leftReducers) + val rightReducers = rightSpec.reducers(leftSpec) + val rightParts = reducePartValues(rightSpec.partitioning.partitionValues, + partitionExprs, + rightReducers) + + // merge values on both sides + var mergedPartValues = mergePartitions(leftParts, rightParts, partitionExprs, joinType) .map(v => (v, 1)) logInfo(log"After merging, there are " + @@ -525,23 +536,6 @@ case class EnsureRequirements( } } - // in case of compatible but not identical partition expressions, we apply 'reduce' - // transforms to group one side's partitions as well as the common partition values - val leftReducers = leftSpec.reducers(rightSpec) - val rightReducers = rightSpec.reducers(leftSpec) - - if (leftReducers.isDefined || rightReducers.isDefined) { - mergedPartValues = reduceCommonPartValues(mergedPartValues, - leftSpec.partitioning.expressions, - leftReducers) - mergedPartValues = reduceCommonPartValues(mergedPartValues, - rightSpec.partitioning.expressions, - rightReducers) - val rowOrdering = RowOrdering - .createNaturalAscendingOrdering(partitionExprs.map(_.dataType)) - mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) - } - // Now we need to push-down the common partition information to the scan in each child newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, leftReducers, applyPartialClustering, replicateLeftSide) @@ -602,15 +596,15 @@ case class EnsureRequirements( child, joinKeyPositions)) } - private def reduceCommonPartValues( - commonPartValues: Seq[(InternalRow, Int)], + private def reducePartValues( + partValues: Seq[InternalRow], expressions: Seq[Expression], reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { - case Some(reducers) => commonPartValues.groupBy { case (row, _) => + case Some(reducers) => partValues.map { row => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) - }.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq - case _ => commonPartValues + }.distinct.map(_.row) + case _ => partValues } } @@ -651,6 +645,46 @@ case class EnsureRequirements( } } + /** + * Merge and sort partitions values for SPJ and optionally enable partition filtering. + * Both sides must have + * matching partition expressions. + * @param leftPartitioning left side partition values + * @param rightPartitioning right side partition values + * @param partitionExpression partition expressions + * @param joinType join type for optional partition filtering + * @return merged and sorted partition values + */ + private def mergePartitions( + leftPartitioning: Seq[InternalRow], + rightPartitioning: Seq[InternalRow], + partitionExpression: Seq[Expression], + joinType: JoinType): Seq[InternalRow] = { + + val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { + joinType match { + case Inner => InternalRowComparableWrapper.mergePartitions( + leftPartitioning, rightPartitioning, partitionExpression, intersect = true) + case LeftOuter => leftPartitioning.map( + InternalRowComparableWrapper(_, partitionExpression)) + case RightOuter => rightPartitioning.map( + InternalRowComparableWrapper(_, partitionExpression)) + case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning, + rightPartitioning, partitionExpression) + } + } else { + InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, + partitionExpression) + } + + // SPARK-41471: We keep to order of partitions to make sure the order of + // partitions is deterministic in different case. + val partitionOrdering: Ordering[InternalRow] = { + RowOrdering.createNaturalAscendingOrdering(partitionExpression.map(_.dataType)) + } + merged.map(_.row).sorted(partitionOrdering) + } + def apply(plan: SparkPlan): SparkPlan = { val newPlan = plan.transformUp { case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 5e5453b4cd500..03cad12364daa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -667,11 +667,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(5, 30.0, cast('2023-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(("true", 10), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 8), (false, true, 3), (true, false, 10), (false, false, 5)).foreach { + case (partial, filter, expected) => withSQLConf( - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> partial.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { @@ -692,6 +693,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-42038: partially clustered: with different partition keys and missing keys on " + "left-hand side") { val items_partitions = Array(identity("id")) @@ -715,11 +717,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(5, 30.0, cast('2023-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(("true", 9), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 3), (false, true, 2), (true, false, 9), (false, false, 5)).foreach { + case(partial, filter, expected) => withSQLConf( SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partial.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { @@ -759,11 +763,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(5, 30.0, cast('2023-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(("true", 6), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 2), (false, true, 2), (true, false, 6), (false, false, 5)).foreach { + case (partial, filter, expected) => withSQLConf( SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partial.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { @@ -802,12 +808,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // In a left-outer join, and when the left side has larger stats, partially clustered // distribution should kick in and pick the right hand side to replicate partitions. Seq(true, false).foreach { pushDownValues => - Seq(("true", 7), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 5), (false, true, 3), (true, false, 7), (false, false, 5)).foreach { + case (partial, filter, expected) => withSQLConf( SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partial.toString) { val df = createJoinTestDF( Seq("id" -> "item_id", "arrive_time" -> "time"), joinType = "LEFT") val shuffles = collectShuffles(df.queryExecution.executedPlan) @@ -815,7 +823,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.isEmpty, "should not contain any shuffle") val scans = collectScans(df.queryExecution.executedPlan) assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -1336,62 +1344,71 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "(2, 'ww', cast('2020-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(true, false).foreach { partiallyClustered => - Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => - - withSQLConf( - SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> - partiallyClustered.toString, - SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> - allowJoinKeysSubsetOfPartitionKeys.toString) { - val df = sql( - s""" - |${selectWithMergeJoinHint("t1", "t2")} - |t1.id AS id, t1.data AS t1data, t2.data AS t2data - |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 - |ON t1.id = t2.id ORDER BY t1.id, t1data, t2data - |""".stripMargin) - val shuffles = collectShuffles(df.queryExecution.executedPlan) - if (allowJoinKeysSubsetOfPartitionKeys) { - assert(shuffles.isEmpty, "SPJ should be triggered") - } else { - assert(shuffles.nonEmpty, "SPJ should not be triggered") - } - - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) - - (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { - // SPJ and partially-clustered - case (true, true) => assert(scans == Seq(8, 8)) - // SPJ and not partially-clustered - case (true, false) => assert(scans == Seq(4, 4)) - // No SPJ - case _ => assert(scans == Seq(5, 4)) + Seq(true, false).foreach { filter => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.id AS id, t1.data AS t1data, t2.data AS t2data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.id = t2.id ORDER BY t1.id, t1data, t2data + |""".stripMargin) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scannedPartitions = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered, filter) match { + // SPJ, partially-clustered, with filter + case (true, true, true) => assert(scannedPartitions == Seq(6, 6)) + + // SPJ, partially-clustered, no filter + case (true, true, false) => assert(scannedPartitions == Seq(8, 8)) + + // SPJ and not partially-clustered, with filter + case (true, false, true) => assert(scannedPartitions == Seq(2, 2)) + + // SPJ and not partially-clustered, no filter + case (true, false, false) => assert(scannedPartitions == Seq(4, 4)) + + // No SPJ + case _ => assert(scannedPartitions == Seq(5, 4)) + } + + checkAnswer(df, Seq( + Row(2, "bb", "ww"), + Row(2, "cc", "ww"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy") + )) } - - checkAnswer(df, Seq( - Row(2, "bb", "ww"), - Row(2, "cc", "ww"), - Row(3, "dd", "xx"), - Row(3, "dd", "xx"), - Row(3, "dd", "xx"), - Row(3, "dd", "xx"), - Row(3, "dd", "yy"), - Row(3, "dd", "yy"), - Row(3, "dd", "yy"), - Row(3, "dd", "yy"), - Row(3, "ee", "xx"), - Row(3, "ee", "xx"), - Row(3, "ee", "xx"), - Row(3, "ee", "xx"), - Row(3, "ee", "yy"), - Row(3, "ee", "yy"), - Row(3, "ee", "yy"), - Row(3, "ee", "yy") - )) } } } @@ -2144,7 +2161,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { test("SPARK-48012: one-side shuffle with partition transforms " + "with fewer join keys than partition kes") { val items_partitions = Array(bucket(2, "id"), identity("name")) - createTable(items, itemsColumns, items_partitions) + createTable(items, itemsColumns, items_partitions)O sql(s"INSERT INTO testcat.ns.$items VALUES " + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + @@ -2176,4 +2193,194 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(3, "bb", 10.0, 19.5))) } } + + test("SPARK-48949: test partition filters inner join") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2020-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2020-01-04' as timestamp)), " + + s"(4, 'cc', 43.5, cast('2020-01-05' as timestamp)), " + + s"(5, 'cc', 44.5, cast('2020-01-15' as timestamp)), " + + s"(6, 'dd', 45.5, cast('2020-02-07' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2020-02-08' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, + Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) + ) + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 2)) + } + } + + test("SPARK-48949: test partition filters with no matches") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(4, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + assert(df.collect().isEmpty, "should return no results") + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 0)) + } + } + + test("SPARK-48949: test partition filters with right outer") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(4, 42.0, cast('2020-01-02' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id"), joinType = "RIGHT OUTER") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + + checkAnswer(df, + Seq(Row(null, null, null, 42.0), + Row(null, null, null, 44.0), + Row(1, "aa", 40.0, 40.0)) + ) + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 3)) + } + } + + test("SPARK-48949: test partition filters with full outer") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(4, 42.0, cast('2020-01-02' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id"), joinType = "FULL OUTER") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + + checkAnswer(df, + Seq(Row(null, null, null, 42.0), + Row(null, null, null, 44.0), + Row(0, "aa", 39.0, null), + Row(1, "aa", 40.0, 40.0)) + ) + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 4)) + } + } + + test("SPARK-48949: test partition filters with left outer") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 38.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 39.0, cast('2020-01-02' as timestamp)), " + + s"(4, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(4, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id"), joinType = "LEFT OUTER") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + + checkAnswer(df, + Seq(Row(0, "aa", 38.0, null), + Row(1, "aa", 39.0, null), + Row(4, "aa", 40.0, 42.0)) + ) + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 3)) + } + } + + test("SPARK-48949: test partition filters with compatible transforms") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2020-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2020-01-04' as timestamp)), " + + s"(4, 'cc', 43.5, cast('2020-01-05' as timestamp)), " + + s"(5, 'cc', 44.5, cast('2020-01-15' as timestamp)), " + + s"(6, 'dd', 45.5, cast('2020-02-07' as timestamp))") + + val purchases_partitions = Array(bucket(4, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2020-02-08' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, + Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) + ) + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 2)) + } + } }