Skip to content

Commit

Permalink
Custom SQL Analyzer (#509)
Browse files Browse the repository at this point in the history
* Run scalastyle before tests

* New analyzer to evaluate a SQL statement that returns a boolean value

---------

Co-authored-by: Yannis Mentekidis <[email protected]>
  • Loading branch information
mentekid and yannis-mentekidis authored Oct 10, 2023
1 parent ca034c3 commit 8d94160
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 0 deletions.
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
</configuration>
<executions>
<execution>
<phase>process-resources</phase>
<goals>
<goal>check</goal>
</goals>
Expand Down
72 changes: 72 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.analyzers

import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.DoubleType

import scala.util.Failure
import scala.util.Success

case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] {
/**
* Compute the state (sufficient statistics) from the data
*
* @param data data frame
* @return
*/
override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = {
val dfSql = data.sqlContext.sql(expression)
val cols = dfSql.columns.toSeq
cols match {
case Seq(resultCol) =>
val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType))
val results: Seq[Row] = dfSqlCast.collect()
if (results.size != 1) {
Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row")))
} else {
Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double])))
}
case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column")))
}
}

/**
* Compute the metric from the state (sufficient statistics)
*
* @param state wrapper holding a state of type S (required due to typing issues...)
* @return
*/
override def computeMetricFrom(state: Option[CustomSqlState]): DoubleMetric = {
state match {
// The returned state may
case Some(theState) => theState.stateOrError match {
case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Success(value))
case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException(error)))
}
case None =>
DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run")))
}
}

override private[deequ] def toFailureMetric(failure: Exception) = {
DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run")))
}
}
11 changes: 11 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/Size.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.{Column, Row}
import Analyzers._

case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] {
lazy val state = stateOrError.left.get
lazy val error = stateOrError.right.get

override def sum(other: CustomSqlState): CustomSqlState = {
CustomSqlState(Left(state + other.state))
}

override def metricValue(): Double = state
}

case class NumMatches(numMatches: Long) extends DoubleValuedState[NumMatches] {

override def sum(other: NumMatches): NumMatches = {
Expand Down
74 changes: 74 additions & 0 deletions src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.analyzers

import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.utils.FixtureSupport
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.util.Failure
import scala.util.Success

class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {

"CustomSql" should {
"return a metric when the statement returns exactly one value" in withSparkSession { session =>
val data = getDfWithStringColumns(session)
data.createOrReplaceTempView("primary")

val sql = CustomSql("SELECT COUNT(*) FROM primary WHERE `Address Line 2` IS NOT NULL")
val state = sql.computeStateFrom(data)
val metric: DoubleMetric = sql.computeMetricFrom(state)

metric.value.isSuccess shouldBe true
metric.value.get shouldBe 6.0
}

"returns a failed metric when the statement returns more than one row" in withSparkSession { session =>
val data = getDfWithStringColumns(session)
data.createOrReplaceTempView("primary")

val sql = CustomSql("Select `Address Line 2` FROM primary WHERE `Address Line 2` is NOT NULL")
val state = sql.computeStateFrom(data)
val metric = sql.computeMetricFrom(state)

metric.value.isFailure shouldBe true
metric.value match {
case Success(_) => fail("Should have failed")
case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 row"
}
}

"returns a failed metric when the statement returns more than one column" in withSparkSession { session =>
val data = getDfWithStringColumns(session)
data.createOrReplaceTempView("primary")

val sql = CustomSql(
"Select `Address Line 1`, `Address Line 2` FROM primary WHERE `Address Line 3` like 'Bandra%'")
val state = sql.computeStateFrom(data)
val metric = sql.computeMetricFrom(state)

metric.value.isFailure shouldBe true
metric.value match {
case Success(_) => fail("Should have failed")
case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 column"
}

}
}
}

0 comments on commit 8d94160

Please sign in to comment.