diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala
index 55e410f3..4970a4d0 100644
--- a/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala
+++ b/src/main/scala/com/amazon/deequ/suggestions/rules/FractionalCategoricalRangeRule.scala
@@ -23,16 +23,17 @@ import com.amazon.deequ.metrics.DistributionValue
import com.amazon.deequ.profiles.ColumnProfile
import com.amazon.deequ.suggestions.ConstraintSuggestion
import com.amazon.deequ.suggestions.ConstraintSuggestionWithValue
+import com.amazon.deequ.suggestions.rules.FractionalCategoricalRangeRule.defaultIntervalStrategy
+import com.amazon.deequ.suggestions.rules.interval.{ConfidenceIntervalStrategy, WilsonScoreIntervalStrategy}
import org.apache.commons.lang3.StringEscapeUtils
-import scala.math.BigDecimal.RoundingMode
-
/** If we see a categorical range for most values in a column, we suggest an IS IN (...)
* constraint that should hold for most values */
case class FractionalCategoricalRangeRule(
targetDataCoverageFraction: Double = 0.9,
categorySorter: Array[(String, DistributionValue)] => Array[(String, DistributionValue)] =
- categories => categories.sortBy({ case (_, value) => value.absolute }).reverse
+ categories => categories.sortBy({ case (_, value) => value.absolute }).reverse,
+ intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy
) extends ConstraintRule[ColumnProfile] {
override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = {
@@ -79,11 +80,8 @@ case class FractionalCategoricalRangeRule(
val p = ratioSums
val n = numRecords
- val z = 1.96
- // TODO this needs to be more robust for p's close to 0 or 1
- val targetCompliance = BigDecimal(p - z * math.sqrt(p * (1 - p) / n))
- .setScale(2, RoundingMode.DOWN).toDouble
+ val targetCompliance = intervalStrategy.calculateTargetConfidenceInterval(p, n).lowerBound
val description = s"'${profile.column}' has value range $categoriesSql for at least " +
s"${targetCompliance * 100}% of values"
@@ -128,3 +126,7 @@ case class FractionalCategoricalRangeRule(
override val ruleDescription: String = "If we see a categorical range for most values " +
"in a column, we suggest an IS IN (...) constraint that should hold for most values"
}
+
+object FractionalCategoricalRangeRule {
+ private val defaultIntervalStrategy: ConfidenceIntervalStrategy = WilsonScoreIntervalStrategy()
+}
diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala
index 9f995a11..35a287b6 100644
--- a/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala
+++ b/src/main/scala/com/amazon/deequ/suggestions/rules/RetainCompletenessRule.scala
@@ -21,8 +21,7 @@ import com.amazon.deequ.profiles.ColumnProfile
import com.amazon.deequ.suggestions.CommonConstraintSuggestion
import com.amazon.deequ.suggestions.ConstraintSuggestion
import com.amazon.deequ.suggestions.rules.RetainCompletenessRule._
-
-import scala.math.BigDecimal.RoundingMode
+import com.amazon.deequ.suggestions.rules.interval.{ConfidenceIntervalStrategy, WilsonScoreIntervalStrategy}
/**
* If a column is incomplete in the sample, we model its completeness as a binomial variable,
@@ -33,21 +32,15 @@ import scala.math.BigDecimal.RoundingMode
*/
case class RetainCompletenessRule(
minCompleteness: Double = defaultMinCompleteness,
- maxCompleteness: Double = defaultMaxCompleteness
+ maxCompleteness: Double = defaultMaxCompleteness,
+ intervalStrategy: ConfidenceIntervalStrategy = defaultIntervalStrategy
) extends ConstraintRule[ColumnProfile] {
override def shouldBeApplied(profile: ColumnProfile, numRecords: Long): Boolean = {
profile.completeness > minCompleteness && profile.completeness < maxCompleteness
}
override def candidate(profile: ColumnProfile, numRecords: Long): ConstraintSuggestion = {
-
- val p = profile.completeness
- val n = numRecords
- val z = 1.96
-
- // TODO this needs to be more robust for p's close to 0 or 1
- val targetCompleteness = BigDecimal(p - z * math.sqrt(p * (1 - p) / n))
- .setScale(2, RoundingMode.DOWN).toDouble
+ val targetCompleteness = intervalStrategy.calculateTargetConfidenceInterval(profile.completeness, numRecords).lowerBound
val constraint = completenessConstraint(profile.column, _ >= targetCompleteness)
@@ -75,4 +68,5 @@ case class RetainCompletenessRule(
object RetainCompletenessRule {
private val defaultMinCompleteness: Double = 0.2
private val defaultMaxCompleteness: Double = 1.0
+ private val defaultIntervalStrategy: ConfidenceIntervalStrategy = WilsonScoreIntervalStrategy()
}
diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala
new file mode 100644
index 00000000..097bd911
--- /dev/null
+++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala
@@ -0,0 +1,34 @@
+package com.amazon.deequ.suggestions.rules.interval
+
+import breeze.stats.distributions.{Gaussian, Rand}
+import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence}
+
+/**
+ * Strategy for calculate confidence interval
+ * */
+trait ConfidenceIntervalStrategy {
+
+ /**
+ * Generated confidence interval interval
+ * @param pHat sample of the population that share a trait
+ * @param numRecords overall number of records
+ * @param confidence confidence level of method used to estimate the interval.
+ * @return
+ */
+ def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval
+
+ def validateInput(pHat: Double, confidence: Double): Unit = {
+ require(0.0 <= pHat && pHat <= 1.0, "pHat must be between 0.0 and 1.0")
+ require(0.0 <= confidence && confidence <= 1.0, "confidence must be between 0.0 and 1.0")
+ }
+
+ def calculateZScore(confidence: Double): Double = Gaussian(0, 1)(Rand).inverseCdf(1 - ((1.0 - confidence)/ 2.0))
+}
+
+object ConfidenceIntervalStrategy {
+ val defaultConfidence = 0.95
+
+ case class ConfidenceInterval(lowerBound: Double, upperBound: Double)
+}
+
+
diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala
new file mode 100644
index 00000000..6e8d1d06
--- /dev/null
+++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala
@@ -0,0 +1,23 @@
+package com.amazon.deequ.suggestions.rules.interval
+
+import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence}
+
+import scala.math.BigDecimal.RoundingMode
+
+/**
+ * Implements the Wald Interval method for creating a binomial proportion confidence interval.
+ *
+ * @see
+ * Normal approximation interval (Wikipedia)
+ */
+case class WaldIntervalStrategy() extends ConfidenceIntervalStrategy {
+ def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = {
+ validateInput(pHat, confidence)
+ val successRatio = BigDecimal(pHat)
+ val marginOfError = BigDecimal(calculateZScore(confidence) * math.sqrt(pHat * (1 - pHat) / numRecords))
+ val lowerBound = (successRatio - marginOfError).setScale(2, RoundingMode.DOWN).toDouble
+ val upperBound = (successRatio + marginOfError).setScale(2, RoundingMode.UP).toDouble
+ ConfidenceInterval(lowerBound, upperBound)
+ }
+}
diff --git a/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala
new file mode 100644
index 00000000..e76b8a0e
--- /dev/null
+++ b/src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala
@@ -0,0 +1,27 @@
+package com.amazon.deequ.suggestions.rules.interval
+
+import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence}
+
+import scala.math.BigDecimal.RoundingMode
+
+/**
+ * Using Wilson score method for creating a binomial proportion confidence interval.
+ *
+ * @see
+ * Wilson score interval (Wikipedia)
+ */
+case class WilsonScoreIntervalStrategy() extends ConfidenceIntervalStrategy {
+
+ def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = {
+ validateInput(pHat, confidence)
+ val zScore = calculateZScore(confidence)
+ val zSquareOverN = math.pow(zScore, 2) / numRecords
+ val factor = 1.0 / (1 + zSquareOverN)
+ val adjustedSuccessRatio = pHat + zSquareOverN/2
+ val marginOfError = zScore * math.sqrt(pHat * (1 - pHat)/numRecords + zSquareOverN/(4 * numRecords))
+ val lowerBound = BigDecimal(factor * (adjustedSuccessRatio - marginOfError)).setScale(2, RoundingMode.DOWN).toDouble
+ val upperBound = BigDecimal(factor * (adjustedSuccessRatio + marginOfError)).setScale(2, RoundingMode.UP).toDouble
+ ConfidenceInterval(lowerBound, upperBound)
+ }
+}
diff --git a/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala
new file mode 100644
index 00000000..708fd285
--- /dev/null
+++ b/src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala
@@ -0,0 +1,32 @@
+package com.amazon.deequ.suggestions.rules.interval
+
+import com.amazon.deequ.SparkContextSpec
+import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval
+import com.amazon.deequ.utils.FixtureSupport
+import org.scalamock.scalatest.MockFactory
+import org.scalatest.wordspec.AnyWordSpec
+
+class IntervalStrategyTest extends AnyWordSpec with FixtureSupport with SparkContextSpec
+ with MockFactory {
+ "WaldIntervalStrategy" should {
+ "be calculated correctly" in {
+ assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(1.0, 1.0))
+ assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6))
+ assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5))
+ assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7))
+ assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.84, 0.96))
+ assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(1.0, 1.0))
+ }
+ }
+
+ "WilsonIntervalStrategy" should {
+ "be calculated correctly" in {
+ assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(0.83, 1.0))
+ assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6))
+ assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5))
+ assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7))
+ assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.82, 0.95))
+ assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(0.96, 1.0))
+ }
+ }
+}