diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 220a73f784374..5019023379750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -265,7 +265,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor CTESubstitution, WindowsSubstitution, EliminateUnions, - SubstituteUnresolvedOrdinals), + SubstituteUnresolvedOrdinals, + ScopeExpressions), Batch("Disable Hints", Once, new ResolveHints.DisableHints), Batch("Hints", fixedPoint, @@ -3447,6 +3448,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + /** + * Restricts the scope of resolving some expressions. + */ + object ScopeExpressions extends Rule[LogicalPlan] { + private def scopeOrder(scope: Seq[Attribute])(sortOrder: SortOrder): SortOrder = { + sortOrder match { + case so if so.child.isInstanceOf[ScopedExpression] => so + case so => so.copy(child = ScopedExpression(so.child, scope)) + } + } + + private def isNotScoped(sortOrder: SortOrder): Boolean = + !sortOrder.child.isInstanceOf[ScopedExpression] + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + // SPARK-42199: sort order of MapGroups must be scoped to their dataAttributes + case mg: MapGroups if mg.dataOrder.exists(isNotScoped) => + mg.copy(dataOrder = mg.dataOrder.map(scopeOrder(mg.dataAttributes))) + + // SPARK-42199: sort order of CoGroups must be scoped to their respective dataAttributes + case cg: CoGroup if Seq(cg.leftOrder, cg.rightOrder).exists(_.exists(isNotScoped)) => + val scopedLeftOrder = cg.leftOrder.map(scopeOrder(cg.leftAttr)) + val scopedRightOrder = cg.rightOrder.map(scopeOrder(cg.rightAttr)) + cg.copy(leftOrder = scopedLeftOrder, rightOrder = scopedRightOrder) + } + } + /** * Resolves expressions against their scope of attributes. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index d969bde284e79..9ee9f677fa7bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1566,6 +1566,87 @@ class AnalysisSuite extends AnalysisTest with Matchers { Project(Seq(plan.output.head.as("value")), appendCols.analyze)) } + test("SPARK-42199: MapGroups scopes sort order expressions") { + def func(k: Int, it: Iterator[Int]): Iterator[Int] = { + Iterator.empty + } + + implicit val intEncoder = ExpressionEncoder[Int]() + + val rel = testRelation2.analyze + val group = MapGroups( + func, + rel.output.head :: Nil, + rel.output, + SortOrder($"b", Ascending) :: Nil, + rel + ) + + // apply the ScopeExpressions + val actualPlan = getAnalyzer.ScopeExpressions(group) + val mg = actualPlan.collectFirst { + case mg: MapGroups => mg + } + + def hasScope(scope: Seq[Attribute])(sortOrder: SortOrder): Boolean = { + sortOrder.child.isInstanceOf[ScopedExpression] && + sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope) + } + + // assert sort order to be scoped + assert(mg.isDefined) + mg.foreach { mg => + assert(mg.dataOrder.size == 1) + assert(mg.dataOrder.forall(hasScope(mg.dataAttributes))) + } + } + + test("SPARK-42199: CoGroup scopes sort order expressions") { + def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = { + Iterator.empty + } + + implicit val intEncoder = ExpressionEncoder[Int]() + + val left = testRelation2.select($"e").analyze + val right = testRelation3.select($"e", $"f").analyze + val leftWithKey = AppendColumns[Int, Int]((x: Int) => x, left) + val rightWithKey = AppendColumns[Int, Int]((x: Int) => x, right) + val leftOrder = SortOrder($"e", Ascending) :: Nil + val rightOrder = SortOrder($"e", Ascending) :: SortOrder($"f", Descending) :: Nil + + val cogroup = leftWithKey.cogroup[Int, Int, Int, Int]( + rightWithKey, + func, + leftWithKey.newColumns, + rightWithKey.newColumns, + left.output, + right.output, + leftOrder, + rightOrder + ) + + // apply the ScopeExpressions + val actualPlan = getAnalyzer.ScopeExpressions(cogroup) + val cg = actualPlan.collectFirst { + case cg: CoGroup => cg + } + + def hasScope(scope: Seq[Attribute])(sortOrder: SortOrder): Boolean = { + sortOrder.child.isInstanceOf[ScopedExpression] && + sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope) + } + + // assert sort order to be scoped + assert(cg.isDefined) + cg.foreach { cg => + assert(cg.leftOrder.size == 1) + assert(cg.rightOrder.size == 2) + assert(cg.leftOrder.forall(hasScope(left.output))) + assert(cg.rightOrder.forall(hasScope(right.output))) + } + } + test("SPARK-38591: resolve left and right CoGroup sort order on respective side only") { def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = { Iterator.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 3618290bcd6b7..c8abfb054ad81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -200,8 +200,6 @@ class KeyValueGroupedDataset[K, V] private[sql]( sortExprs: Column*)( f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs.map(_.expr)) - // SPARK-42199: resolve these sort expressions only against dataAttributes - .map(scopeSortOrder(dataAttributes)) Dataset[U]( sparkSession, @@ -870,13 +868,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( case expr: Expression => SortOrder(expr, Ascending) } - // SPARK-42199: resolve these sort expressions only against dataAttributes - val thisSortOrder: Seq[SortOrder] = thisSortExprs - .map(toSortOrder) - .map(scopeSortOrder(dataAttributes)) - val otherSortOrder: Seq[SortOrder] = otherSortExprs - .map(toSortOrder) - .map(scopeSortOrder(other.dataAttributes)) + val thisSortOrder: Seq[SortOrder] = thisSortExprs.map(toSortOrder) + val otherSortOrder: Seq[SortOrder] = otherSortExprs.map(toSortOrder) implicit val uEncoder = other.vExprEnc Dataset[R]( @@ -919,9 +912,6 @@ class KeyValueGroupedDataset[K, V] private[sql]( (key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } - private def scopeSortOrder(scope: Seq[Attribute])(sortOrder: SortOrder): SortOrder = - sortOrder.copy(child = ScopedExpression(sortOrder.child, scope)) - private def scopeTypedColumn( scope: Seq[Attribute])( typedCol: TypedColumn[_, _]): TypedColumn[_, _] =