Skip to content

WIP: Update analyzer to resolve multipart identifiers. #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ queryTerm

queryPrimary
: querySpecification #queryPrimaryDefault
| TABLE tableIdentifier #table
| TABLE multipartIdentifier #table
| inlineTable #inlineTableDefault1
| '(' queryNoWith ')' #subquery
;
Expand Down Expand Up @@ -536,7 +536,7 @@ identifierComment
;

relationPrimary
: tableIdentifier sample? tableAlias #tableName
: multipartIdentifier sample? tableAlias #tableName
| '(' queryNoWith ')' sample? tableAlias #aliasedQuery
| '(' relation ')' sample? tableAlias #aliasedRelation
| inlineTable #inlineTableDefault2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class Analyzer(

lazy val batches: Seq[Batch] = Seq(
Batch("Hints", fixedPoint,
new ResolveHints.ResolveJoinStrategyHints(conf),
new ResolveHints.ResolveJoinStrategyHints(conf, lookupCatalog),
ResolveHints.ResolveCoalesceHints,
ResolveHints.RemoveAllHints),
Batch("Simple Sanity Check", Once,
Expand Down Expand Up @@ -224,9 +224,8 @@ class Analyzer(

def substituteCTE(plan: LogicalPlan, cteRelations: Seq[(String, LogicalPlan)]): LogicalPlan = {
plan resolveOperatorsDown {
case u: UnresolvedRelation =>
cteRelations.find(x => resolver(x._1, u.tableIdentifier.table))
.map(_._2).getOrElse(u)
case u @ UnresolvedRelation(Seq(table)) =>
cteRelations.find(x => resolver(x._1, table)).map(_._2).getOrElse(u)
case other =>
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other transformExpressions {
Expand Down Expand Up @@ -686,10 +685,15 @@ class Analyzer(
// Note this is compatible with the views defined by older versions of Spark(before 2.2), which
// have empty defaultDatabase and all the relations in viewText have database part defined.
def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
case u: UnresolvedRelation if !isRunningDirectlyOnFiles(u.tableIdentifier) =>
case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) =>
val defaultDatabase = AnalysisContext.get.defaultDatabase
val foundRelation = lookupTableFromCatalog(u, defaultDatabase)
resolveRelation(foundRelation)
val foundRelation = lookupTableFromCatalog(ident, u, defaultDatabase)
if (foundRelation != u) {
resolveRelation(foundRelation)
} else {
u
}

// The view's child should be a logical plan parsed from the `desc.viewText`, the variable
// `viewText` should be defined, or else we throw an error on the generation of the View
// operator.
Expand All @@ -712,8 +716,9 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case i @ InsertIntoTable(u @ UnresolvedRelation(AsTableIdentifier(ident)), _, child, _, _)
if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(ident, u)) match {
case v: View =>
u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.")
case other => i.copy(table = other)
Expand All @@ -728,20 +733,16 @@ class Analyzer(
// and the default database is only used to look up a view);
// 3. Use the currentDb of the SessionCatalog.
private def lookupTableFromCatalog(
tableIdentifier: TableIdentifier,
u: UnresolvedRelation,
defaultDatabase: Option[String] = None): LogicalPlan = {
val tableIdentWithDb = u.tableIdentifier.copy(
database = u.tableIdentifier.database.orElse(defaultDatabase))
val tableIdentWithDb = tableIdentifier.copy(
database = tableIdentifier.database.orElse(defaultDatabase))
try {
catalog.lookupRelation(tableIdentWithDb)
} catch {
case e: NoSuchTableException =>
u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}", e)
// If the database is defined and that database is not found, throw an AnalysisException.
// Note that if the database is not defined, it is possible we are looking up a temp view.
case e: NoSuchDatabaseException =>
u.failAnalysis(s"Table or view not found: ${tableIdentWithDb.unquotedString}, the " +
s"database ${e.db} doesn't exist.", e)
case _: NoSuchTableException | _: NoSuchDatabaseException =>
u
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ trait CheckAnalysis extends PredicateHelper {
case p if p.analyzed => // Skip already analyzed sub-plans

case u: UnresolvedRelation =>
u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}")
u.failAnalysis(s"Table or view not found: ${u.multipartIdentifier.quoted}")

case operator: LogicalPlan =>
// Check argument data types of higher-order functions downwards first.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.Locale
import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, LookupCatalog}
import org.apache.spark.sql.catalyst.expressions.IntegerLiteral
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -50,11 +51,15 @@ object ResolveHints {
*
* This rule must happen before common table expressions.
*/
class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] {
class ResolveJoinStrategyHints(
conf: SQLConf,
catalogLookup: String => CatalogPlugin) extends Rule[LogicalPlan] with LookupCatalog {
private val STRATEGY_HINT_NAMES = JoinStrategyHint.strategies.flatMap(_.hintAliases)

def resolver: Resolver = conf.resolver

override protected def lookupCatalog(name: String): CatalogPlugin = catalogLookup(name)

private def createHintInfo(hintName: String): HintInfo = {
HintInfo(strategy =
JoinStrategyHint.strategies.find(
Expand All @@ -71,18 +76,20 @@ object ResolveHints {

val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case ResolvedHint(u: UnresolvedRelation, hint)
if relations.exists(resolver(_, u.tableIdentifier.table)) =>
relations.remove(u.tableIdentifier.table)
case ResolvedHint(u @ UnresolvedRelation(ident), hint)
if relations.exists(resolver(_, ident.last)) =>
relations.remove(ident.last)
ResolvedHint(u, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo))

case ResolvedHint(r: SubqueryAlias, hint)
if relations.exists(resolver(_, r.alias)) =>
relations.remove(r.alias)
ResolvedHint(r, createHintInfo(hintName).merge(hint, handleOverriddenHintInfo))

case u: UnresolvedRelation if relations.exists(resolver(_, u.tableIdentifier.table)) =>
relations.remove(u.tableIdentifier.table)
case u @ UnresolvedRelation(ident) if relations.exists(resolver(_, ident.last)) =>
relations.remove(ident.last)
ResolvedHint(plan, createHintInfo(hintName))

case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
relations.remove(r.alias)
ResolvedHint(plan, createHintInfo(hintName))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,32 @@ class UnresolvedException[TreeType <: TreeNode[_]](tree: TreeType, function: Str
/**
* Holds the name of a relation that has yet to be looked up in a catalog.
*
* @param tableIdentifier table name
* @param multipartIdentifier table name
*/
case class UnresolvedRelation(tableIdentifier: TableIdentifier)
extends LeafNode {
case class UnresolvedRelation(multipartIdentifier: Seq[String]) extends LeafNode {
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._

/** Returns a `.` separated name for this relation. */
def tableName: String = tableIdentifier.unquotedString
def tableName: String = multipartIdentifier.quoted

override def output: Seq[Attribute] = Nil

override lazy val resolved = false
}

object UnresolvedRelation {
def apply(tableIdentifier: TableIdentifier): UnresolvedRelation = {
val multipartIdentifier = tableIdentifier.database match {
case Some(db) =>
Seq(db, tableIdentifier.table)
case None =>
Seq(tableIdentifier.table)
}

UnresolvedRelation(multipartIdentifier)
}
}

/**
* An inline table that has not been resolved yet. Once resolved, it is turned by the analyzer into
* a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,14 +844,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
* }}}
*/
override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier))
UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier))
}

/**
* Create an aliased table reference. This is typically used in FROM clauses.
*/
override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
val tableId = visitTableIdentifier(ctx.tableIdentifier)
val tableId = visitMultipartIdentifier(ctx.multipartIdentifier)
val table = mayApplyAliasPlan(ctx.tableAlias, UnresolvedRelation(tableId))
table.optionalMap(ctx.sample)(withSample)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,17 @@ class SparkSession private(
* @since 2.0.0
*/
def table(tableName: String): DataFrame = {
table(sessionState.sqlParser.parseTableIdentifier(tableName))
table(sessionState.sqlParser.parseMultipartIdentifier(tableName))
}

private[sql] def table(tableIdent: TableIdentifier): DataFrame = {
Dataset.ofRows(self, UnresolvedRelation(tableIdent))
}

private[sql] def table(multipartIdentifier: Seq[String]): DataFrame = {
Dataset.ofRows(self, UnresolvedRelation(multipartIdentifier))
}

/* ----------------- *
| Everything else |
* ----------------- */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command
import scala.collection.mutable

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, LookupCatalog}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
Expand Down Expand Up @@ -181,6 +182,11 @@ case class CreateViewCommand(
* Permanent views are not allowed to reference temp objects, including temp function and views
*/
private def verifyTemporaryObjectsNotExists(sparkSession: SparkSession): Unit = {
val lookup = new LookupCatalog {
override protected def lookupCatalog(name: String): CatalogPlugin = sparkSession.catalog(name)
}
import lookup._

if (!isTemporary) {
// This func traverses the unresolved plan `child`. Below are the reasons:
// 1) Analyzer replaces unresolved temporary views by a SubqueryAlias with the corresponding
Expand All @@ -190,10 +196,11 @@ case class CreateViewCommand(
// package (e.g., HiveGenericUDF).
child.collect {
// Disallow creating permanent views based on temporary views.
case s: UnresolvedRelation
if sparkSession.sessionState.catalog.isTemporaryTable(s.tableIdentifier) =>
case UnresolvedRelation(AsTableIdentifier(ident))
if sparkSession.sessionState.catalog.isTemporaryTable(ident) =>
// temporary views are only stored in the session catalog
throw new AnalysisException(s"Not allowed to create a permanent view $name by " +
s"referencing a temporary view ${s.tableIdentifier}")
s"referencing a temporary view $ident")
case other if !other.resolved => other.expressions.flatMap(_.collect {
// Disallow creating permanent views based on temporary UDFs.
case e: UnresolvedFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale

import scala.collection.mutable
import scala.util.Try

import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog}
import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils}
import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DropTableCommand
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.TableProvider
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

case class DataSourceResolution(
conf: SQLConf,
Expand Down Expand Up @@ -98,6 +101,14 @@ case class DataSourceResolution(

case DropViewStatement(AsTableIdentifier(tableName), ifExists) =>
DropTableCommand(tableName, ifExists, isView = true, purge = false)

case u @ UnresolvedRelation(CatalogObjectIdentifier(Some(catalog), ident)) =>
Try(catalog.asTableCatalog.loadTable(ident)).toOption match {
case Some(table) =>
DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty)
case _ =>
u
}
}

object V1WriteProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale

import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession}
import org.apache.spark.sql.catalog.v2.{CatalogPlugin, LookupCatalog}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering}
Expand All @@ -35,18 +36,22 @@ import org.apache.spark.sql.util.SchemaUtils
/**
* Replaces [[UnresolvedRelation]]s if the plan is for direct query on files.
*/
class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] with LookupCatalog {

override protected def lookupCatalog(name: String): CatalogPlugin = sparkSession.catalog(name)

private def maybeSQLFile(u: UnresolvedRelation): Boolean = {
sparkSession.sessionState.conf.runSQLonFile && u.tableIdentifier.database.isDefined
sparkSession.sessionState.conf.runSQLonFile && u.multipartIdentifier.size == 2 &&
u.multipartIdentifier.last.contains("/")
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case u: UnresolvedRelation if maybeSQLFile(u) =>
try {
val dataSource = DataSource(
sparkSession,
paths = u.tableIdentifier.table :: Nil,
className = u.tableIdentifier.database.get)
paths = u.multipartIdentifier.last :: Nil,
className = u.multipartIdentifier.head)

// `dataSource.providingClass` may throw ClassNotFoundException, then the outer try-catch
// will catch it and return the original plan, so that the analyzer can report table not
Expand All @@ -55,7 +60,7 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
if (!isFileFormat ||
dataSource.className.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Unsupported data source type for direct query on files: " +
s"${u.tableIdentifier.database.get}")
s"${dataSource.className}")
}
LogicalRelation(dataSource.resolveRelation())
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal

import org.apache.spark.SparkConf
import org.apache.spark.annotation.{Experimental, Unstable}
import org.apache.spark.sql.catalog.v2.CatalogPlugin
import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
Expand Down Expand Up @@ -185,6 +186,8 @@ abstract class BaseSessionStateBuilder(
V2WriteSupportCheck +:
V2StreamingScanSupportCheck +:
customCheckRules

override protected def lookupCatalog(name: String): CatalogPlugin = session.catalog(name)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2160,7 +2160,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
assert(output.contains(
"""== Parsed Logical Plan ==
|'Project [*]
|+- 'UnresolvedRelation `tmp`""".stripMargin))
|+- 'UnresolvedRelation [tmp]""".stripMargin))
assert(output.contains(
"""== Physical Plan ==
|*(1) Range (0, 10, step=1, splits=2)""".stripMargin))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(commands(2)._1 == "insertInto")
assert(commands(2)._2.isInstanceOf[InsertIntoTable])
assert(commands(2)._2.asInstanceOf[InsertIntoTable].table
.asInstanceOf[UnresolvedRelation].tableIdentifier.table == "tab")
.asInstanceOf[UnresolvedRelation].multipartIdentifier == Seq("tab"))
}
// exiting withTable adds commands(3) via onSuccess (drops tab)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive

import org.apache.spark.annotation.{Experimental, Unstable}
import org.apache.spark.sql._
import org.apache.spark.sql.catalog.v2.CatalogPlugin
import org.apache.spark.sql.catalyst.analysis.Analyzer
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -91,6 +92,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
V2WriteSupportCheck +:
V2StreamingScanSupportCheck +:
customCheckRules

override protected def lookupCatalog(name: String): CatalogPlugin = session.catalog(name)
}

/**
Expand Down
Loading