From a77f15b7b6459146217b905cbd40b50cbfb8adeb Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 5 Aug 2024 14:10:37 -0700 Subject: [PATCH] Use multiple case-classes for compile time safety --- .../resources/error/error-conditions.json | 18 ---- .../CheckConnectJvmClientCompatibility.scala | 4 +- .../sql/errors/QueryCompilationErrors.scala | 15 ---- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../org/apache/spark/sql/UpdateWriter.scala | 89 ++++++++++++------- .../sql/connector/UpdateDataFrameSuite.scala | 37 +------- 6 files changed, 63 insertions(+), 102 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 850f23d142fa3..7f54a77c94a0f 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4922,24 +4922,6 @@ }, "sqlState" : "0A000" }, - "UNSUPPORTED_UPDATE" : { - "message" : [ - "Failed to update :" - ], - "subClass" : { - "NO_SET_CLAUSE" : { - "message" : [ - "No set clause is provided." - ] - }, - "MULTIPLE_WHERE_CLAUSE" : { - "message" : [ - "Multiple where clauses provided: and " - ] - } - }, - "sqlState" : "0A000" - }, "UNSUPPORTED_TYPED_LITERAL" : { "message" : [ "Literals of the type are not supported. Supported types are ." diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 6142d7aa55d7f..ccab0347e6d81 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -318,7 +318,9 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.artifact.util.ArtifactUtils$"), // Update Writer - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter")) ++ + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithAssignment"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithCondition")) ++ mergeIntoWriterExcludeRules checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 72fb486d4aef6..7d32bd48f1659 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3543,21 +3543,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "cond" -> toSQLExpr(cond))) } - def updateNoSetClause(tableName: String): Throwable = { - new AnalysisException( - errorClass = "UNSUPPORTED_UPDATE.NO_SET_CLAUSE", Map("tableName" -> tableName)) - } - - def updateMultipleWhereClause(tableName: String, condition1: Expression, - condition2: Expression): Throwable = { - new AnalysisException( - errorClass = "UNSUPPORTED_UPDATE.MULTIPLE_WHERE_CLAUSE", - messageParameters = Map( - "tableName" -> tableName, - "condition1" -> condition1.sql, - "condition2" -> condition2.sql)) - } - def failedToParseExistenceDefaultAsLiteral(fieldName: String, defaultValue: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1344", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c05d5ca3abff9..c960f1125aad3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -860,7 +860,7 @@ class SparkSession private( errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", messageParameters = Map("methodName" -> toSQLId("update"))) } - UpdateWriter(tableName, tableDF) + UpdateWriter(tableDF) } // scalastyle:off diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala index d58aeaa5b091a..5edd4d9dddc62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala @@ -18,37 +18,42 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.functions.expr /** * `UpdateWriter` provides methods to define and execute an update action on a target table. * - * @param tableName name of table. * @param tableDF DataFrame representing table to update. * * @since 4.0.0 */ @Experimental -case class UpdateWriter (tableName: String, tableDF: DataFrame) { - - private val df: DataFrame = tableDF.toDF() - private val sparkSession = tableDF.sparkSession - private val logicalPlan = df.queryExecution.logical - - private var assignments: Option[Map[String, Column]] = None - private var expression: Option[Expression] = None +case class UpdateWriter (tableDF: DataFrame) { /** * @param assignments A Map of column names to Column expressions representing the updates * to be applied. */ - def set(assignments: Map[String, Column]): UpdateWriter = { - this.assignments = Some(assignments) - this + def set(assignments: Map[String, Column]): UpdateWithAssignment = { + UpdateWithAssignment(tableDF, assignments) } +} + +/** + * A class for defining a condition on an update operation or directly executing it. + * + * @param tableDF DataFrame representing table to update. + * @param assignment A Map of column names to Column expressions representing the updates + * to be applied. + * + * @since 4.0.0 + */ +@Experimental +case class UpdateWithAssignment(tableDF: DataFrame, assignment: Map[String, Column]) { + + private val sparkSession = tableDF.sparkSession + private val logicalPlan = tableDF.queryExecution.logical /** * Limits the update to rows matching the specified condition. @@ -56,28 +61,50 @@ case class UpdateWriter (tableName: String, tableDF: DataFrame) { * @param condition the update condition * @return */ - def where(condition: Column): UpdateWriter = { - expression match { - case Some(e) => - throw QueryCompilationErrors.updateMultipleWhereClause(tableName, condition.expr, e) - case None => this.expression = Some(condition.expr) - } - this + def where(condition: Column): UpdateWithCondition = { + UpdateWithCondition(tableDF, assignment, condition) + } + + /** + * Executes the update operation. + */ + def execute(): Unit = { + val update = UpdateTable( + logicalPlan, + assignment.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq, + None) + val qe = sparkSession.sessionState.executePlan(update) + qe.assertCommandExecuted() } +} + +/** + * A class for executing an update operation. + * + * @param tableDF DataFrame representing table to update. + * @param assignments A Map of column names to Column expressions representing the updates + * to be applied. + * @param condition the update condition + * @since 4.0.0 + */ +@Experimental +case class UpdateWithCondition( + tableDF: DataFrame, + assignments: Map[String, Column], + condition: Column) { + + private val sparkSession = tableDF.sparkSession + private val logicalPlan = tableDF.queryExecution.logical /** * Executes the update operation. */ def execute(): Unit = { - assignments match { - case Some(a) => - val update = UpdateTable( - logicalPlan, - a.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq, - expression) - val qe = sparkSession.sessionState.executePlan(update) - qe.assertCommandExecuted() - case None => throw QueryCompilationErrors.updateNoSetClause(tableName) - } + val update = UpdateTable( + logicalPlan, + assignments.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq, + Some(condition.expr)) + val qe = sparkSession.sessionState.executePlan(update) + qe.assertCommandExecuted() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala index 3e4017dfde2e6..b74e16201636f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/UpdateDataFrameSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.functions._ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase { @@ -62,39 +62,4 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase { Row(2, 150, "software"), Row(3, 120, "software"))) } - - test("Update without set clause") { - createAndInitTable("pk INT, salary INT, dep STRING", - """{ "pk": 1, "salary": 300, "dep": 'hr' } - |{ "pk": 2, "salary": 150, "dep": 'software' } - |{ "pk": 3, "salary": 120, "dep": 'hr' } - |""".stripMargin) - - val exception = intercept[AnalysisException]( - spark.update(tableNameAsString) - .execute() - ) - - assert(exception.message.contains( - s"Failed to update $tableNameAsString: No set clause is provided")) - } - - test("Update with multiple where clauses") { - createAndInitTable("pk INT, salary INT, dep STRING", - """{ "pk": 1, "salary": 300, "dep": 'hr' } - |{ "pk": 2, "salary": 150, "dep": 'software' } - |{ "pk": 3, "salary": 120, "dep": 'hr' } - |""".stripMargin) - - val exception = intercept[AnalysisException]( - spark.update(tableNameAsString) - .set(Map("dep" -> lit("software"))) - .where($"pk" >= 2) - .where($"pk" < 3) - .execute() - ) - - assert(exception.message.contains( - s"Failed to update $tableNameAsString: Multiple where clauses provided: `<`(pk, 3) and `>=`(pk, 2)")) - } }