Skip to content

Commit

Permalink
Use rule to scope sort order expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Jul 30, 2024
1 parent 3896488 commit 2443f8a
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals),
SubstituteUnresolvedOrdinals,
ScopeExpressions),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Batch("Hints", fixedPoint,
Expand Down Expand Up @@ -3508,6 +3509,33 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

/**
* Restricts the scope of resolving some expressions.
*/
object ScopeExpressions extends Rule[LogicalPlan] {
private def scopeOrder(scope: Seq[Attribute])(sortOrder: SortOrder): SortOrder = {
sortOrder match {
case so if so.child.isInstanceOf[ScopedExpression] => so
case so => so.copy(child = ScopedExpression(so.child, scope))
}
}

private def isNotScoped(sortOrder: SortOrder): Boolean =
!sortOrder.child.isInstanceOf[ScopedExpression]

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// SPARK-42199: sort order of MapGroups must be scoped to their dataAttributes
case mg: MapGroups if mg.dataOrder.exists(isNotScoped) =>
mg.copy(dataOrder = mg.dataOrder.map(scopeOrder(mg.dataAttributes)))

// SPARK-42199: sort order of CoGroups must be scoped to their respective dataAttributes
case cg: CoGroup if Seq(cg.leftOrder, cg.rightOrder).exists(_.exists(isNotScoped)) =>
val scopedLeftOrder = cg.leftOrder.map(scopeOrder(cg.leftAttr))
val scopedRightOrder = cg.rightOrder.map(scopeOrder(cg.rightAttr))
cg.copy(leftOrder = scopedLeftOrder, rightOrder = scopedRightOrder)
}
}

/**
* Resolves expressions against their scope of attributes.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1570,6 +1570,87 @@ class AnalysisSuite extends AnalysisTest with Matchers {
Project(Seq(plan.output.head.as("value")), appendCols.analyze))
}

test("SPARK-42199: MapGroups scopes sort order expressions") {
def func(k: Int, it: Iterator[Int]): Iterator[Int] = {
Iterator.empty
}

implicit val intEncoder = ExpressionEncoder[Int]()

val rel = testRelation2.analyze
val group = MapGroups(
func,
rel.output.head :: Nil,
rel.output,
SortOrder($"b", Ascending) :: Nil,
rel
)

// apply the ScopeExpressions
val actualPlan = getAnalyzer.ScopeExpressions(group)
val mg = actualPlan.collectFirst {
case mg: MapGroups => mg
}

def hasScope(scope: Seq[Attribute])(sortOrder: SortOrder): Boolean = {
sortOrder.child.isInstanceOf[ScopedExpression] &&
sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope)
}

// assert sort order to be scoped
assert(mg.isDefined)
mg.foreach { mg =>
assert(mg.dataOrder.size == 1)
assert(mg.dataOrder.forall(hasScope(mg.dataAttributes)))
}
}

test("SPARK-42199: CoGroup scopes sort order expressions") {
def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = {
Iterator.empty
}

implicit val intEncoder = ExpressionEncoder[Int]()

val left = testRelation2.select($"e").analyze
val right = testRelation3.select($"e", $"f").analyze
val leftWithKey = AppendColumns[Int, Int]((x: Int) => x, left)
val rightWithKey = AppendColumns[Int, Int]((x: Int) => x, right)
val leftOrder = SortOrder($"e", Ascending) :: Nil
val rightOrder = SortOrder($"e", Ascending) :: SortOrder($"f", Descending) :: Nil

val cogroup = leftWithKey.cogroup[Int, Int, Int, Int](
rightWithKey,
func,
leftWithKey.newColumns,
rightWithKey.newColumns,
left.output,
right.output,
leftOrder,
rightOrder
)

// apply the ScopeExpressions
val actualPlan = getAnalyzer.ScopeExpressions(cogroup)
val cg = actualPlan.collectFirst {
case cg: CoGroup => cg
}

def hasScope(scope: Seq[Attribute])(sortOrder: SortOrder): Boolean = {
sortOrder.child.isInstanceOf[ScopedExpression] &&
sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope)
}

// assert sort order to be scoped
assert(cg.isDefined)
cg.foreach { cg =>
assert(cg.leftOrder.size == 1)
assert(cg.rightOrder.size == 2)
assert(cg.leftOrder.forall(hasScope(left.output)))
assert(cg.rightOrder.forall(hasScope(right.output)))
}
}

test("SPARK-38591: resolve left and right CoGroup sort order on respective side only") {
def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = {
Iterator.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
sortExprs: Column*)(
f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = {
val sortOrder: Seq[SortOrder] = MapGroups.sortOrder(sortExprs.map(_.expr))
// SPARK-42199: resolve these sort expressions only against dataAttributes
.map(scopeSortOrder(dataAttributes))

Dataset[U](
sparkSession,
Expand Down Expand Up @@ -1166,13 +1164,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
case expr: Expression => SortOrder(expr, Ascending)
}

// SPARK-42199: resolve these sort expressions only against dataAttributes
val thisSortOrder: Seq[SortOrder] = thisSortExprs
.map(toSortOrder)
.map(scopeSortOrder(dataAttributes))
val otherSortOrder: Seq[SortOrder] = otherSortExprs
.map(toSortOrder)
.map(scopeSortOrder(other.dataAttributes))
val thisSortOrder: Seq[SortOrder] = thisSortExprs.map(toSortOrder)
val otherSortOrder: Seq[SortOrder] = otherSortExprs.map(toSortOrder)

implicit val uEncoder = other.vExprEnc
Dataset[R](
Expand Down Expand Up @@ -1215,9 +1208,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
(key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
}

private def scopeSortOrder(scope: Seq[Attribute])(sortOrder: SortOrder): SortOrder =
sortOrder.copy(child = ScopedExpression(sortOrder.child, scope))

private def scopeTypedColumn(
scope: Seq[Attribute])(
typedCol: TypedColumn[_, _]): TypedColumn[_, _] =
Expand Down

0 comments on commit 2443f8a

Please sign in to comment.