Skip to content

Commit

Permalink
Column Count Analyzer and Check (awslabs#555)
Browse files Browse the repository at this point in the history
* Fix flaky KLL test

* Move CustomSql state to CustomSql analyzer

* Implement new Analyzer to count columns

* Improve documentation, remove unused parameter, replace if/else with map

---------

Co-authored-by: Yannis Mentekidis <[email protected]>
  • Loading branch information
2 people authored and rdsharma26 committed Apr 16, 2024
1 parent 614a91f commit b141f4b
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 13 deletions.
63 changes: 63 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/ColumnCount.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*
*/

package com.amazon.deequ.analyzers

import com.amazon.deequ.metrics.DoubleMetric
import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.DataFrame

case class ColumnCount() extends Analyzer[NumMatches, DoubleMetric] {

val name = "ColumnCount"
val instance = "*"
val entity = Entity.Dataset

/**
* Compute the state (sufficient statistics) from the data
*
* @param data the input dataframe
* @return the number of columns in the input
*/
override def computeStateFrom(data: DataFrame, filterCondition: Option[String]): Option[NumMatches] = {
if (filterCondition.isDefined) {
throw new IllegalArgumentException("ColumnCount does not accept a filter condition")
} else {
val numColumns = data.columns.size
Some(NumMatches(numColumns))
}
}

/**
* Compute the metric from the state (sufficient statistics)
*
* @param state the computed state from [[computeStateFrom]]
* @return a double metric indicating the number of columns for this analyzer
*/
override def computeMetricFrom(state: Option[NumMatches]): DoubleMetric = {
state
.map(v => Analyzers.metricFromValue(v.metricValue(), name, instance, entity))
.getOrElse(Analyzers.metricFromEmpty(this, name, instance, entity))
}

/**
* Compute the metric from a failure - reports the exception thrown while trying to count columns
*/
override private[deequ] def toFailureMetric(failure: Exception): DoubleMetric = {
Analyzers.metricFromFailure(failure, name, instance, entity)
}
}
11 changes: 11 additions & 0 deletions src/main/scala/com/amazon/deequ/analyzers/CustomSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ import scala.util.Failure
import scala.util.Success
import scala.util.Try

case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] {
lazy val state = stateOrError.left.get
lazy val error = stateOrError.right.get

override def sum(other: CustomSqlState): CustomSqlState = {
CustomSqlState(Left(state + other.state))
}

override def metricValue(): Double = state
}

case class CustomSql(expression: String) extends Analyzer[CustomSqlState, DoubleMetric] {
/**
* Compute the state (sufficient statistics) from the data
Expand Down
11 changes: 0 additions & 11 deletions src/main/scala/com/amazon/deequ/analyzers/Size.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,6 @@ import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.{Column, Row}
import Analyzers._

case class CustomSqlState(stateOrError: Either[Double, String]) extends DoubleValuedState[CustomSqlState] {
lazy val state = stateOrError.left.get
lazy val error = stateOrError.right.get

override def sum(other: CustomSqlState): CustomSqlState = {
CustomSqlState(Left(state + other.state))
}

override def metricValue(): Double = state
}

case class NumMatches(numMatches: Long) extends DoubleValuedState[NumMatches] {

override def sum(other: NumMatches): NumMatches = {
Expand Down
7 changes: 7 additions & 0 deletions src/main/scala/com/amazon/deequ/checks/Check.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ case class Check(
addFilterableConstraint { filter => Constraint.sizeConstraint(assertion, filter, hint) }
}

def hasColumnCount(assertion: Long => Boolean, hint: Option[String] = None)
: CheckWithLastConstraintFilterable = {
addFilterableConstraint {
filter => Constraint.columnCountConstraint(assertion, hint)
}
}

/**
* Creates a constraint that asserts on a column completion.
*
Expand Down
12 changes: 12 additions & 0 deletions src/main/scala/com/amazon/deequ/constraints/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,18 @@ object Constraint {
new NamedConstraint(constraint, s"SizeConstraint($size)")
}

def columnCountConstraint(assertion: Long => Boolean, hint: Option[String] = None): Constraint = {
val colCount = ColumnCount()
fromAnalyzer(colCount, assertion, hint)
}


def fromAnalyzer(colCount: ColumnCount, assertion: Long => Boolean, hint: Option[String]): Constraint = {
val constraint = AnalysisBasedConstraint[NumMatches, Double, Long](colCount, assertion, Some(_.toLong), hint)

new NamedConstraint(constraint, name = s"ColumnCountConstraint($colCount)")
}

/**
* Runs Histogram analysis on the given column and executes the assertion
*
Expand Down
5 changes: 3 additions & 2 deletions src/test/scala/com/amazon/deequ/KLL/KLLDistanceTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import com.amazon.deequ.analyzers.{Distance, QuantileNonSample}
import com.amazon.deequ.metrics.BucketValue
import com.amazon.deequ.utils.FixtureSupport
import org.scalatest.WordSpec
import com.amazon.deequ.metrics.{BucketValue}
import com.amazon.deequ.metrics.BucketValue
import org.scalactic.Tolerance.convertNumericToPlusOrMinusWrapper

class KLLDistanceTest extends WordSpec with SparkContextSpec
with FixtureSupport{
Expand Down Expand Up @@ -88,7 +89,7 @@ class KLLDistanceTest extends WordSpec with SparkContextSpec
val sample2 = scala.collection.mutable.Map(
"a" -> 22L, "b" -> 20L, "c" -> 25L, "d" -> 12L, "e" -> 13L, "f" -> 15L)
val distance = Distance.categoricalDistance(sample1, sample2, method = LInfinityMethod(alpha = Some(0.003)))
assert(distance == 0.2726338046550349)
assert(distance === 0.2726338046550349 +- 1E-14)
}

"Categorial distance should compute correct linf_robust with different alpha value .1" in {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class VerificationSuiteTest extends WordSpec with Matchers with SparkContextSpec

val checkToSucceed = Check(CheckLevel.Error, "group-1")
.isComplete("att1")
.hasColumnCount(_ == 3)
.hasCompleteness("att1", _ == 1.0)

val checkToErrorOut = Check(CheckLevel.Error, "group-2-E")
Expand Down
45 changes: 45 additions & 0 deletions src/test/scala/com/amazon/deequ/analyzers/ColumnCountTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
* either express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*
*/

package com.amazon.deequ.analyzers

import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.utils.FixtureSupport
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.util.Failure
import scala.util.Success

class ColumnCountTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {
"ColumnCount" should {
"return column count for a dataset" in withSparkSession { session =>
val data = getDfWithStringColumns(session)
val colCount = ColumnCount()

val state = colCount.computeStateFrom(data)
state.isDefined shouldBe true
state.get.metricValue() shouldBe 5.0

val metric = colCount.computeMetricFrom(state)
metric.fullColumn shouldBe None
metric.value shouldBe Success(5.0)
}
}
}

0 comments on commit b141f4b

Please sign in to comment.