From 3896488ed99b8857c6802d586db1a4f38e04b176 Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 17 Jan 2024 14:38:17 +0100 Subject: [PATCH] Add scope attributes to ScopedExpression children --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../expressions/namedExpressions.scala | 19 ++++++++++++------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ed0e91314903d..0304bc1c00815 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3513,7 +3513,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ object ResolveScopedExpression extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case se: ScopedExpression if se.resolved => se.child + case se: ScopedExpression if se.resolved => se.expr case se @ ScopedExpression(expr, attributes) => val resolved = expr.transformDown { case u@UnresolvedAttribute(nameParts) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index e97ac2e507e7c..b5820c198a8da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -305,7 +305,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB getAllExpressions(operator).foreach(_.foreach { case se: ScopedExpression if !se.resolved => - se.child.foreachUp { + se.expr.foreachUp { case a: Attribute if !a.resolved => failUnresolvedAttribute(a, se.scope.attrs, "UNRESOLVED_COLUMN") case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 4296a38996a3d..bbcadce2537a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -426,15 +426,20 @@ case class PrettyAttribute( /** * An expression that has to be resolved against a scope of resolved attributes. */ -case class ScopedExpression(child: Expression, scope: Seq[Attribute]) - extends UnaryExpression with Unevaluable { - override def dataType: DataType = child.dataType - override def nullable: Boolean = child.nullable +case class ScopedExpression(expr: Expression, scope: Seq[Attribute]) + extends Expression with Unevaluable { + override def children: Seq[Expression] = expr +: scope + override def dataType: DataType = expr.dataType + override def nullable: Boolean = expr.nullable override def prettyName: String = "scoped" + override def sql: String = s"$prettyName(${expr.sql}, $scope)" + override lazy val resolved: Boolean = expr.resolved - override def sql: String = s"$prettyName(${child.sql}, $scope)" - - override def withNewChildInternal(newChild: Expression): ScopedExpression = copy(child = newChild) + override protected def withNewChildrenInternal(children: IndexedSeq[Expression]): Expression = { + val scope = children.tail + assert(scope.forall(_.isInstanceOf[Attribute]), "Scope children have to be attributes") + copy(expr = children.head, scope = scope.map(_.asInstanceOf[Attribute])) + } } /**