diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0164af945ca28..2d15c978bee6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -274,7 +274,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor CTESubstitution, WindowsSubstitution, EliminateUnions, - SubstituteUnresolvedOrdinals), + SubstituteUnresolvedOrdinals, + ScopeExpressions), Batch("Disable Hints", Once, new ResolveHints.DisableHints), Batch("Hints", fixedPoint, @@ -295,6 +296,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveFieldNameAndPosition :: AddMetadataColumns :: DeduplicateRelations :: + new ResolveScopedExpression(resolver) :: new ResolveReferences(catalogManager) :: // Please do not insert any other rules in between. See the TODO comments in rule // ResolveLateralColumnAliasReference for more details. @@ -519,6 +521,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)() + case se @ ScopedExpression(ne: NamedExpression, _) => Alias(se, ne.name)() case e: ExtractValue if extractOnly(e) => Alias(e, toPrettySQL(e))() case e if optGenAliasFunc.isDefined => Alias(child, optGenAliasFunc.get.apply(e))() @@ -1654,34 +1657,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } - case mg: MapGroups if mg.dataOrder.exists(!_.resolved) => - // Resolve against `AppendColumns`'s children, instead of `AppendColumns`, - // because `AppendColumns`'s serializer might produce conflict attribute - // names leading to ambiguous references exception. - val planForResolve = mg.child match { - case appendColumns: AppendColumns => appendColumns.child - case plan => plan - } - val resolvedOrder = mg.dataOrder - .map(resolveExpressionByPlanOutput(_, planForResolve).asInstanceOf[SortOrder]) - mg.copy(dataOrder = resolvedOrder) - // Left and right sort expression have to be resolved against the respective child plan only case cg: CoGroup if cg.leftOrder.exists(!_.resolved) || cg.rightOrder.exists(!_.resolved) => - // Resolve against `AppendColumns`'s children, instead of `AppendColumns`, - // because `AppendColumns`'s serializer might produce conflict attribute - // names leading to ambiguous references exception. - val (leftPlanForResolve, rightPlanForResolve) = Seq(cg.left, cg.right).map { - case appendColumns: AppendColumns => appendColumns.child - case plan => plan - } match { - case Seq(left, right) => (left, right) - } - val resolvedLeftOrder = cg.leftOrder - .map(resolveExpressionByPlanOutput(_, leftPlanForResolve).asInstanceOf[SortOrder]) + .map(resolveExpressionByPlanOutput(_, cg.left).asInstanceOf[SortOrder]) val resolvedRightOrder = cg.rightOrder - .map(resolveExpressionByPlanOutput(_, rightPlanForResolve).asInstanceOf[SortOrder]) + .map(resolveExpressionByPlanOutput(_, cg.right).asInstanceOf[SortOrder]) cg.copy(leftOrder = resolvedLeftOrder, rightOrder = resolvedRightOrder) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a9fbe548ba39e..668c5f61dde3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -144,11 +144,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB operator: LogicalPlan, a: Attribute, errorClass: String): Nothing = { + failUnresolvedAttribute(a, operator.inputSet.toSeq, errorClass) + } + + private def failUnresolvedAttribute( + a: Attribute, + candidates: Seq[Attribute], + errorClass: String): Nothing = { val missingCol = a.sql - val candidates = operator.inputSet.toSeq - .map(attr => attr.qualifier :+ attr.name) + val candidateNames = candidates.map(attr => attr.qualifier :+ attr.name) val orderedCandidates = - StringUtils.orderSuggestedIdentifiersBySimilarity(missingCol, candidates) + StringUtils.orderSuggestedIdentifiersBySimilarity(missingCol, candidateNames) throw QueryCompilationErrors.unresolvedAttributeError( errorClass, missingCol, orderedCandidates, a.origin) } @@ -310,6 +316,24 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB ColumnDefinition.checkColumnDefinitions(operator) var stagedError: Option[() => Unit] = None + + // check for unresolved scoped expressions before any other check below, + // this improves error messages, see SPARK-42199 + getAllExpressions(operator).foreach(_.foreach { + case se: ScopedExpression if !se.resolved => + se.expr.foreachUp { + case a: Attribute if !a.resolved => + failUnresolvedAttribute(a, se.scope.attrs, "UNRESOLVED_COLUMN") + case _ => + } + case se: ScopedExpression => throw SparkException.internalError( + "Resolved scoped expression operator should be removed " + + f"during analysis: ${toSQLExpr(se)}", + context = se.origin.getQueryContext, + summary = se.origin.context.summary) + case _ => + }) + getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => failUnresolvedAttribute(operator, a, "UNRESOLVED_COLUMN") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index 7ea90854932e5..c84ee15004444 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, IntegerLiteral, Literal, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} import org.apache.spark.sql.connector.catalog.CatalogManager @@ -51,18 +51,10 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S with ColumnResolutionHelper with AliasHelper { def apply(a: Aggregate): Aggregate = { - val planForResolve = a.child match { - // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of - // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute - // names leading to ambiguous references exception. - case appendColumns: AppendColumns => appendColumns - case _ => a - } - val resolvedGroupExprsBasic = a.groupingExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve)) + .map(resolveExpressionByPlanChildren(_, a)) val resolvedAggExprsBasic = a.aggregateExpressions.map( - resolveExpressionByPlanChildren(_, planForResolve)) + resolveExpressionByPlanChildren(_, a)) val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsBasic) val resolvedAggExprsFinal = resolvedAggExprsWithLCA.map(resolveColsLastResort) .map(_.asInstanceOf[NamedExpression]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/scopes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/scopes.scala new file mode 100644 index 0000000000000..6739586d66101 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/scopes.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{CoGroup, LogicalPlan, MapGroups} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.DataType + + +/** + * An expression that has to be resolved against a scope of resolved attributes. + */ +case class ScopedExpression(expr: Expression, scope: Seq[Attribute]) + extends Expression with Unevaluable { + override def children: Seq[Expression] = expr +: scope + override def dataType: DataType = expr.dataType + override def nullable: Boolean = expr.nullable + override def prettyName: String = "scoped" + override def sql: String = s"$prettyName(${expr.sql}, $scope)" + override lazy val resolved: Boolean = expr.resolved + + override protected def withNewChildrenInternal(children: IndexedSeq[Expression]): Expression = { + val scope = children.tail + assert(scope.forall(_.isInstanceOf[Attribute]), "Scope children have to be attributes") + copy(expr = children.head, scope = scope.map(_.asInstanceOf[Attribute])) + } +} + +/** + * Restricts the scope of resolving some expressions. + */ +object ScopeExpressions extends Rule[LogicalPlan] { + private def scopeOrder(scope: Seq[Attribute])(sortOrder: SortOrder): SortOrder = { + sortOrder match { + case so if so.child.isInstanceOf[ScopedExpression] => so + case so => so.copy( + child = ScopedExpression(so.child, scope), + sameOrderExpressions = so.sameOrderExpressions.map(soe => ScopedExpression(soe, scope)) + ) + } + } + + private def isNotScoped(sortOrder: SortOrder): Boolean = + !sortOrder.child.isInstanceOf[ScopedExpression] + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + // SPARK-42199: sort order of MapGroups must be scoped to their dataAttributes + case mg: MapGroups if mg.dataOrder.exists(isNotScoped) => + mg.copy(dataOrder = mg.dataOrder.map(scopeOrder(mg.dataAttributes))) + + // SPARK-42199: sort order of CoGroups must be scoped to their respective dataAttributes + case cg: CoGroup if Seq(cg.leftOrder, cg.rightOrder).exists(_.exists(isNotScoped)) => + val scopedLeftOrder = cg.leftOrder.map(scopeOrder(cg.leftAttr)) + val scopedRightOrder = cg.rightOrder.map(scopeOrder(cg.rightAttr)) + cg.copy(leftOrder = scopedLeftOrder, rightOrder = scopedRightOrder) + } +} + +/** + * Resolves expressions against their scope of attributes. + */ +class ResolveScopedExpression(val resolver: Resolver) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { + case se: ScopedExpression if se.resolved => se.expr + case se @ ScopedExpression(expr, attributes) => + val resolved = expr.transformDown { + case u@UnresolvedAttribute(nameParts) => + attributes.resolve(nameParts, resolver).getOrElse(u) + } + if (resolved.fastEquals(expr)) { + se + } else { + resolved + } + } +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 1816c620414c9..419438df2cdab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1544,6 +1544,124 @@ class AnalysisSuite extends AnalysisTest with Matchers { queryContext = Array(ExpectedContext("SELECT *\nFROM t1\nWHERE 'true'", 31, 59))) } + test("SPARK-42199: resolve expression against scoped attributes") { + val plan = testRelation.select($"a".as("value")).analyze + implicit val intEncoder = ExpressionEncoder[Int]() + val appendCols = AppendColumns[Int, Int]((x: Int) => x, plan) + + // AppendColumns adds a duplicate 'value' column, which makes $"value" ambiguous + assertAnalysisErrorClass( + Project( + Seq($"value"), + appendCols + ), + "AMBIGUOUS_REFERENCE", + Map( + "name" -> "`value`", + "referenceNames" -> "[`value`, `value`]")) + + // We can use ScopedExpression to restrict resolution of $"value" to child of AppendColumns + checkAnalysis( + Project( + // 'value' is resolved against plan, not this Project's child + Seq(ScopedExpression($"value", appendCols.child.output).as("value")), + // this appends another column 'value' + appendCols + ), + Project(Seq(plan.output.head.as("value")), appendCols.analyze)) + } + + test("SPARK-42199: MapGroups scopes sort order expressions") { + def func(k: Int, it: Iterator[Int]): Iterator[Int] = { + Iterator.empty + } + + implicit val intEncoder = ExpressionEncoder[Int]() + + val rel = testRelation2.analyze + val group = MapGroups( + func, + rel.output.head :: Nil, + rel.output, + SortOrder($"b", Ascending, $"b".as("b2") :: Nil) :: Nil, + rel + ) + + // apply the ScopeExpressions + val actualPlan = ScopeExpressions(group) + val mg = actualPlan.collectFirst { + case mg: MapGroups => mg + } + + def hasScope(scope: Seq[Attribute])(sortOrder: SortOrder): Boolean = { + sortOrder.child.isInstanceOf[ScopedExpression] && + sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope) && + // because we have aliased the order column + sortOrder.sameOrderExpressions.nonEmpty && + sortOrder.sameOrderExpressions.forall(_.isInstanceOf[ScopedExpression]) + } + + // assert sort order to be scoped + assert(mg.isDefined) + mg.foreach { mg => + assert(mg.dataOrder.size == 1) + assert(mg.dataOrder.forall(hasScope(mg.dataAttributes)), mg.dataOrder.mkString(", ")) + } + } + + test("SPARK-42199: CoGroup scopes sort order expressions") { + def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = { + Iterator.empty + } + + implicit val intEncoder = ExpressionEncoder[Int]() + + val left = testRelation2.select($"e").analyze + val right = testRelation3.select($"e", $"f").analyze + val leftWithKey = AppendColumns[Int, Int]((x: Int) => x, left) + val rightWithKey = AppendColumns[Int, Int]((x: Int) => x, right) + val leftOrder = SortOrder($"e", Ascending) :: Nil + val rightOrder = + SortOrder($"e", Ascending, $"e".as("e2") :: Nil) :: + SortOrder($"f", Descending, $"f".as("f2") :: Nil) :: + Nil + + val cogroup = leftWithKey.cogroup[Int, Int, Int, Int]( + rightWithKey, + func, + leftWithKey.newColumns, + rightWithKey.newColumns, + left.output, + right.output, + leftOrder, + rightOrder + ) + + // apply the ScopeExpressions + val actualPlan = ScopeExpressions(cogroup) + val cg = actualPlan.collectFirst { + case cg: CoGroup => cg + } + + def hasScope(scope: Seq[Attribute], hasSameOrderExpr: Boolean) + (sortOrder: SortOrder): Boolean = { + sortOrder.child.isInstanceOf[ScopedExpression] && + sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope) && + // because we (may) have aliased the order column + sortOrder.sameOrderExpressions.nonEmpty == hasSameOrderExpr && + sortOrder.sameOrderExpressions.forall(_.isInstanceOf[ScopedExpression]) + } + + // assert sort order to be scoped + assert(cg.isDefined) + cg.foreach { cg => + assert(cg.leftOrder.size == 1) + assert(cg.rightOrder.size == 2) + assert(cg.leftOrder.forall(hasScope(left.output, hasSameOrderExpr = false))) + assert(cg.rightOrder.forall(hasScope(right.output, hasSameOrderExpr = true))) + } + } + test("SPARK-38591: resolve left and right CoGroup sort order on respective side only") { def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = { Iterator.empty diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index bb6d52308c192..9018b2e500586 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -2241,6 +2241,9 @@ class SparkConnectPlanner( val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, ds.groupingAttributes) val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq .map(expr => transformExpressionWithTypedReduceExpression(expr, input)) + // SPARK-42199: resolve these aggregate expressions only against dataAttributes + // this is to hide key column from expression resolution + .map(ScopedExpression(_, ds.dataAttributes)) .map(toNamedExpression) logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, ds.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 1ebdd57f1962b..821b617d349bd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.analysis.{EliminateEventTimeWatermark, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, ScopedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.internal.ExpressionColumnNode import org.apache.spark.sql.internal.TypedAggUtils.{aggKeyColumn, withInputType} import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeMode} @@ -305,7 +306,12 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** @inheritdoc */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(c => encoderFor(c.encoder)) - val namedColumns = columns.map(c => withInputType(c.named, vExprEnc, dataAttributes)) + val namedColumns = + columns + // SPARK-42199: resolve these aggregate expressions only against dataAttributes + // this is to hide key column from expression resolution + .map(scopeTypedColumn(dataAttributes)) + .map(c => withInputType(c.named, vExprEnc, dataAttributes)) val keyColumn = aggKeyColumn(kExprEnc, groupingAttributes) val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sparkSession, aggregate) @@ -334,6 +340,11 @@ class KeyValueGroupedDataset[K, V] private[sql]( other.logicalPlan)) } + private def scopeTypedColumn( + scope: Seq[Attribute])( + typedCol: TypedColumn[_, _]): TypedColumn[_, _] = + new TypedColumn(ExpressionColumnNode(ScopedExpression(typedCol.expr, scope)), typedCol.encoder) + override def toString: String = { val builder = new StringBuilder val kFields = kExprEnc.schema.map { f => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index fdb2ec30fdd2d..5240e07b2cf5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -567,6 +567,68 @@ class DatasetSuite extends QueryTest (1, 1)) } + test("SPARK-42199: groupBy function, agg, unresolved reference suggestions") { + checkError( + exception = intercept[AnalysisException] { + spark.range(10).groupByKey(id => id).agg(count("unknown")) + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`unknown`", "proposal" -> "`id`"), + context = ExpectedContext(fragment = "count", getCurrentClassCallSitePattern)) + } + + test("SPARK-42199: groupBy function, flatMapSortedGroups, unresolved reference suggestions") { + checkError( + exception = intercept[AnalysisException] { + spark.range(10).groupByKey(id => id).flatMapSortedGroups($"unknown") { + case (g, it) => Iterator((g, it.mkString(", "))) + } + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`unknown`", "proposal" -> "`id`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) + } + + test("SPARK-42199: groupBy function, mapValues, agg, unresolved reference suggestions") { + checkError( + exception = intercept[AnalysisException] { + spark.range(10).groupByKey(id => id).mapValues(id => id).agg(count("unknown")) + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`unknown`", "proposal" -> "`value`"), + context = ExpectedContext(fragment = "count", getCurrentClassCallSitePattern)) + } + + test("SPARK-42199: groupBy function, mapValues, flatMapSortedGroups, " + + "unresolved reference suggestions") { + checkError( + exception = intercept[AnalysisException] { + spark.range(10).groupByKey(id => id).mapValues(id => id).flatMapSortedGroups($"unknown") { + case (g, it) => Iterator((g, it.mkString(", "))) + } + }, + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`unknown`", "proposal" -> "`value`"), + context = ExpectedContext(fragment = "$", getCurrentClassCallSitePattern)) + } + + test("SPARK-42199: groupBy function, agg expr resolution") { + val actual = spark.range(3) + .withColumnRenamed("id", "value").as[Long] // add column 'value' to dataset + .groupByKey(value => value * 2) // produces key column 'value' + .agg(sum("value").as[Long]) // 'value' does not resolve to key column + .collect() + assert(actual.sorted === Seq((0, 0), (2, 1), (4, 2))) + + val actual2 = spark.range(3) + .withColumnRenamed("id", "value").as[Long] // add column 'value' to dataset + .groupByKey(value => value * 2) // produces key column 'value' + .mapValues(value => value * -1) // replaces value column 'value' + .agg(sum("value").as[Long]) // 'value' does not resolve to key column + .collect() + assert(actual2.sorted === Seq((0, 0), (2, -1), (4, -2))) + } + test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupByKey(v => (v._1, "word")) @@ -700,6 +762,42 @@ class DatasetSuite extends QueryTest assert(result.sortBy(_.a) === Seq(K1(0), K1(0), K1(1), K1(1))) } + test("SPARK-42199: groupBy function, flatMapSortedGroups expr resolution") { + val ds = Seq(("a", 1, 10), ("a", 2, 20), ("b", 2, 1), ("b", 1, 2), ("c", 1, 1)) + .toDF("key", "seq", "value") + + // groupByKey produces key column 'value' + val grouped = ds.groupByKey(v => v.getString(0)) + // 'value' does not resolve to key column + val aggregated = grouped.flatMapSortedGroups($"value") { + (g, iter) => Iterator(g, iter.mkString(", ")) + } + + checkDatasetUnorderly( + aggregated, + "a", "[a,1,10], [a,2,20]", + "b", "[b,2,1], [b,1,2]", + "c", "[c,1,1]" + ) + + + // groupByKey produces key column 'value' + val grouped2 = ds.groupByKey(v => v.getString(0)) + // mapValues replaces value column 'value' + .mapValues(v => v.getInt(1) * -1) + // 'value' does not resolve to key column + val aggregated2 = grouped2.flatMapSortedGroups($"value") { + (g, iter) => Iterator(g, iter.mkString(", ")) + } + + checkDatasetUnorderly( + aggregated2, + "a", "-2, -1", + "b", "-2, -1", + "c", "-1" + ) + } + test("groupBy function, mapValues, flatMap") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val keyValue = ds.groupByKey(_._1).mapValues(_._2) @@ -972,6 +1070,44 @@ class DatasetSuite extends QueryTest checkAnswer(joined, Row(0L, 123L, 0L) :: Nil) } + test("SPARK-42199: cogroup sorted, sort expr resolution") { + val left = Seq(1 -> "a", 3 -> "xyz", 5 -> "hello", 3 -> "abcx", 3 -> "ijk").toDF("id", "value") + val right = Seq(2 -> "q", 3 -> "w", 5 -> "x", 5 -> "z", 3 -> "a", 5 -> "y").toDF("id", "value") + + // dataset column 'value' coexists with key column 'value' produced by groupByKey + val groupedLeft = left.groupByKey(_.getInt(0)) + val groupedRight = right.groupByKey(_.getInt(0)) + + // 'value' does not resolve to key column + val actual = groupedLeft.cogroupSorted(groupedRight)($"value")($"value".desc) { + (key, left, right) => Iterator(key -> (left.mkString + "#" + right.mkString)) + } + checkDatasetUnorderly( + actual, + 1 -> "[1,a]#", + 2 -> "#[2,q]", + 3 -> "[3,abcx][3,ijk][3,xyz]#[3,w][3,a]", + 5 -> "[5,hello]#[5,z][5,y][5,x]") + + + val groupedLeft2 = groupedLeft + // mapValues replaces value column 'value' + .mapValues(_.getString(1).reverse) + val groupedRight2 = groupedRight + // mapValues replaces value column 'value' + .mapValues(_.getString(1).reverse) + // 'value' does not resolve to key column + val actual2 = groupedLeft2.cogroupSorted(groupedRight2)($"value")($"value".desc) { + (key, left, right) => Iterator(key -> (left.mkString + "#" + right.mkString)) + } + checkDatasetUnorderly( + actual2, + 1 -> "a#", + 2 -> "#q", + 3 -> "kjixcbazyx#wa", + 5 -> "olleh#zyx") + } + test("SPARK-34806: observation on datasets") { val namedObservation = Observation("named") val unnamedObservation = Observation()