Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for satisfies row level results bug #553

Merged
merged 1 commit into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions src/main/scala/com/amazon/deequ/analyzers/Compliance.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import org.apache.spark.sql.functions._
import Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.types.DoubleType

/**
* Compliance is a measure of the fraction of rows that complies with the given column constraint.
Expand All @@ -43,37 +42,31 @@ case class Compliance(instance: String,
where: Option[String] = None,
columns: List[String] = List.empty[String],
analyzerOptions: Option[AnalyzerOptions] = None)
extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance)
with FilterableAnalyzer {
extends StandardScanShareableAnalyzer[NumMatchesAndCount]("Compliance", instance) with FilterableAnalyzer {

override def fromAggregationResult(result: Row, offset: Int): Option[NumMatchesAndCount] = {

ifNoNullsIn(result, offset, howMany = 2) { _ =>
NumMatchesAndCount(result.getLong(offset), result.getLong(offset + 1), Some(rowLevelResults))
}
}

override def aggregationFunctions(): Seq[Column] = {

val summation = sum(criterion)

val summation = sum(criterion.cast(IntegerType))
summation :: conditionalCount(where) :: Nil
}

override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = {
conditionalSelection(expr(predicate), where).cast(IntegerType)
}
private def criterion: Column = conditionalSelection(expr(predicate), where)

private def rowLevelResults: Column = {
val filteredRowOutcome = getRowLevelFilterTreatment(analyzerOptions)
val whereNotCondition = where.map { expression => not(expr(expression)) }

filteredRowOutcome match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true).cast(IntegerType)
conditionSelectionGivenColumn(expr(predicate), whereNotCondition, replaceWith = true)
case _ =>
// The default behavior when using filtering for rows is to treat them as nulls. No special treatment needed.
criterion
Expand Down
6 changes: 2 additions & 4 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,10 @@ object Constraint {
val constraint = AnalysisBasedConstraint[NumMatchesAndCount, Double, Double](
compliance, assertion, hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
new RowLevelAssertedConstraint(
new RowLevelConstraint(
constraint,
s"ComplianceConstraint($compliance)",
s"ColumnsCompliance-${compliance.predicate}",
sparkAssertion)
s"ColumnsCompliance-${compliance.predicate}")
}

/**
Expand Down
70 changes: 70 additions & 0 deletions src/test/scala/com/amazon/deequ/VerificationSuiteTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.amazon.deequ.constraints.Constraint
import com.amazon.deequ.io.DfsUtils
import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository.MetricsRepository
import com.amazon.deequ.repository.ResultKey
import com.amazon.deequ.repository.memory.InMemoryMetricsRepository
Expand Down Expand Up @@ -1993,6 +1994,75 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec
}
}

"Verification Suite's Row Level Results" should {
"yield correct results for satisfies check" in withSparkSession { sparkSession =>
import sparkSession.implicits._
val df = Seq(
(1, "blue"),
(2, "green"),
(3, "blue"),
(4, "red"),
(5, "purple")
).toDF("id", "color")

val columnCondition = "color in ('blue')"
val whereClause = "id <= 3"

case class CheckConfig(checkName: String,
assertion: Double => Boolean,
checkStatus: CheckStatus.Value,
whereClause: Option[String] = None)

val success = CheckStatus.Success
val error = CheckStatus.Error

val checkConfigs = Seq(
// Without where clause: Expected compliance metric for full dataset for given condition is 0.4
CheckConfig("check with >", (d: Double) => d > 0.5, error),
CheckConfig("check with >=", (d: Double) => d >= 0.35, success),
CheckConfig("check with <", (d: Double) => d < 0.3, error),
CheckConfig("check with <=", (d: Double) => d <= 0.4, success),
CheckConfig("check with =", (d: Double) => d == 0.4, success),
CheckConfig("check with > / <", (d: Double) => d > 0.0 && d < 0.5, success),
CheckConfig("check with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, error),

// With where Clause: Expected compliance metric for full dataset for given condition with where clause is 0.67
CheckConfig("check w/ where and with >", (d: Double) => d > 0.7, error, Some(whereClause)),
CheckConfig("check w/ where and with >=", (d: Double) => d >= 0.66, success, Some(whereClause)),
CheckConfig("check w/ where and with <", (d: Double) => d < 0.6, error, Some(whereClause)),
CheckConfig("check w/ where and with <=", (d: Double) => d <= 0.67, success, Some(whereClause)),
CheckConfig("check w/ where and with =", (d: Double) => d == 0.66, error, Some(whereClause)),
CheckConfig("check w/ where and with > / <", (d: Double) => d > 0.0 && d < 0.5, error, Some(whereClause)),
CheckConfig("check w/ where and with >= / <=", (d: Double) => d >= 0.41 && d <= 1.1, success, Some(whereClause))
)

val checks = checkConfigs.map { checkConfig =>
val constraintName = s"Constraint for check: ${checkConfig.checkName}"
val check = Check(CheckLevel.Error, checkConfig.checkName)
.satisfies(columnCondition, constraintName, checkConfig.assertion)
checkConfig.whereClause.map(check.where).getOrElse(check)
}

val verificationResult = VerificationSuite().onData(df).addChecks(checks).run()
val actualResults = verificationResult.checkResults.map { case (c, r) => c.description -> r.status }
val expectedResults = checkConfigs.map { c => c.checkName -> c.checkStatus}.toMap
assert(actualResults == expectedResults)

verificationResult.metrics.values.foreach { metric =>
val metricValue = metric.asInstanceOf[Metric[Double]].value.toOption.getOrElse(0.0)
if (metric.instance.contains("where")) assert(math.abs(metricValue - 0.66) < 0.1)
else assert(metricValue == 0.4)
}

val rowLevelResults = VerificationResult.rowLevelResultsAsDataFrame(sparkSession, verificationResult, df)
checkConfigs.foreach { checkConfig =>
val results = rowLevelResults.select(checkConfig.checkName).collect().map { r => r.getAs[Boolean](0)}.toSeq
if (checkConfig.whereClause.isDefined) assert(results == Seq(true, false, true, true, true))
else assert(results == Seq(true, false, true, false, false))
}
}
}

/** Run anomaly detection using a repository with some previous analysis results for testing */
private[this] def evaluateWithRepositoryWithHistory(test: MetricsRepository => Unit): Unit = {

Expand Down
30 changes: 11 additions & 19 deletions src/test/scala/com/amazon/deequ/analyzers/ComplianceTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
*
*/


package com.amazon.deequ.analyzers

import com.amazon.deequ.SparkContextSpec
Expand All @@ -25,34 +24,30 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {

"Compliance" should {
"return row-level results for columns" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "att1 > 3", columns = List("att1"))
val state = att1Compliance.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Int]("new")
) shouldBe Seq(0, 0, 0, 1, 1, 1)
data.withColumn("new", metric.fullColumn.get).collect().map(_.getAs[Boolean]("new")
) shouldBe Seq(false, false, false, true, true, true)
}

"return row-level results for null columns" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "attNull > 3", columns = List("att1"))
val state = att1Compliance.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 1, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, true, true, true)
}

"return row-level results filtered with null" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"),
Expand All @@ -61,11 +56,10 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(null, null, null, 0, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(null, null, null, false, true, true)
}

"return row-level results filtered with true" in withSparkSession { session =>

val data = getDfWithNumericValues(session)

val att1Compliance = Compliance("rule1", "att1 > 4", where = Option("att2 != 0"),
Expand All @@ -74,7 +68,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(1, 1, 1, 0, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(true, true, true, false, true, true)
}

"return row-level results for compliance in bounds" in withSparkSession { session =>
Expand All @@ -93,7 +87,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 0)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, false)
}

"return row-level results for compliance in bounds filtered as null" in withSparkSession { session =>
Expand All @@ -114,7 +108,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, null, null, null)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, null, null, null)
}

"return row-level results for compliance in bounds filtered as true" in withSparkSession { session =>
Expand All @@ -135,7 +129,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 1, 1, 1, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, true, true, true, true, true)
}

"return row-level results for compliance in array" in withSparkSession { session =>
Expand All @@ -157,7 +151,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 0)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, false)
}

"return row-level results for compliance in array filtered as null" in withSparkSession { session =>
Expand All @@ -180,7 +174,7 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, null, null)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, null, null)
}

"return row-level results for compliance in array filtered as true" in withSparkSession { session =>
Expand All @@ -196,16 +190,14 @@ class ComplianceTest extends AnyWordSpec with Matchers with SparkContextSpec wit

val data = getDfWithNumericValues(session)


val att1Compliance = Compliance(s"$column contained in ${allowedValues.mkString(",")}", predicate,
where = Option("att1 < 5"), columns = List("att3"),
analyzerOptions = Option(AnalyzerOptions(filteredRow = FilteredRowOutcome.TRUE)))
val state = att1Compliance.computeStateFrom(data)
val metric: DoubleMetric with FullColumn = att1Compliance.computeMetricFrom(state)

data.withColumn("new", metric.fullColumn.get).collect().map(r =>
if (r == null) null else r.getAs[Int]("new")) shouldBe Seq(0, 0, 1, 1, 1, 1)
if (r == null) null else r.getAs[Boolean]("new")) shouldBe Seq(false, false, true, true, true, true)
}
}

}
Loading