Skip to content

Commit

Permalink
Add support for EntityTypes dqdl rule
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshua Zexter committed Jul 29, 2024
1 parent 101142e commit 0ba95ac
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.analyzers

import com.amazon.deequ.metrics.AttributeDoubleMetric
import com.amazon.deequ.metrics.Entity
import org.apache.spark.sql.DataFrame

import scala.util.Failure
import scala.util.Success
import scala.util.Try

// Define a custom state to hold aggregation results
case class AggregatedMetricState(counts: Map[String, Int], totalRows: Int)
extends DoubleValuedState[AggregatedMetricState] {

def sum(other: AggregatedMetricState): AggregatedMetricState = {
val combinedCounts = counts ++ other
.counts
.map { case (k, v) => k -> (v + counts.getOrElse(k, 0)) }
AggregatedMetricState(combinedCounts, totalRows + other.totalRows)
}

def metricValue(): Double = counts.values.sum.toDouble / totalRows
}

// Define the analyzer
case class ConditionalAggregationAnalyzer(aggregatorFunc: DataFrame => AggregatedMetricState, metricName: String, instance: String)
extends Analyzer[AggregatedMetricState, AttributeDoubleMetric] {

def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[AggregatedMetricState] = {
Try(aggregatorFunc(data)) match {
case Success(state) => Some(state)
case Failure(_) => None
}
}

def computeMetricFrom(state: Option[AggregatedMetricState]): AttributeDoubleMetric = {
state match {
case Some(detState) =>
val metrics = detState.counts.map { case (key, count) =>
key -> (count.toDouble / detState.totalRows)
}
AttributeDoubleMetric(Entity.Column, metricName, instance, Success(metrics))
case None =>
AttributeDoubleMetric(Entity.Column, metricName, instance,
Failure(new RuntimeException("Metric computation failed")))
}
}

override private[deequ] def toFailureMetric(failure: Exception): AttributeDoubleMetric = {
AttributeDoubleMetric(Entity.Column, metricName, instance, Failure(failure))
}
}
17 changes: 17 additions & 0 deletions src/main/scala/com/amazon/deequ/metrics/Metric.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,20 @@ case class KeyedDoubleMetric(
}
}
}

case class AttributeDoubleMetric(
entity: Entity.Value,
name: String,
instance: String,
value: Try[Map[String, Double]])
extends Metric[Map[String, Double]] {

override def flatten(): Seq[DoubleMetric] = {
value match {
case Success(valuesMap) => valuesMap.map { case (key, metricValue) =>
DoubleMetric(entity, s"$name.$key", instance, Success(metricValue))
}.toSeq
case Failure(ex) => Seq(DoubleMetric(entity, name, instance, Failure(ex)))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/**
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.analyzers

import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.utils.FixtureSupport
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.util.Failure
import scala.util.Success

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.DataFrame

import com.amazon.deequ.metrics.AttributeDoubleMetric

class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {

"ConditionalAggregationAnalyzerTest" should {

"Example use: return correct counts for product sales in different categories" in withSparkSession
{ session =>
val data = getDfWithIdColumn(session)
val mockLambda: DataFrame => AggregatedMetricState = _ =>
AggregatedMetricState(Map("ProductA" -> 50, "ProductB" -> 45), 100)

val analyzer = ConditionalAggregationAnalyzer(mockLambda, "ProductSales", "category")

val state = analyzer.computeStateFrom(data)
val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state)

metric.value.isSuccess shouldBe true
metric.value.get should contain ("ProductA" -> 0.5)
metric.value.get should contain ("ProductB" -> 0.45)
}

"handle scenarios with no data points effectively" in withSparkSession { session =>
val data = getDfWithIdColumn(session)
val mockLambda: DataFrame => AggregatedMetricState = _ =>
AggregatedMetricState(Map.empty[String, Int], 100)

val analyzer = ConditionalAggregationAnalyzer(mockLambda,"WebsiteTraffic", "page")

val state = analyzer.computeStateFrom(data)
val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state)

metric.value.isSuccess shouldBe true
metric.value.get shouldBe empty
}

"return a failure metric when the lambda function fails" in withSparkSession { session =>
val data = getDfWithIdColumn(session)
val failingLambda: DataFrame => AggregatedMetricState = _ => throw new RuntimeException("Test failure")

val analyzer = ConditionalAggregationAnalyzer(failingLambda,"ProductSales", "category")

val state = analyzer.computeStateFrom(data)
val metric = analyzer.computeMetricFrom(state)

metric.value.isFailure shouldBe true
metric.value match {
case Success(_) => fail("Should have failed due to lambda function failure")
case Failure(exception) => exception.getMessage shouldBe "Metric computation failed"
}
}

"return a failure metric if there are no rows in DataFrame" in withSparkSession { session =>
val emptyData = session.createDataFrame(session.sparkContext.emptyRDD[org.apache.spark.sql.Row],
getDfWithIdColumn(session).schema)
val mockLambda: DataFrame => AggregatedMetricState = df =>
if (df.isEmpty) throw new RuntimeException("No data to analyze") // Explicitly fail if the data is empty
else AggregatedMetricState(Map("ProductA" -> 0, "ProductB" -> 0), 0)

val analyzer = ConditionalAggregationAnalyzer(mockLambda,"ProductSales", "category")

val state = analyzer.computeStateFrom(emptyData)
val metric = analyzer.computeMetricFrom(state)

metric.value.isFailure shouldBe true
metric.value match {
case Success(_) => fail("Should have failed due to no data")
case Failure(exception) => exception.getMessage should include("Metric computation failed")
}
}
}

def getDfWithIdColumn(session: SparkSession): DataFrame = {
import session.implicits._
Seq(
("ProductA", "North"),
("ProductA", "South"),
("ProductB", "East"),
("ProductA", "West")
).toDF("product", "region")
}
}

0 comments on commit 0ba95ac

Please sign in to comment.