From 83733a52d9e4f28e790286fa255f0d9c643f4982 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 2 Aug 2024 10:52:27 -0700 Subject: [PATCH] [SQL] Add write options for INSERT ### What changes were proposed in this pull request? Add `tbl WITH (k1=v1, k2=v2) for INSERT, INSERT OVERWRITE, and INSERT... REPLACE WHERE ### Why are the changes needed? Follow up for SPARK-36680 which added WITH for SELECT statement ### Does this PR introduce _any_ user-facing change? Adds new SQL syntax ### How was this patch tested? New test in DataSourceV2SQLSuite ### Was this patch authored or co-authored using generative AI tooling? No --- .../sql/catalyst/parser/SqlBaseParser.g4 | 6 +- .../sql/catalyst/parser/AstBuilder.scala | 42 ++++++++----- .../sql/connector/DataSourceV2SQLSuite.scala | 63 ++++++++++++++++++- 3 files changed, 92 insertions(+), 19 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index c7aa56cf920ad..11e2d41664886 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -406,9 +406,9 @@ query ; insertInto - : INSERT OVERWRITE TABLE? identifierReference (partitionSpec (IF errorCapturingNot EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable - | INSERT INTO TABLE? identifierReference partitionSpec? (IF errorCapturingNot EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable - | INSERT INTO TABLE? identifierReference REPLACE whereClause #insertIntoReplaceWhere + : INSERT OVERWRITE TABLE? identifierReference optionsClause? (partitionSpec (IF errorCapturingNot EXISTS)?)? ((BY NAME) | identifierList)? #insertOverwriteTable + | INSERT INTO TABLE? identifierReference optionsClause? partitionSpec? (IF errorCapturingNot EXISTS)? ((BY NAME) | identifierList)? #insertIntoTable + | INSERT INTO TABLE? identifierReference optionsClause? REPLACE whereClause #insertIntoReplaceWhere | INSERT OVERWRITE LOCAL? DIRECTORY path=stringLit rowFormat? createFileFormat? #insertOverwriteHiveDir | INSERT OVERWRITE LOCAL? DIRECTORY (path=stringLit)? tableProvider (OPTIONS options=propertyList)? #insertOverwriteDir ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a046ededf964c..fbd9c076d4295 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -382,10 +382,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper /** * Parameters used for writing query to a table: - * (table ident, tableColumnList, partitionKeys, ifPartitionNotExists, byName). + * (table ident, options, tableColumnList, partitionKeys, ifPartitionNotExists, byName). */ type InsertTableParams = - (IdentifierReferenceContext, Seq[String], Map[String, Option[String]], Boolean, Boolean) + (IdentifierReferenceContext, Option[OptionsClauseContext], Seq[String], + Map[String, Option[String]], Boolean, Boolean) /** * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). @@ -412,11 +413,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper // 2. Write commands do not hold the table logical plan as a child, and we need to add // additional resolution code to resolve identifiers inside the write commands. case table: InsertIntoTableContext => - val (relationCtx, cols, partition, ifPartitionNotExists, byName) + val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) = visitInsertIntoTable(table) withIdentClause(relationCtx, ident => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident), + createUnresolvedRelation(relationCtx, ident, options), partition, cols, query, @@ -425,11 +426,11 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper byName) }) case table: InsertOverwriteTableContext => - val (relationCtx, cols, partition, ifPartitionNotExists, byName) + val (relationCtx, options, cols, partition, ifPartitionNotExists, byName) = visitInsertOverwriteTable(table) withIdentClause(relationCtx, ident => { InsertIntoStatement( - createUnresolvedRelation(relationCtx, ident), + createUnresolvedRelation(relationCtx, ident, options), partition, cols, query, @@ -440,7 +441,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper case ctx: InsertIntoReplaceWhereContext => withIdentClause(ctx.identifierReference, ident => { OverwriteByExpression.byPosition( - createUnresolvedRelation(ctx.identifierReference, ident), + createUnresolvedRelation(ctx.identifierReference, ident, Option(ctx.optionsClause())), query, expression(ctx.whereClause().booleanExpression())) }) @@ -469,7 +470,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper invalidStatement("INSERT INTO ... IF NOT EXISTS", ctx) } - (ctx.identifierReference, cols, partitionKeys, false, ctx.NAME() != null) + (ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys, false, + ctx.NAME() != null) } /** @@ -489,7 +491,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper dynamicPartitionKeys.keys.mkString(", "), ctx) } - (ctx.identifierReference, cols, partitionKeys, ctx.EXISTS() != null, ctx.NAME() != null) + (ctx.identifierReference, Option(ctx.optionsClause()), cols, partitionKeys, + ctx.EXISTS() != null, ctx.NAME() != null) } /** @@ -3067,9 +3070,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper private def createUnresolvedRelation( ctx: IdentifierReferenceContext, optionsClause: Option[OptionsClauseContext] = None): LogicalPlan = withOrigin(ctx) { - val options = optionsClause.map{ clause => - new CaseInsensitiveStringMap(visitPropertyKeyValues(clause.options).asJava) - }.getOrElse(CaseInsensitiveStringMap.empty) + val options = resolveOptions(optionsClause) withIdentClause(ctx, parts => new UnresolvedRelation(parts, options, isStreaming = false)) } @@ -3078,8 +3079,18 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper * Create an [[UnresolvedRelation]] from a multi-part identifier. */ private def createUnresolvedRelation( - ctx: ParserRuleContext, ident: Seq[String]): UnresolvedRelation = withOrigin(ctx) { - UnresolvedRelation(ident) + ctx: ParserRuleContext, + ident: Seq[String], + optionsClause: Option[OptionsClauseContext]): UnresolvedRelation = withOrigin(ctx) { + val options = resolveOptions(optionsClause) + new UnresolvedRelation(ident, options, isStreaming = false) + } + + private def resolveOptions( + optionsClause: Option[OptionsClauseContext]): CaseInsensitiveStringMap = { + optionsClause.map{ clause => + new CaseInsensitiveStringMap(visitPropertyKeyValues(clause.options).asJava) + }.getOrElse(CaseInsensitiveStringMap.empty) } /** @@ -4948,7 +4959,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper if (query.isDefined) { CacheTableAsSelect(ident.head, query.get, source(ctx.query()), isLazy, options) } else { - CacheTable(createUnresolvedRelation(ctx.identifierReference, ident), ident, isLazy, options) + CacheTable(createUnresolvedRelation(ctx.identifierReference, ident, None), + ident, isLazy, options) } }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 51d7f270e1a53..ba7760223a487 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.catalyst.plans.logical.ColumnStat +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, ColumnStat, CommandResult, OverwriteByExpression} import org.apache.spark.sql.catalyst.statsEstimation.StatsEstimationTestBase import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.{Column => ColumnV2, _} @@ -41,6 +41,7 @@ import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} @@ -3547,6 +3548,66 @@ class DataSourceV2SQLSuiteV1Filter } } + test("SPARK-36680: Supports Dynamic Table Options for Insert") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')") + + val collected = df.queryExecution.optimizedPlan.collect { + case CommandResult(_, AppendData(relation: DataSourceV2Relation, _, _, _, _, _), _, _) => + assert(relation.options.get("write.split-size") == "10") + } + assert (collected.size == 1) + + val insertResult = sql(s"SELECT * FROM $t1") + checkAnswer(insertResult, Seq(Row(1, "a"), Row(2, "b"))) + } + } + + test("SPARK-36680: Supports Dynamic Table Options for Insert Overwrite") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')") + + val df = sql(s"INSERT OVERWRITE $t1 WITH (`write.split-size` = 10) " + + s"VALUES (3, 'c'), (4, 'd')") + val collected = df.queryExecution.optimizedPlan.collect { + case CommandResult(_, + OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _), + _, _) => + assert(relation.options.get("write.split-size") == "10") + } + assert (collected.size == 1) + + val insertResult = sql(s"SELECT * FROM $t1") + checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d"))) + } + } + + test("SPARK-36680: Supports Dynamic Table Options for Insert Replace") { + val t1 = s"${catalogAndNamespace}table" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) VALUES (1, 'a'), (2, 'b')") + + val df = sql(s"INSERT INTO $t1 WITH (`write.split-size` = 10) " + + s"REPLACE WHERE TRUE " + + s"VALUES (3, 'c'), (4, 'd')") + val collected = df.queryExecution.optimizedPlan.collect { + case CommandResult(_, + OverwriteByExpression(relation: DataSourceV2Relation, _, _, _, _, _, _), + _, _) => + assert(relation.options.get("write.split-size") == "10") + } + assert (collected.size == 1) + + val insertResult = sql(s"SELECT * FROM $t1") + checkAnswer(insertResult, Seq(Row(3, "c"), Row(4, "d"))) + } + } + private def testNotSupportedV2Command( sqlCommand: String, sqlParams: String,