Skip to content

Commit

Permalink
Use multiple case-classes for compile time safety
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Aug 5, 2024
1 parent 389b87d commit a77f15b
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 102 deletions.
18 changes: 0 additions & 18 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4922,24 +4922,6 @@
},
"sqlState" : "0A000"
},
"UNSUPPORTED_UPDATE" : {
"message" : [
"Failed to update <tableName>:"
],
"subClass" : {
"NO_SET_CLAUSE" : {
"message" : [
"No set clause is provided."
]
},
"MULTIPLE_WHERE_CLAUSE" : {
"message" : [
"Multiple where clauses provided: <condition1> and <condition2>"
]
}
},
"sqlState" : "0A000"
},
"UNSUPPORTED_TYPED_LITERAL" : {
"message" : [
"Literals of the type <unsupportedType> are not supported. Supported types are <supportedTypes>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.artifact.util.ArtifactUtils$"),

// Update Writer
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter")) ++
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWriter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithAssignment"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UpdateWithCondition")) ++
mergeIntoWriterExcludeRules

checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3543,21 +3543,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
"cond" -> toSQLExpr(cond)))
}

def updateNoSetClause(tableName: String): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_UPDATE.NO_SET_CLAUSE", Map("tableName" -> tableName))
}

def updateMultipleWhereClause(tableName: String, condition1: Expression,
condition2: Expression): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_UPDATE.MULTIPLE_WHERE_CLAUSE",
messageParameters = Map(
"tableName" -> tableName,
"condition1" -> condition1.sql,
"condition2" -> condition2.sql))
}

def failedToParseExistenceDefaultAsLiteral(fieldName: String, defaultValue: String): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1344",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ class SparkSession private(
errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED",
messageParameters = Map("methodName" -> toSQLId("update")))
}
UpdateWriter(tableName, tableDF)
UpdateWriter(tableDF)
}

// scalastyle:off
Expand Down
89 changes: 58 additions & 31 deletions sql/core/src/main/scala/org/apache/spark/sql/UpdateWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,66 +18,93 @@
package org.apache.spark.sql

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.functions.expr

/**
* `UpdateWriter` provides methods to define and execute an update action on a target table.
*
* @param tableName name of table.
* @param tableDF DataFrame representing table to update.
*
* @since 4.0.0
*/
@Experimental
case class UpdateWriter (tableName: String, tableDF: DataFrame) {

private val df: DataFrame = tableDF.toDF()
private val sparkSession = tableDF.sparkSession
private val logicalPlan = df.queryExecution.logical

private var assignments: Option[Map[String, Column]] = None
private var expression: Option[Expression] = None
case class UpdateWriter (tableDF: DataFrame) {

/**
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
*/
def set(assignments: Map[String, Column]): UpdateWriter = {
this.assignments = Some(assignments)
this
def set(assignments: Map[String, Column]): UpdateWithAssignment = {
UpdateWithAssignment(tableDF, assignments)
}
}

/**
* A class for defining a condition on an update operation or directly executing it.
*
* @param tableDF DataFrame representing table to update.
* @param assignment A Map of column names to Column expressions representing the updates
* to be applied.
*
* @since 4.0.0
*/
@Experimental
case class UpdateWithAssignment(tableDF: DataFrame, assignment: Map[String, Column]) {

private val sparkSession = tableDF.sparkSession
private val logicalPlan = tableDF.queryExecution.logical

/**
* Limits the update to rows matching the specified condition.
*
* @param condition the update condition
* @return
*/
def where(condition: Column): UpdateWriter = {
expression match {
case Some(e) =>
throw QueryCompilationErrors.updateMultipleWhereClause(tableName, condition.expr, e)
case None => this.expression = Some(condition.expr)
}
this
def where(condition: Column): UpdateWithCondition = {
UpdateWithCondition(tableDF, assignment, condition)
}

/**
* Executes the update operation.
*/
def execute(): Unit = {
val update = UpdateTable(
logicalPlan,
assignment.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq,
None)
val qe = sparkSession.sessionState.executePlan(update)
qe.assertCommandExecuted()
}
}

/**
* A class for executing an update operation.
*
* @param tableDF DataFrame representing table to update.
* @param assignments A Map of column names to Column expressions representing the updates
* to be applied.
* @param condition the update condition
* @since 4.0.0
*/
@Experimental
case class UpdateWithCondition(
tableDF: DataFrame,
assignments: Map[String, Column],
condition: Column) {

private val sparkSession = tableDF.sparkSession
private val logicalPlan = tableDF.queryExecution.logical

/**
* Executes the update operation.
*/
def execute(): Unit = {
assignments match {
case Some(a) =>
val update = UpdateTable(
logicalPlan,
a.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq,
expression)
val qe = sparkSession.sessionState.executePlan(update)
qe.assertCommandExecuted()
case None => throw QueryCompilationErrors.updateNoSetClause(tableName)
}
val update = UpdateTable(
logicalPlan,
assignments.map(x => Assignment(expr(x._1).expr, x._2.expr)).toSeq,
Some(condition.expr))
val qe = sparkSession.sessionState.executePlan(update)
qe.assertCommandExecuted()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.connector

import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._

class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {
Expand Down Expand Up @@ -62,39 +62,4 @@ class UpdateDataFrameSuite extends RowLevelOperationSuiteBase {
Row(2, 150, "software"),
Row(3, 120, "software")))
}

test("Update without set clause") {
createAndInitTable("pk INT, salary INT, dep STRING",
"""{ "pk": 1, "salary": 300, "dep": 'hr' }
|{ "pk": 2, "salary": 150, "dep": 'software' }
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

val exception = intercept[AnalysisException](
spark.update(tableNameAsString)
.execute()
)

assert(exception.message.contains(
s"Failed to update $tableNameAsString: No set clause is provided"))
}

test("Update with multiple where clauses") {
createAndInitTable("pk INT, salary INT, dep STRING",
"""{ "pk": 1, "salary": 300, "dep": 'hr' }
|{ "pk": 2, "salary": 150, "dep": 'software' }
|{ "pk": 3, "salary": 120, "dep": 'hr' }
|""".stripMargin)

val exception = intercept[AnalysisException](
spark.update(tableNameAsString)
.set(Map("dep" -> lit("software")))
.where($"pk" >= 2)
.where($"pk" < 3)
.execute()
)

assert(exception.message.contains(
s"Failed to update $tableNameAsString: Multiple where clauses provided: `<`(pk, 3) and `>=`(pk, 2)"))
}
}

0 comments on commit a77f15b

Please sign in to comment.