From 07c9648169059d9f20efafbf55fd73400ba2a3b6 Mon Sep 17 00:00:00 2001 From: Rahul Sharma Date: Tue, 2 Apr 2024 22:46:16 -0400 Subject: [PATCH] Fix for satisfies row level results bug - The satisfies constraint was incorrectly using the provided assertion to evaluate the row level outcomes. The assertion should only be used to evaluate the final outcome. - As part of this change, we have updated the row level results to return a true/false. The cast to an integer happens as part of the aggregation result. - Added a test to verify the row level results using checks made up of different assertions. --- .../amazon/deequ/analyzers/Compliance.scala | 15 ++-- .../amazon/deequ/constraints/Constraint.scala | 6 +- .../amazon/deequ/VerificationSuiteTest.scala | 70 +++++++++++++++++++ .../deequ/analyzers/ComplianceTest.scala | 30 +++----- 4 files changed, 87 insertions(+), 34 deletions(-) diff --git a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala index 0edf0197..247a02c1 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Compliance.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.functions._ import Analyzers._ import com.amazon.deequ.analyzers.Preconditions.hasColumn import com.google.common.annotations.VisibleForTesting -import org.apache.spark.sql.types.DoubleType /** * Compliance is a measure of the fraction of rows that complies with the given column constraint. @@ -43,29 +42,23 @@ case class Compliance(instance: String, where: Option[String] = None, columns: List[String] = List.empty[String], analyzerOptions: Option[AnalyzerOptions] = None) - extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) - with FilterableAnalyzer { + extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) with FilterableAnalyzer { override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = { - ifNoNullsIn(result, offset, howMany = 2) { _ => NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults)) } } override def aggregationFunctions(): Seq[Column] = { - - val summation = sum(criterion) - + val summation = sum(criterion.cast(IntegerType)) summation :: conditionalCount(where) :: Nil } override def filterCondition: Option[String] = where @VisibleForTesting - private def criterion: Column = { - conditionalSelection(expr(predicate), where).cast(IntegerType) - } + private def criterion: Column = conditionalSelection(expr(predicate), where) private def rowLevelResults: Column = { val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions) @@ -73,7 +66,7 @@ case class Compliance(instance: String, filteredRowOutcome match { case FilteredRowOutcome.TRUE => - conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType) + conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true) case _ => // The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed. criterion diff --git a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala index 8df88165..c0e6e9b9 100644 --- a/src/main/scala/com/amazon/deequ/constraints/Constraint.scala +++ b/src/main/scala/com/amazon/deequ/constraints/Constraint.scala @@ -359,12 +359,10 @@ object Constraint { val constraint = AnalysisBasedConstraint[NumMatchesAndCount, Double, Double]( compliance, assertion, hint = hint) - val sparkAssertion = org.apache.spark.sql.functions.udf(assertion) - new RowLevelAssertedConstraint( + new RowLevelConstraint( constraint, s"ComplianceConstraint($compliance)", - s"ColumnsCompliance-${compliance.predicate}", - sparkAssertion) + s"ColumnsCompliance-${compliance.predicate}") } /** diff --git a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala index 587f8bf5..df13ea90 100644 --- a/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala +++ b/src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala @@ -26,6 +26,7 @@ import com.amazon.deequ.constraints.Constraint import com.amazon.deequ.io.DfsUtils import com.amazon.deequ.metrics.DoubleMetric import com.amazon.deequ.metrics.Entity +import com.amazon.deequ.metrics.Metric import com.amazon.deequ.repository.MetricsRepository import com.amazon.deequ.repository.ResultKey import com.amazon.deequ.repository.memory.InMemoryMetricsRepository @@ -1993,6 +1994,75 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec } } + "Verification Suite's Row Level Results" should { + "yield correct results for satisfies check" in withSparkSession { sparkSession => + import sparkSession.implicits._ + val df = Seq( + (1, "blue"), + (2, "green"), + (3, "blue"), + (4, "red"), + (5, "purple") + ).toDF("id", "color") + + val columnCondition = "color in ('blue')" + val whereClause = "id <= 3" + + case class CheckConfig(checkName: String, + assertion: Double => Boolean, + checkStatus: CheckStatus.Value, + whereClause: Option[String] = None) + + val success = CheckStatus.Success + val error = CheckStatus.Error + + val checkConfigs = Seq( + // Without where clause: Expected compliance metric for full dataset for given condition is 0.4 + CheckConfig("check with >", (d: Double) => d > 0.5, error), + CheckConfig("check with >=", (d: Double) => d >= 0.35, success), + CheckConfig("check with <", (d: Double) => d < 0.3, error), + CheckConfig("check with <=", (d: Double) => d <= 0.4, success), + CheckConfig("check with =", (d: Double) => d == 0.4, success), + CheckConfig("check with > / <", (d: Double) => d > 0.0 && d < 0.5, success), + CheckConfig("check with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, error), + + // With where Clause: Expected compliance metric for full dataset for given condition with where clause is 0.67 + CheckConfig("check w/ where and with >", (d: Double) => d > 0.7, error, Some(whereClause)), + CheckConfig("check w/ where and with >=", (d: Double) => d >= 0.66, success, Some(whereClause)), + CheckConfig("check w/ where and with <", (d: Double) => d < 0.6, error, Some(whereClause)), + CheckConfig("check w/ where and with <=", (d: Double) => d <= 0.67, success, Some(whereClause)), + CheckConfig("check w/ where and with =", (d: Double) => d == 0.66, error, Some(whereClause)), + CheckConfig("check w/ where and with > / <", (d: Double) => d > 0.0 && d < 0.5, error, Some(whereClause)), + CheckConfig("check w/ where and with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, success, Some(whereClause)) + ) + + val checks = checkConfigs.map { checkConfig => + val constraintName = s"Constraint for check: ${checkConfig.checkName}" + val check = Check(CheckLevel.Error, checkConfig.checkName) + .satisfies(columnCondition, constraintName, checkConfig.assertion) + checkConfig.whereClause.map(check.where).getOrElse(check) + } + + val verificationResult = VerificationSuite().onData(df).addChecks(checks).run() + val actualResults = verificationResult.checkResults.map { case (c, r) => c.description -> r.status } + val expectedResults = checkConfigs.map { c => c.checkName -> c.checkStatus}.toMap + assert(actualResults == expectedResults) + + verificationResult.metrics.values.foreach { metric => + val metricValue = metric.asInstanceOf[Metric[Double]].value.toOption.getOrElse(0.0) + if (metric.instance.contains("where")) assert(math.abs(metricValue - 0.66) < 0.1) + else assert(metricValue == 0.4) + } + + val rowLevelResults = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df) + checkConfigs.foreach { checkConfig => + val results = rowLevelResults.select(checkConfig.checkName).collect().map { r => r.getAs[Boolean](0)}.toSeq + if (checkConfig.whereClause.isDefined) assert(results == Seq(true, false, true, true, true)) + else assert(results == Seq(true, false, true, false, false)) + } + } + } + /** Run anomaly detection using a repository with some previous analysis results for testing */ private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = { diff --git a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala index 5aa4033b..54fc225f 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala @@ -14,7 +14,6 @@ * */ - package com.amazon.deequ.analyzers import com.amazon.deequ.SparkContextSpec @@ -25,22 +24,19 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { - "Compliance" should { "return row-level results for columns" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "att1 > 3", columns = List("att1")) val state = att1Compliance.computeStateFrom(data) val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) - data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new") - ) shouldBe Seq(0, 0, 0, 1, 1, 1) + data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new") + ) shouldBe Seq(false, false, false, true, true, true) } "return row-level results for null columns" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "attNull > 3", columns = List("att1")) @@ -48,11 +44,10 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 1, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, true, true, true) } "return row-level results filtered with null" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), @@ -61,11 +56,10 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 0, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, false, true, true) } "return row-level results filtered with true" in withSparkSession { session => - val data = getDfWithNumericValues(session) val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"), @@ -74,7 +68,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(1, 1, 1, 0, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(true, true, true, false, true, true) } "return row-level results for compliance in bounds" in withSparkSession { session => @@ -93,7 +87,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 0) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, false) } "return row-level results for compliance in bounds filtered as null" in withSparkSession { session => @@ -114,7 +108,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, null, null, null) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, null, null, null) } "return row-level results for compliance in bounds filtered as true" in withSparkSession { session => @@ -135,7 +129,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, true) } "return row-level results for compliance in array" in withSparkSession { session => @@ -157,7 +151,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 0) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, false) } "return row-level results for compliance in array filtered as null" in withSparkSession { session => @@ -180,7 +174,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, null, null) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, null, null) } "return row-level results for compliance in array filtered as true" in withSparkSession { session => @@ -196,7 +190,6 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val data = getDfWithNumericValues(session) - val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate, where = Option("att1 < 5"), columns = List("att3"), analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE))) @@ -204,8 +197,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state) data.withColumn("new", metric.fullColumn.get).collect().map(r => - if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 1) + if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, true) } } - }