Skip to content

Commit

Permalink
Introduce ScopedExpression to control scope of expression resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed Sep 6, 2024
1 parent fdeb288 commit 4b90b28
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
CTESubstitution,
WindowsSubstitution,
EliminateUnions,
SubstituteUnresolvedOrdinals),
SubstituteUnresolvedOrdinals,
ScopeExpressions),
Batch("Disable Hints", Once,
new ResolveHints.DisableHints),
Batch("Hints", fixedPoint,
Expand All @@ -295,6 +296,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveFieldNameAndPosition ::
AddMetadataColumns ::
DeduplicateRelations ::
new ResolveScopedExpression(resolver) ::
new ResolveReferences(catalogManager) ::
// Please do not insert any other rules in between. See the TODO comments in rule
// ResolveLateralColumnAliasReference for more details.
Expand Down Expand Up @@ -519,6 +521,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case e if !e.resolved => u
case g: Generator => MultiAlias(g, Nil)
case c @ Cast(ne: NamedExpression, _, _, _) => Alias(c, ne.name)()
case se @ ScopedExpression(ne: NamedExpression, _) => Alias(se, ne.name)()
case e: ExtractValue if extractOnly(e) => Alias(e, toPrettySQL(e))()
case e if optGenAliasFunc.isDefined =>
Alias(child, optGenAliasFunc.get.apply(e))()
Expand Down Expand Up @@ -1654,34 +1657,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}

case mg: MapGroups if mg.dataOrder.exists(!_.resolved) =>
// Resolve against `AppendColumns`'s children, instead of `AppendColumns`,
// because `AppendColumns`'s serializer might produce conflict attribute
// names leading to ambiguous references exception.
val planForResolve = mg.child match {
case appendColumns: AppendColumns => appendColumns.child
case plan => plan
}
val resolvedOrder = mg.dataOrder
.map(resolveExpressionByPlanOutput(_, planForResolve).asInstanceOf[SortOrder])
mg.copy(dataOrder = resolvedOrder)

// Left and right sort expression have to be resolved against the respective child plan only
case cg: CoGroup if cg.leftOrder.exists(!_.resolved) || cg.rightOrder.exists(!_.resolved) =>
// Resolve against `AppendColumns`'s children, instead of `AppendColumns`,
// because `AppendColumns`'s serializer might produce conflict attribute
// names leading to ambiguous references exception.
val (leftPlanForResolve, rightPlanForResolve) = Seq(cg.left, cg.right).map {
case appendColumns: AppendColumns => appendColumns.child
case plan => plan
} match {
case Seq(left, right) => (left, right)
}

val resolvedLeftOrder = cg.leftOrder
.map(resolveExpressionByPlanOutput(_, leftPlanForResolve).asInstanceOf[SortOrder])
.map(resolveExpressionByPlanOutput(_, cg.left).asInstanceOf[SortOrder])
val resolvedRightOrder = cg.rightOrder
.map(resolveExpressionByPlanOutput(_, rightPlanForResolve).asInstanceOf[SortOrder])
.map(resolveExpressionByPlanOutput(_, cg.right).asInstanceOf[SortOrder])

cg.copy(leftOrder = resolvedLeftOrder, rightOrder = resolvedRightOrder)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
operator: LogicalPlan,
a: Attribute,
errorClass: String): Nothing = {
failUnresolvedAttribute(a, operator.inputSet.toSeq, errorClass)
}

private def failUnresolvedAttribute(
a: Attribute,
candidates: Seq[Attribute],
errorClass: String): Nothing = {
val missingCol = a.sql
val candidates = operator.inputSet.toSeq
.map(attr => attr.qualifier :+ attr.name)
val candidateNames = candidates.map(attr => attr.qualifier :+ attr.name)
val orderedCandidates =
StringUtils.orderSuggestedIdentifiersBySimilarity(missingCol, candidates)
StringUtils.orderSuggestedIdentifiersBySimilarity(missingCol, candidateNames)
throw QueryCompilationErrors.unresolvedAttributeError(
errorClass, missingCol, orderedCandidates, a.origin)
}
Expand Down Expand Up @@ -310,6 +316,24 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
ColumnDefinition.checkColumnDefinitions(operator)

var stagedError: Option[() => Unit] = None

// check for unresolved scoped expressions before any other check below,
// this improves error messages, see SPARK-42199
getAllExpressions(operator).foreach(_.foreach {
case se: ScopedExpression if !se.resolved =>
se.expr.foreachUp {
case a: Attribute if !a.resolved =>
failUnresolvedAttribute(a, se.scope.attrs, "UNRESOLVED_COLUMN")
case _ =>
}
case se: ScopedExpression => throw SparkException.internalError(
"Resolved scoped expression operator should be removed " +
f"during analysis: ${toSQLExpr(se)}",
context = se.origin.getQueryContext,
summary = se.origin.context.summary)
case _ =>
})

getAllExpressions(operator).foreach(_.foreachUp {
case a: Attribute if !a.resolved =>
failUnresolvedAttribute(operator, a, "UNRESOLVED_COLUMN")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, IntegerLiteral, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE}
import org.apache.spark.sql.connector.catalog.CatalogManager

Expand Down Expand Up @@ -51,18 +51,10 @@ class ResolveReferencesInAggregate(val catalogManager: CatalogManager) extends S
with ColumnResolutionHelper with AliasHelper {

def apply(a: Aggregate): Aggregate = {
val planForResolve = a.child match {
// SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of
// `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute
// names leading to ambiguous references exception.
case appendColumns: AppendColumns => appendColumns
case _ => a
}

val resolvedGroupExprsBasic = a.groupingExpressions
.map(resolveExpressionByPlanChildren(_, planForResolve))
.map(resolveExpressionByPlanChildren(_, a))
val resolvedAggExprsBasic = a.aggregateExpressions.map(
resolveExpressionByPlanChildren(_, planForResolve))
resolveExpressionByPlanChildren(_, a))
val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsBasic)
val resolvedAggExprsFinal = resolvedAggExprsWithLCA.map(resolveColsLastResort)
.map(_.asInstanceOf[NamedExpression])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.plans.logical.{CoGroup, LogicalPlan, MapGroups}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType


/**
* An expression that has to be resolved against a scope of resolved attributes.
*/
case class ScopedExpression(expr: Expression, scope: Seq[Attribute])
extends Expression with Unevaluable {
override def children: Seq[Expression] = expr +: scope
override def dataType: DataType = expr.dataType
override def nullable: Boolean = expr.nullable
override def prettyName: String = "scoped"
override def sql: String = s"$prettyName(${expr.sql}, $scope)"
override lazy val resolved: Boolean = expr.resolved

override protected def withNewChildrenInternal(children: IndexedSeq[Expression]): Expression = {
val scope = children.tail
assert(scope.forall(_.isInstanceOf[Attribute]), "Scope children have to be attributes")
copy(expr = children.head, scope = scope.map(_.asInstanceOf[Attribute]))
}
}

/**
* Restricts the scope of resolving some expressions.
*/
object ScopeExpressions extends Rule[LogicalPlan] {
private def scopeOrder(scope: Seq[Attribute])(sortOrder: SortOrder): SortOrder = {
sortOrder match {
case so if so.child.isInstanceOf[ScopedExpression] => so
case so => so.copy(
child = ScopedExpression(so.child, scope),
sameOrderExpressions = so.sameOrderExpressions.map(soe => ScopedExpression(soe, scope))
)
}
}

private def isNotScoped(sortOrder: SortOrder): Boolean =
!sortOrder.child.isInstanceOf[ScopedExpression]

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
// SPARK-42199: sort order of MapGroups must be scoped to their dataAttributes
case mg: MapGroups if mg.dataOrder.exists(isNotScoped) =>
mg.copy(dataOrder = mg.dataOrder.map(scopeOrder(mg.dataAttributes)))

// SPARK-42199: sort order of CoGroups must be scoped to their respective dataAttributes
case cg: CoGroup if Seq(cg.leftOrder, cg.rightOrder).exists(_.exists(isNotScoped)) =>
val scopedLeftOrder = cg.leftOrder.map(scopeOrder(cg.leftAttr))
val scopedRightOrder = cg.rightOrder.map(scopeOrder(cg.rightAttr))
cg.copy(leftOrder = scopedLeftOrder, rightOrder = scopedRightOrder)
}
}

/**
* Resolves expressions against their scope of attributes.
*/
class ResolveScopedExpression(val resolver: Resolver) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions {
case se: ScopedExpression if se.resolved => se.expr
case se @ ScopedExpression(expr, attributes) =>
val resolved = expr.transformDown {
case u@UnresolvedAttribute(nameParts) =>
attributes.resolve(nameParts, resolver).getOrElse(u)
}
if (resolved.fastEquals(expr)) {
se
} else {
resolved
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,124 @@ class AnalysisSuite extends AnalysisTest with Matchers {
queryContext = Array(ExpectedContext("SELECT *\nFROM t1\nWHERE 'true'", 31, 59)))
}

test("SPARK-42199: resolve expression against scoped attributes") {
val plan = testRelation.select($"a".as("value")).analyze
implicit val intEncoder = ExpressionEncoder[Int]()
val appendCols = AppendColumns[Int, Int]((x: Int) => x, plan)

// AppendColumns adds a duplicate 'value' column, which makes $"value" ambiguous
assertAnalysisErrorClass(
Project(
Seq($"value"),
appendCols
),
"AMBIGUOUS_REFERENCE",
Map(
"name" -> "`value`",
"referenceNames" -> "[`value`, `value`]"))

// We can use ScopedExpression to restrict resolution of $"value" to child of AppendColumns
checkAnalysis(
Project(
// 'value' is resolved against plan, not this Project's child
Seq(ScopedExpression($"value", appendCols.child.output).as("value")),
// this appends another column 'value'
appendCols
),
Project(Seq(plan.output.head.as("value")), appendCols.analyze))
}

test("SPARK-42199: MapGroups scopes sort order expressions") {
def func(k: Int, it: Iterator[Int]): Iterator[Int] = {
Iterator.empty
}

implicit val intEncoder = ExpressionEncoder[Int]()

val rel = testRelation2.analyze
val group = MapGroups(
func,
rel.output.head :: Nil,
rel.output,
SortOrder($"b", Ascending, $"b".as("b2") :: Nil) :: Nil,
rel
)

// apply the ScopeExpressions
val actualPlan = ScopeExpressions(group)
val mg = actualPlan.collectFirst {
case mg: MapGroups => mg
}

def hasScope(scope: Seq[Attribute])(sortOrder: SortOrder): Boolean = {
sortOrder.child.isInstanceOf[ScopedExpression] &&
sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope) &&
// because we have aliased the order column
sortOrder.sameOrderExpressions.nonEmpty &&
sortOrder.sameOrderExpressions.forall(_.isInstanceOf[ScopedExpression])
}

// assert sort order to be scoped
assert(mg.isDefined)
mg.foreach { mg =>
assert(mg.dataOrder.size == 1)
assert(mg.dataOrder.forall(hasScope(mg.dataAttributes)), mg.dataOrder.mkString(", "))
}
}

test("SPARK-42199: CoGroup scopes sort order expressions") {
def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = {
Iterator.empty
}

implicit val intEncoder = ExpressionEncoder[Int]()

val left = testRelation2.select($"e").analyze
val right = testRelation3.select($"e", $"f").analyze
val leftWithKey = AppendColumns[Int, Int]((x: Int) => x, left)
val rightWithKey = AppendColumns[Int, Int]((x: Int) => x, right)
val leftOrder = SortOrder($"e", Ascending) :: Nil
val rightOrder =
SortOrder($"e", Ascending, $"e".as("e2") :: Nil) ::
SortOrder($"f", Descending, $"f".as("f2") :: Nil) ::
Nil

val cogroup = leftWithKey.cogroup[Int, Int, Int, Int](
rightWithKey,
func,
leftWithKey.newColumns,
rightWithKey.newColumns,
left.output,
right.output,
leftOrder,
rightOrder
)

// apply the ScopeExpressions
val actualPlan = ScopeExpressions(cogroup)
val cg = actualPlan.collectFirst {
case cg: CoGroup => cg
}

def hasScope(scope: Seq[Attribute], hasSameOrderExpr: Boolean)
(sortOrder: SortOrder): Boolean = {
sortOrder.child.isInstanceOf[ScopedExpression] &&
sortOrder.child.asInstanceOf[ScopedExpression].scope.equals(scope) &&
// because we (may) have aliased the order column
sortOrder.sameOrderExpressions.nonEmpty == hasSameOrderExpr &&
sortOrder.sameOrderExpressions.forall(_.isInstanceOf[ScopedExpression])
}

// assert sort order to be scoped
assert(cg.isDefined)
cg.foreach { cg =>
assert(cg.leftOrder.size == 1)
assert(cg.rightOrder.size == 2)
assert(cg.leftOrder.forall(hasScope(left.output, hasSameOrderExpr = false)))
assert(cg.rightOrder.forall(hasScope(right.output, hasSameOrderExpr = true)))
}
}

test("SPARK-38591: resolve left and right CoGroup sort order on respective side only") {
def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = {
Iterator.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,9 @@ class SparkConnectPlanner(
val keyColumn = TypedAggUtils.aggKeyColumn(ds.kEncoder, ds.groupingAttributes)
val namedColumns = rel.getAggregateExpressionsList.asScala.toSeq
.map(expr => transformExpressionWithTypedReduceExpression(expr, input))
// SPARK-42199: resolve these aggregate expressions only against dataAttributes
// this is to hide key column from expression resolution
.map(ScopedExpression(_, ds.dataAttributes))
.map(toNamedExpression)
logical.Aggregate(ds.groupingAttributes, keyColumn +: namedColumns, ds.analyzed)
}
Expand Down
Loading

0 comments on commit 4b90b28

Please sign in to comment.