diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala index 55e410f3..4970a4d0 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala @@ -23,16 +23,17 @@ import com.amazon.deequ.metrics.DistributionValue import com.amazon.deequ.profiles.ColumnProfile import com.amazon.deequ.suggestions.ConstraintSuggestion import com.amazon.deequ.suggestions.ConstraintSuggestionWithValue +import com.amazon.deequ.suggestions.rules.FractionalCategoricalRangeRule.defaultIntervalStrategy +import com.amazon.deequ.suggestions.rules.interval.{ConfidenceIntervalStrategy, WilsonScoreIntervalStrategy} import org.apache.commons.lang3.StringEscapeUtils -import scala.math.BigDecimal.RoundingMode - /** If we see a categorical range for most values in a column, we suggest an IS IN (...) * constraint that should hold for most values */ case class FractionalCategoricalRangeRule( targetDataCoverageFraction: Double = 0.9, categorySorter: Array[(String, DistributionValue)] => Array[(String, DistributionValue)] = - categories => categories.sortBy({ case (_, value) => value.absolute }).reverse + categories => categories.sortBy({ case (_, value) => value.absolute }).reverse, + intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy ) extends ConstraintRule[ColumnProfile] { override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = { @@ -79,11 +80,8 @@ case class FractionalCategoricalRangeRule( val p = ratioSums val n = numRecords - val z = 1.96 - // TODO this needs to be more robust for p's close to 0 or 1 - val targetCompliance = BigDecimal(p - z * math.sqrt(p * (1 - p) / n)) - .setScale(2, RoundingMode.DOWN).toDouble + val targetCompliance = intervalStrategy.calculateTargetConfidenceInterval(p, n).lowerBound val description = s"'${profile.column}' has value range $categoriesSql for at least " + s"${targetCompliance * 100}% of values" @@ -128,3 +126,7 @@ case class FractionalCategoricalRangeRule( override val ruleDescription: String = "If we see a categorical range for most values " + "in a column, we suggest an IS IN (...) constraint that should hold for most values" } + +object FractionalCategoricalRangeRule { + private val defaultIntervalStrategy: ConfidenceIntervalStrategy = WilsonScoreIntervalStrategy() +} diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala index 9f995a11..35a287b6 100644 --- a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala @@ -21,8 +21,7 @@ import com.amazon.deequ.profiles.ColumnProfile import com.amazon.deequ.suggestions.CommonConstraintSuggestion import com.amazon.deequ.suggestions.ConstraintSuggestion import com.amazon.deequ.suggestions.rules.RetainCompletenessRule._ - -import scala.math.BigDecimal.RoundingMode +import com.amazon.deequ.suggestions.rules.interval.{ConfidenceIntervalStrategy, WilsonScoreIntervalStrategy} /** * If a column is incomplete in the sample, we model its completeness as a binomial variable, @@ -33,21 +32,15 @@ import scala.math.BigDecimal.RoundingMode */ case class RetainCompletenessRule( minCompleteness: Double = defaultMinCompleteness, - maxCompleteness: Double = defaultMaxCompleteness + maxCompleteness: Double = defaultMaxCompleteness, + intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy ) extends ConstraintRule[ColumnProfile] { override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = { profile.completeness > minCompleteness && profile.completeness < maxCompleteness } override def candidate(profile: ColumnProfile, numRecords: Long): ConstraintSuggestion = { - - val p = profile.completeness - val n = numRecords - val z = 1.96 - - // TODO this needs to be more robust for p's close to 0 or 1 - val targetCompleteness = BigDecimal(p - z * math.sqrt(p * (1 - p) / n)) - .setScale(2, RoundingMode.DOWN).toDouble + val targetCompleteness = intervalStrategy.calculateTargetConfidenceInterval(profile.completeness, numRecords).lowerBound val constraint = completenessConstraint(profile.column, _ >= targetCompleteness) @@ -75,4 +68,5 @@ case class RetainCompletenessRule( object RetainCompletenessRule { private val defaultMinCompleteness: Double = 0.2 private val defaultMaxCompleteness: Double = 1.0 + private val defaultIntervalStrategy: ConfidenceIntervalStrategy = WilsonScoreIntervalStrategy() } diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala new file mode 100644 index 00000000..097bd911 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala @@ -0,0 +1,34 @@ +package com.amazon.deequ.suggestions.rules.interval + +import breeze.stats.distributions.{Gaussian, Rand} +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence} + +/** + * Strategy for calculate confidence interval + * */ +trait ConfidenceIntervalStrategy { + + /** + * Generated confidence interval interval + * @param pHat sample of the population that share a trait + * @param numRecords overall number of records + * @param confidence confidence level of method used to estimate the interval. + * @return + */ + def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval + + def validateInput(pHat: Double, confidence: Double): Unit = { + require(0.0 <= pHat && pHat <= 1.0, "pHat must be between 0.0 and 1.0") + require(0.0 <= confidence && confidence <= 1.0, "confidence must be between 0.0 and 1.0") + } + + def calculateZScore(confidence: Double): Double = Gaussian(0, 1)(Rand).inverseCdf(1 - ((1.0 - confidence)/ 2.0)) +} + +object ConfidenceIntervalStrategy { + val defaultConfidence = 0.95 + + case class ConfidenceInterval(lowerBound: Double, upperBound: Double) +} + + diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala new file mode 100644 index 00000000..6e8d1d06 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala @@ -0,0 +1,23 @@ +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence} + +import scala.math.BigDecimal.RoundingMode + +/** + * Implements the Wald Interval method for creating a binomial proportion confidence interval. + * + * @see + * Normal approximation interval (Wikipedia) + */ +case class WaldIntervalStrategy() extends ConfidenceIntervalStrategy { + def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = { + validateInput(pHat, confidence) + val successRatio = BigDecimal(pHat) + val marginOfError = BigDecimal(calculateZScore(confidence) * math.sqrt(pHat * (1 - pHat) / numRecords)) + val lowerBound = (successRatio - marginOfError).setScale(2, RoundingMode.DOWN).toDouble + val upperBound = (successRatio + marginOfError).setScale(2, RoundingMode.UP).toDouble + ConfidenceInterval(lowerBound, upperBound) + } +} diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala new file mode 100644 index 00000000..e76b8a0e --- /dev/null +++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala @@ -0,0 +1,27 @@ +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence} + +import scala.math.BigDecimal.RoundingMode + +/** + * Using Wilson score method for creating a binomial proportion confidence interval. + * + * @see + * Wilson score interval (Wikipedia) + */ +case class WilsonScoreIntervalStrategy() extends ConfidenceIntervalStrategy { + + def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = { + validateInput(pHat, confidence) + val zScore = calculateZScore(confidence) + val zSquareOverN = math.pow(zScore, 2) / numRecords + val factor = 1.0 / (1 + zSquareOverN) + val adjustedSuccessRatio = pHat + zSquareOverN/2 + val marginOfError = zScore * math.sqrt(pHat * (1 - pHat)/numRecords + zSquareOverN/(4 * numRecords)) + val lowerBound = BigDecimal(factor * (adjustedSuccessRatio - marginOfError)).setScale(2, RoundingMode.DOWN).toDouble + val upperBound = BigDecimal(factor * (adjustedSuccessRatio + marginOfError)).setScale(2, RoundingMode.UP).toDouble + ConfidenceInterval(lowerBound, upperBound) + } +} diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala new file mode 100644 index 00000000..708fd285 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala @@ -0,0 +1,32 @@ +package com.amazon.deequ.suggestions.rules.interval + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval +import com.amazon.deequ.utils.FixtureSupport +import org.scalamock.scalatest.MockFactory +import org.scalatest.wordspec.AnyWordSpec + +class IntervalStrategyTest extends AnyWordSpec with FixtureSupport with SparkContextSpec + with MockFactory { + "WaldIntervalStrategy" should { + "be calculated correctly" in { + assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(1.0, 1.0)) + assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6)) + assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5)) + assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7)) + assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.84, 0.96)) + assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(1.0, 1.0)) + } + } + + "WilsonIntervalStrategy" should { + "be calculated correctly" in { + assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(0.83, 1.0)) + assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6)) + assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5)) + assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7)) + assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.82, 0.95)) + assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(0.96, 1.0)) + } + } +}