Skip to content

Commit

Permalink
Fail when CustomSql has syntax errors (#510)
Browse files Browse the repository at this point in the history
* Fail when CustomSql has syntax errors

* Relax CustomSql syntax error test expectation

---------

Co-authored-by: Yannis Mentekidis <[email protected]>
  • Loading branch information
mentekid and yannis-mentekidis authored Oct 12, 2023
1 parent 8d94160 commit d7eb316
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 11 deletions.
29 changes: 18 additions & 11 deletions src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.types.DoubleType

import scala.util.Failure
import scala.util.Success
import scala.util.Try

case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] {
/**
Expand All @@ -33,18 +34,24 @@ case class CustomSql(expression: String) extends Analyzer[CustomSqlState, Double
* @return
*/
override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = {
val dfSql = data.sqlContext.sql(expression)
val cols = dfSql.columns.toSeq
cols match {
case Seq(resultCol) =>
val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType))
val results: Seq[Row] = dfSqlCast.collect()
if (results.size != 1) {
Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row")))
} else {
Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double])))

Try {
data.sqlContext.sql(expression)
} match {
case Failure(e) => Some(CustomSqlState(Right(e.getMessage)))
case Success(dfSql) =>
val cols = dfSql.columns.toSeq
cols match {
case Seq(resultCol) =>
val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType))
val results: Seq[Row] = dfSqlCast.collect()
if (results.size != 1) {
Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row")))
} else {
Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double])))
}
case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column")))
}
case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column")))
}
}

Expand Down
14 changes: 14 additions & 0 deletions src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,21 @@ class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with
case Success(_) => fail("Should have failed")
case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 column"
}
}

"returns the error if the SQL statement has a syntax error" in withSparkSession { session =>
val data = getDfWithStringColumns(session)
data.createOrReplaceTempView("primary")

val sql = CustomSql("Select `foo` from primary")
val state = sql.computeStateFrom(data)
val metric = sql.computeMetricFrom(state)

metric.value.isFailure shouldBe true
metric.value match {
case Success(_) => fail("Should have failed")
case Failure(exception) => exception.getMessage should include("`foo`")
}
}
}
}

0 comments on commit d7eb316

Please sign in to comment.