-
Notifications
You must be signed in to change notification settings - Fork 543
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
New analyzer to evaluate a SQL statement that returns a boolean value
- Loading branch information
1 parent
836bfcf
commit c63caed
Showing
3 changed files
with
157 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/** | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not | ||
* use this file except in compliance with the License. A copy of the License | ||
* is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on | ||
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
* express or implied. See the License for the specific language governing | ||
* permissions and limitations under the License. | ||
* | ||
*/ | ||
package com.amazon.deequ.analyzers | ||
|
||
import com.amazon.deequ.metrics.DoubleMetric | ||
import com.amazon.deequ.metrics.Entity | ||
import org.apache.spark.sql.DataFrame | ||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.functions.col | ||
import org.apache.spark.sql.types.DoubleType | ||
|
||
import scala.util.Failure | ||
import scala.util.Success | ||
|
||
case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] { | ||
/** | ||
* Compute the state (sufficient statistics) from the data | ||
* | ||
* @param data data frame | ||
* @return | ||
*/ | ||
override def computeStateFrom(data: DataFrame): Option[CustomSqlState] = { | ||
val dfSql = data.sqlContext.sql(expression) | ||
val cols = dfSql.columns.toSeq | ||
cols match { | ||
case Seq(resultCol) => | ||
val dfSqlCast = dfSql.withColumn(resultCol, col(resultCol).cast(DoubleType)) | ||
val results: Seq[Row] = dfSqlCast.collect() | ||
if (results.size != 1) { | ||
Some(CustomSqlState(Right("Custom SQL did not return exactly 1 row"))) | ||
} else { | ||
Some(CustomSqlState(Left(results.head.get(0).asInstanceOf[Double]))) | ||
} | ||
case _ => Some(CustomSqlState(Right("Custom SQL did not return exactly 1 column"))) | ||
} | ||
} | ||
|
||
/** | ||
* Compute the metric from the state (sufficient statistics) | ||
* | ||
* @param state wrapper holding a state of type S (required due to typing issues...) | ||
* @return | ||
*/ | ||
override def computeMetricFrom(state: Option[CustomSqlState]): DoubleMetric = { | ||
state match { | ||
// The returned state may | ||
case Some(theState) => theState.stateOrError match { | ||
case Left(value) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Success(value)) | ||
case Right(error) => DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException(error))) | ||
} | ||
case None => | ||
DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) | ||
} | ||
} | ||
|
||
override private[deequ] def toFailureMetric(failure: Exception) = { | ||
DoubleMetric(Entity.Dataset, "CustomSQL", "*", Failure(new RuntimeException("CustomSql Failed To Run"))) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
src/test/scala/com/amazon/deequ/analyzers/CustomSqlTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/** | ||
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"). You may not | ||
* use this file except in compliance with the License. A copy of the License | ||
* is located at | ||
* | ||
* http://aws.amazon.com/apache2.0/ | ||
* | ||
* or in the "license" file accompanying this file. This file is distributed on | ||
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either | ||
* express or implied. See the License for the specific language governing | ||
* permissions and limitations under the License. | ||
* | ||
*/ | ||
package com.amazon.deequ.analyzers | ||
|
||
import com.amazon.deequ.SparkContextSpec | ||
import com.amazon.deequ.metrics.DoubleMetric | ||
import com.amazon.deequ.utils.FixtureSupport | ||
import org.scalatest.matchers.should.Matchers | ||
import org.scalatest.wordspec.AnyWordSpec | ||
|
||
import scala.util.Failure | ||
import scala.util.Success | ||
|
||
class CustomSqlTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { | ||
|
||
"CustomSql" should { | ||
"return a metric when the statement returns exactly one value" in withSparkSession { session => | ||
val data = getDfWithStringColumns(session) | ||
data.createOrReplaceTempView("primary") | ||
|
||
val sql = CustomSql("SELECT COUNT(*) FROM primary WHERE `Address Line 2` IS NOT NULL") | ||
val state = sql.computeStateFrom(data) | ||
val metric: DoubleMetric = sql.computeMetricFrom(state) | ||
|
||
metric.value.isSuccess shouldBe true | ||
metric.value.get shouldBe 6.0 | ||
} | ||
|
||
"returns a failed metric when the statement returns more than one row" in withSparkSession { session => | ||
val data = getDfWithStringColumns(session) | ||
data.createOrReplaceTempView("primary") | ||
|
||
val sql = CustomSql("Select `Address Line 2` FROM primary WHERE `Address Line 2` is NOT NULL") | ||
val state = sql.computeStateFrom(data) | ||
val metric = sql.computeMetricFrom(state) | ||
|
||
metric.value.isFailure shouldBe true | ||
metric.value match { | ||
case Success(_) => fail("Should have failed") | ||
case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 row" | ||
} | ||
} | ||
|
||
"returns a failed metric when the statement returns more than one column" in withSparkSession { session => | ||
val data = getDfWithStringColumns(session) | ||
data.createOrReplaceTempView("primary") | ||
|
||
val sql = CustomSql( | ||
"Select `Address Line 1`, `Address Line 2` FROM primary WHERE `Address Line 3` like 'Bandra%'") | ||
val state = sql.computeStateFrom(data) | ||
val metric = sql.computeMetricFrom(state) | ||
|
||
metric.value.isFailure shouldBe true | ||
metric.value match { | ||
case Success(_) => fail("Should have failed") | ||
case Failure(exception) => exception.getMessage shouldBe "Custom SQL did not return exactly 1 column" | ||
} | ||
|
||
} | ||
} | ||
} |