From 8d941609a9aa4bda75c5b70e0e7e1a1dbeb71fa3 Mon Sep 17 00:00:00 2001 From: Yannis Mentekidis Date: Tue, 10 Oct 2023 16:51:17 -0400 Subject: [PATCH] Custom SQL Analyzer (#509) * Run scalastyle before tests * New analyzer to evaluate a SQL statement that returns a boolean value --------- Co-authored-by: Yannis Mentekidis --- pom.xml | 1 + .../amazon/deequ/analyzers/CustomSql.scala | 72 ++++++++++++++++++ .../com/amazon/deequ/analyzers/Size.scala | 11 +++ .../deequ/analyzers/CustomSqlTest.scala | 74 +++++++++++++++++++ 4 files changed, 158 insertions(+) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala create mode 100644 src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala diff --git a/pom.xml b/pom.xml index 318ba4ab7..262a670f6 100644 --- a/pom.xml +++ b/pom.xml @@ -237,6 +237,7 @@ + process-resources check diff --git a/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala new file mode 100644 index 000000000..caf90f423 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala @@ -0,0 +1,72 @@ +/** + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.DoubleType + +import scala.util.Failure +import scala.util.Success + +case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { + /** + * Compute the state (sufficient statistics) from the data + * + * @param data data frame + * @return + */ + override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = { + val dfSql = data.sqlContext.sql(expression) + val cols = dfSql.columns.toSeq + cols match { + case Seq(resultCol) => + val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType)) + val results: Seq[Row] = dfSqlCast.collect() + if (results.size != 1) { + Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row"))) + } else { + Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double]))) + } + case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column"))) + } + } + + /** + * Compute the metric from the state (sufficient statistics) + * + * @param state wrapper holding a state of type S (required due to typing issues...) + * @return + */ + override def computeMetricFrom(state: Option[CustomSqlState]): DoubleMetric = { + state match { + // The returned state may + case Some(theState) => theState.stateOrError match { + case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Success(value)) + case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException(error))) + } + case None => + DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + } + } + + override private[deequ] def toFailureMetric(failure: Exception) = { + DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) + } +} diff --git a/src/main/scala/com/amazon/deequ/analyzers/Size.scala b/src/main/scala/com/amazon/deequ/analyzers/Size.scala index c56083abe..a5080084a 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Size.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Size.scala @@ -20,6 +20,17 @@ import com.amazon.deequ.metrics.Entity import org.apache.spark.sql.{Column, Row} import Analyzers._ +case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] { + lazy val state = stateOrError.left.get + lazy val error = stateOrError.right.get + + override def sum(other: CustomSqlState): CustomSqlState = { + CustomSqlState(Left(state + other.state)) + } + + override def metricValue(): Double = state +} + case class NumMatches(numMatches: Long) extends DoubleValuedState[NumMatches] { override def sum(other: NumMatches): NumMatches = { diff --git a/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala new file mode 100644 index 000000000..8b3990545 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala @@ -0,0 +1,74 @@ +/** + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.metrics.DoubleMetric +import com.amazon.deequ.utils.FixtureSupport +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Failure +import scala.util.Success + +class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + + "CustomSql" should { + "return a metric when the statement returns exactly one value" in withSparkSession { session => + val data = getDfWithStringColumns(session) + data.createOrReplaceTempView("primary") + + val sql = CustomSql("SELECT COUNT(*) FROM primary WHERE `Address Line 2` IS NOT NULL") + val state = sql.computeStateFrom(data) + val metric: DoubleMetric = sql.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get shouldBe 6.0 + } + + "returns a failed metric when the statement returns more than one row" in withSparkSession { session => + val data = getDfWithStringColumns(session) + data.createOrReplaceTempView("primary") + + val sql = CustomSql("Select `Address Line 2` FROM primary WHERE `Address Line 2` is NOT NULL") + val state = sql.computeStateFrom(data) + val metric = sql.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed") + case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 row" + } + } + + "returns a failed metric when the statement returns more than one column" in withSparkSession { session => + val data = getDfWithStringColumns(session) + data.createOrReplaceTempView("primary") + + val sql = CustomSql( + "Select `Address Line 1`, `Address Line 2` FROM primary WHERE `Address Line 3` like 'Bandra%'") + val state = sql.computeStateFrom(data) + val metric = sql.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed") + case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 column" + } + + } + } +}