-
Notifications
You must be signed in to change notification settings - Fork 542
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
130 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
34 changes: 34 additions & 0 deletions
34
src/main/scala/com/amazon/deequ/suggestions/rules/interval/ConfidenceIntervalStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package com.amazon.deequ.suggestions.rules.interval | ||
|
||
import breeze.stats.distributions.{Gaussian, Rand} | ||
import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence} | ||
|
||
/** | ||
* Strategy for calculate confidence interval | ||
* */ | ||
trait ConfidenceIntervalStrategy { | ||
|
||
/** | ||
* Generated confidence interval interval | ||
* @param pHat sample of the population that share a trait | ||
* @param numRecords overall number of records | ||
* @param confidence confidence level of method used to estimate the interval. | ||
* @return | ||
*/ | ||
def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval | ||
|
||
def validateInput(pHat: Double, confidence: Double): Unit = { | ||
require(0.0 <= pHat && pHat <= 1.0, "pHat must be between 0.0 and 1.0") | ||
require(0.0 <= confidence && confidence <= 1.0, "confidence must be between 0.0 and 1.0") | ||
} | ||
|
||
def calculateZScore(confidence: Double): Double = Gaussian(0, 1)(Rand).inverseCdf(1 - ((1.0 - confidence)/ 2.0)) | ||
} | ||
|
||
object ConfidenceIntervalStrategy { | ||
val defaultConfidence = 0.95 | ||
|
||
case class ConfidenceInterval(lowerBound: Double, upperBound: Double) | ||
} | ||
|
||
|
23 changes: 23 additions & 0 deletions
23
src/main/scala/com/amazon/deequ/suggestions/rules/interval/WaldIntervalStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package com.amazon.deequ.suggestions.rules.interval | ||
|
||
import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence} | ||
|
||
import scala.math.BigDecimal.RoundingMode | ||
|
||
/** | ||
* Implements the Wald Interval method for creating a binomial proportion confidence interval. | ||
* | ||
* @see <a | ||
* href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Normal_approximation_interval"> | ||
* Normal approximation interval (Wikipedia)</a> | ||
*/ | ||
case class WaldIntervalStrategy() extends ConfidenceIntervalStrategy { | ||
def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = { | ||
validateInput(pHat, confidence) | ||
val successRatio = BigDecimal(pHat) | ||
val marginOfError = BigDecimal(calculateZScore(confidence) * math.sqrt(pHat * (1 - pHat) / numRecords)) | ||
val lowerBound = (successRatio - marginOfError).setScale(2, RoundingMode.DOWN).toDouble | ||
val upperBound = (successRatio + marginOfError).setScale(2, RoundingMode.UP).toDouble | ||
ConfidenceInterval(lowerBound, upperBound) | ||
} | ||
} |
27 changes: 27 additions & 0 deletions
27
src/main/scala/com/amazon/deequ/suggestions/rules/interval/WilsonScoreIntervalStrategy.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package com.amazon.deequ.suggestions.rules.interval | ||
|
||
import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.{ConfidenceInterval, defaultConfidence} | ||
|
||
import scala.math.BigDecimal.RoundingMode | ||
|
||
/** | ||
* Using Wilson score method for creating a binomial proportion confidence interval. | ||
* | ||
* @see <a | ||
* href="http://en.wikipedia.org/wiki/Binomial_proportion_confidence_interval#Wilson_score_interval"> | ||
* Wilson score interval (Wikipedia)</a> | ||
*/ | ||
case class WilsonScoreIntervalStrategy() extends ConfidenceIntervalStrategy { | ||
|
||
def calculateTargetConfidenceInterval(pHat: Double, numRecords: Long, confidence: Double = defaultConfidence): ConfidenceInterval = { | ||
validateInput(pHat, confidence) | ||
val zScore = calculateZScore(confidence) | ||
val zSquareOverN = math.pow(zScore, 2) / numRecords | ||
val factor = 1.0 / (1 + zSquareOverN) | ||
val adjustedSuccessRatio = pHat + zSquareOverN/2 | ||
val marginOfError = zScore * math.sqrt(pHat * (1 - pHat)/numRecords + zSquareOverN/(4 * numRecords)) | ||
val lowerBound = BigDecimal(factor * (adjustedSuccessRatio - marginOfError)).setScale(2, RoundingMode.DOWN).toDouble | ||
val upperBound = BigDecimal(factor * (adjustedSuccessRatio + marginOfError)).setScale(2, RoundingMode.UP).toDouble | ||
ConfidenceInterval(lowerBound, upperBound) | ||
} | ||
} |
32 changes: 32 additions & 0 deletions
32
src/test/scala/com/amazon/deequ/suggestions/rules/interval/IntervalStrategyTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package com.amazon.deequ.suggestions.rules.interval | ||
|
||
import com.amazon.deequ.SparkContextSpec | ||
import com.amazon.deequ.suggestions.rules.interval.ConfidenceIntervalStrategy.ConfidenceInterval | ||
import com.amazon.deequ.utils.FixtureSupport | ||
import org.scalamock.scalatest.MockFactory | ||
import org.scalatest.wordspec.AnyWordSpec | ||
|
||
class IntervalStrategyTest extends AnyWordSpec with FixtureSupport with SparkContextSpec | ||
with MockFactory { | ||
"WaldIntervalStrategy" should { | ||
"be calculated correctly" in { | ||
assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(1.0, 1.0)) | ||
assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6)) | ||
assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5)) | ||
assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7)) | ||
assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.84, 0.96)) | ||
assert(WaldIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(1.0, 1.0)) | ||
} | ||
} | ||
|
||
"WilsonIntervalStrategy" should { | ||
"be calculated correctly" in { | ||
assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 20L) == ConfidenceInterval(0.83, 1.0)) | ||
assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.5, 100L) == ConfidenceInterval(0.4, 0.6)) | ||
assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.4, 100L) == ConfidenceInterval(0.3, 0.5)) | ||
assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.6, 100L) == ConfidenceInterval(0.5, 0.7)) | ||
assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(0.90, 100L) == ConfidenceInterval(0.82, 0.95)) | ||
assert(WilsonScoreIntervalStrategy().calculateTargetConfidenceInterval(1.0, 100L) == ConfidenceInterval(0.96, 1.0)) | ||
} | ||
} | ||
} |