Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MinLength/MaxLength] Apply filtered row behavior at the row level evaluation #547

Merged
merged 2 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/main/scala/com/amazon/deequ/analyzers/Analyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,16 @@ private[deequ] object Analyzers {
}

def conditionalSelectionWithAugmentedOutcome(selection: Column,
condition: Option[String],
replaceWith: Double): Column = {
condition: Option[String]): Column = {
val origSelection = array(lit(InScopeData.name).as("source"), selection.as("selection"))
val filteredSelection = array(lit(FilteredData.name).as("source"), lit(replaceWith).as("selection"))

// The 2nd value in the array is set to null, but it can be set to anything.
// The value is not used to evaluate the row level outcome for filtered rows (to true/null).
// That decision is made using the 1st value which is set to "FilteredData" here.
val filteredSelection = array(lit(FilteredData.name).as("source"), lit(null).as("selection"))

condition
.map { cond => when(not(expr(cond)), filteredSelection).otherwise(origSelection) }
.map { cond => when(expr(cond), origSelection).otherwise(filteredSelection) }
.getOrElse(origSelection)
}

Expand Down
47 changes: 16 additions & 31 deletions src/main/scala/com/amazon/deequ/analyzers/MaxLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -35,12 +36,15 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
max(criterion) :: Nil
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
max(element_at(criterion, 2).cast(DoubleType)) :: Nil
eycho-am marked this conversation as resolved.
Show resolved Hide resolved
}

override def fromAggregationResult(result: Row, offset: Int): Option[MaxState] = {
ifNoNullsIn(result, offset) { _ =>
MaxState(result.getDouble(offset), Some(rowLevelResults))
MaxState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -51,35 +55,16 @@ case class MaxLength(column: String, where: Option[String] = None, analyzerOptio
override def filterCondition: Option[String] = where

private[deequ] def criterion: Column = {
transformColForNullBehavior
}

private[deequ] def rowLevelResults: Column = {
transformColForFilteredRow(criterion)
}

private def transformColForFilteredRow(col: Column): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
getRowLevelFilterTreatment(analyzerOptions) match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MinValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior: Column = {
val isNullCheck = col(column).isNull
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
getNullBehavior match {
case NullBehavior.Fail =>
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MaxValue)
case NullBehavior.EmptyString =>
// Empty String is 0 length string
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType)
case _ =>
colLengths
val colLength = length(col(column)).cast(DoubleType)
val updatedColumn = getNullBehavior match {
case NullBehavior.Fail => when(isNullCheck, Double.MaxValue).otherwise(colLength)
// Empty String is 0 length string
case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength)
case NullBehavior.Ignore => colLength
}

conditionalSelectionWithAugmentedOutcome(updatedColumn, where)
}

private def getNullBehavior: NullBehavior = {
Expand Down
20 changes: 14 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/Maximum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric}
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, element_at, max}
import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.amazon.deequ.analyzers.Preconditions.isNumeric
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row

case class MaxState(maxValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MaxState] with FullColumn {
Expand All @@ -41,6 +46,9 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
max(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

Expand All @@ -57,5 +65,5 @@ case class Maximum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MinValue)
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where)
}
47 changes: 16 additions & 31 deletions src/main/scala/com/amazon/deequ/analyzers/MinLength.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ import com.amazon.deequ.analyzers.Preconditions.isString
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.length
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.functions.min
import org.apache.spark.sql.functions.not
import org.apache.spark.sql.functions.when
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType

Expand All @@ -35,12 +36,15 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
min(criterion) :: Nil
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
min(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

override def fromAggregationResult(result: Row, offset: Int): Option[MinState] = {
ifNoNullsIn(result, offset) { _ =>
MinState(result.getDouble(offset), Some(rowLevelResults))
MinState(result.getDouble(offset), Some(criterion))
}
}

Expand All @@ -51,35 +55,16 @@ case class MinLength(column: String, where: Option[String] = None, analyzerOptio
override def filterCondition: Option[String] = where

private[deequ] def criterion: Column = {
transformColForNullBehavior
}

private[deequ] def rowLevelResults: Column = {
transformColForFilteredRow(criterion)
}

private def transformColForFilteredRow(col: Column): Column = {
val whereNotCondition = where.map { expression => not(expr(expression)) }
getRowLevelFilterTreatment(analyzerOptions) match {
case FilteredRowOutcome.TRUE =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = Double.MaxValue)
case _ =>
conditionSelectionGivenColumn(col, whereNotCondition, replaceWith = null)
}
}

private def transformColForNullBehavior: Column = {
val isNullCheck = col(column).isNull
val colLengths: Column = length(conditionalSelection(column, where)).cast(DoubleType)
getNullBehavior match {
case NullBehavior.Fail =>
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = Double.MinValue)
case NullBehavior.EmptyString =>
// Empty String is 0 length string
conditionSelectionGivenColumn(colLengths, Option(isNullCheck), replaceWith = 0.0).cast(DoubleType)
case _ =>
colLengths
val colLength = length(col(column)).cast(DoubleType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this is much cleaner

val updatedColumn = getNullBehavior match {
case NullBehavior.Fail => when(isNullCheck, Double.MinValue).otherwise(colLength)
// Empty String is 0 length string
case NullBehavior.EmptyString => when(isNullCheck, lit(0.0)).otherwise(colLength)
case NullBehavior.Ignore => colLength
}

conditionalSelectionWithAugmentedOutcome(updatedColumn, where)
}

private def getNullBehavior: NullBehavior = {
Expand Down
20 changes: 14 additions & 6 deletions src/main/scala/com/amazon/deequ/analyzers/Minimum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

package com.amazon.deequ.analyzers

import com.amazon.deequ.analyzers.Preconditions.{hasColumn, isNumeric}
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.functions.{col, element_at, min}
import org.apache.spark.sql.types.{DoubleType, StructType}
import Analyzers._
import com.amazon.deequ.analyzers.Analyzers._
import com.amazon.deequ.analyzers.Preconditions.hasColumn
import com.amazon.deequ.analyzers.Preconditions.isNumeric
import com.amazon.deequ.metrics.FullColumn
import com.google.common.annotations.VisibleForTesting
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.element_at
import org.apache.spark.sql.functions.min
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.Column
import org.apache.spark.sql.Row

case class MinState(minValue: Double, override val fullColumn: Option[Column] = None)
extends DoubleValuedState[MinState] with FullColumn {
Expand All @@ -41,6 +46,9 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
with FilterableAnalyzer {

override def aggregationFunctions(): Seq[Column] = {
// The criterion returns a column where each row contains an array of 2 elements.
// The first element of the array is a string that indicates if the row is "in scope" or "filtered" out.
// The second element is the value used for calculating the metric. We use "element_at" to extract it.
min(element_at(criterion, 2).cast(DoubleType)) :: Nil
}

Expand All @@ -57,5 +65,5 @@ case class Minimum(column: String, where: Option[String] = None, analyzerOptions
override def filterCondition: Option[String] = where

@VisibleForTesting
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where, Double.MaxValue)
private def criterion: Column = conditionalSelectionWithAugmentedOutcome(col(column), where)
}
61 changes: 47 additions & 14 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,8 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MaxState, Double, Double](maxLength, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, maxLength.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
Expand Down Expand Up @@ -593,7 +594,8 @@ object Constraint {
val constraint = AnalysisBasedConstraint[MinState, Double, Double](minLength, assertion,
hint = hint)

val sparkAssertion = org.apache.spark.sql.functions.udf(assertion)
val updatedAssertion = getUpdatedRowLevelAssertionForLengthConstraint(assertion, minLength.analyzerOptions)
val sparkAssertion = org.apache.spark.sql.functions.udf(updatedAssertion)

new RowLevelAssertedConstraint(
constraint,
Expand Down Expand Up @@ -953,26 +955,57 @@ object Constraint {
}
}

def filteredRowOutcome: java.lang.Boolean = {
analyzerOptions match {
case Some(opts) =>
opts.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
// https://github.com/awslabs/deequ/issues/530
// Filtered rows should be marked as true by default.
// They can be set to null using the FilteredRowOutcome option.
case None => true
scope match {
case FilteredData.name => filteredRowOutcome(analyzerOptions)
case InScopeData.name => inScopeRowOutcome(value)
}
}
}

private[this] def getUpdatedRowLevelAssertionForLengthConstraint(assertion: Double => Boolean,
analyzerOptions: Option[AnalyzerOptions])
: Seq[String] => java.lang.Boolean = {
(d: Seq[String]) => {
val (scope, value) = (d.head, Option(d.last).map(_.toDouble))

def inScopeRowOutcome(value: Option[Double]): java.lang.Boolean = {
if (value.isDefined) {
// If value is defined, run it through the assertion.
assertion(value.get)
} else {
// If value is not defined (value is null), apply NullBehavior.
analyzerOptions match {
case Some(opts) =>
opts.nullBehavior match {
case NullBehavior.EmptyString => assertion(0.0)
case NullBehavior.Fail => false
case NullBehavior.Ignore => null
}
case None => null
}
}
}

scope match {
case FilteredData.name => filteredRowOutcome
case FilteredData.name => filteredRowOutcome(analyzerOptions)
case InScopeData.name => inScopeRowOutcome(value)
}
}
}

private def filteredRowOutcome(analyzerOptions: Option[AnalyzerOptions]): java.lang.Boolean = {
analyzerOptions match {
case Some(opts) =>
opts.filteredRow match {
case FilteredRowOutcome.TRUE => true
case FilteredRowOutcome.NULL => null
}
// https://github.com/awslabs/deequ/issues/530
// Filtered rows should be marked as true by default.
// They can be set to null using the FilteredRowOutcome option.
case None => true
}
}
}

/**
Expand Down
Loading
Loading