From 0ba95ac2c48ed0548fb1dc027adceac9bdbab624 Mon Sep 17 00:00:00 2001 From: Joshua Zexter Date: Fri, 26 Jul 2024 11:56:21 -0400 Subject: [PATCH] Add support for EntityTypes dqdl rule --- .../ConditionalAggregationAnalyzer.scala | 67 +++++++++++ .../com/amazon/deequ/metrics/Metric.scala | 17 +++ .../ConditionalAggregationAnalyzerTest.scala | 110 ++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala create mode 100644 src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala diff --git a/src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala new file mode 100644 index 000000000..29a3198ff --- /dev/null +++ b/src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala @@ -0,0 +1,67 @@ +/** + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.metrics.AttributeDoubleMetric +import com.amazon.deequ.metrics.Entity +import org.apache.spark.sql.DataFrame + +import scala.util.Failure +import scala.util.Success +import scala.util.Try + +// Define a custom state to hold aggregation results +case class AggregatedMetricState(counts: Map[String, Int], totalRows: Int) + extends DoubleValuedState[AggregatedMetricState] { + + def sum(other: AggregatedMetricState): AggregatedMetricState = { + val combinedCounts = counts ++ other + .counts + .map { case (k, v) => k -> (v + counts.getOrElse(k, 0)) } + AggregatedMetricState(combinedCounts, totalRows + other.totalRows) + } + + def metricValue(): Double = counts.values.sum.toDouble / totalRows +} + +// Define the analyzer +case class ConditionalAggregationAnalyzer(aggregatorFunc: DataFrame => AggregatedMetricState, metricName: String, instance: String) + extends Analyzer[AggregatedMetricState, AttributeDoubleMetric] { + + def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[AggregatedMetricState] = { + Try(aggregatorFunc(data)) match { + case Success(state) => Some(state) + case Failure(_) => None + } + } + + def computeMetricFrom(state: Option[AggregatedMetricState]): AttributeDoubleMetric = { + state match { + case Some(detState) => + val metrics = detState.counts.map { case (key, count) => + key -> (count.toDouble / detState.totalRows) + } + AttributeDoubleMetric(Entity.Column, metricName, instance, Success(metrics)) + case None => + AttributeDoubleMetric(Entity.Column, metricName, instance, + Failure(new RuntimeException("Metric computation failed"))) + } + } + + override private[deequ] def toFailureMetric(failure: Exception): AttributeDoubleMetric = { + AttributeDoubleMetric(Entity.Column, metricName, instance, Failure(failure)) + } +} diff --git a/src/main/scala/com/amazon/deequ/metrics/Metric.scala b/src/main/scala/com/amazon/deequ/metrics/Metric.scala index 30225e246..307b278d1 100644 --- a/src/main/scala/com/amazon/deequ/metrics/Metric.scala +++ b/src/main/scala/com/amazon/deequ/metrics/Metric.scala @@ -89,3 +89,20 @@ case class KeyedDoubleMetric( } } } + +case class AttributeDoubleMetric( + entity: Entity.Value, + name: String, + instance: String, + value: Try[Map[String, Double]]) + extends Metric[Map[String, Double]] { + + override def flatten(): Seq[DoubleMetric] = { + value match { + case Success(valuesMap) => valuesMap.map { case (key, metricValue) => + DoubleMetric(entity, s"$name.$key", instance, Success(metricValue)) + }.toSeq + case Failure(ex) => Seq(DoubleMetric(entity, name, instance, Failure(ex))) + } + } +} diff --git a/src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala new file mode 100644 index 000000000..fe15b8892 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala @@ -0,0 +1,110 @@ +/** + * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.analyzers + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.utils.FixtureSupport +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Failure +import scala.util.Success + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.DataFrame + +import com.amazon.deequ.metrics.AttributeDoubleMetric + +class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { + + "ConditionalAggregationAnalyzerTest" should { + + "Example use: return correct counts for product sales in different categories" in withSparkSession + { session => + val data = getDfWithIdColumn(session) + val mockLambda: DataFrame => AggregatedMetricState = _ => + AggregatedMetricState(Map("ProductA" -> 50, "ProductB" -> 45), 100) + + val analyzer = ConditionalAggregationAnalyzer(mockLambda, "ProductSales", "category") + + val state = analyzer.computeStateFrom(data) + val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get should contain ("ProductA" -> 0.5) + metric.value.get should contain ("ProductB" -> 0.45) + } + + "handle scenarios with no data points effectively" in withSparkSession { session => + val data = getDfWithIdColumn(session) + val mockLambda: DataFrame => AggregatedMetricState = _ => + AggregatedMetricState(Map.empty[String, Int], 100) + + val analyzer = ConditionalAggregationAnalyzer(mockLambda,"WebsiteTraffic", "page") + + val state = analyzer.computeStateFrom(data) + val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state) + + metric.value.isSuccess shouldBe true + metric.value.get shouldBe empty + } + + "return a failure metric when the lambda function fails" in withSparkSession { session => + val data = getDfWithIdColumn(session) + val failingLambda: DataFrame => AggregatedMetricState = _ => throw new RuntimeException("Test failure") + + val analyzer = ConditionalAggregationAnalyzer(failingLambda,"ProductSales", "category") + + val state = analyzer.computeStateFrom(data) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed due to lambda function failure") + case Failure(exception) => exception.getMessage shouldBe "Metric computation failed" + } + } + + "return a failure metric if there are no rows in DataFrame" in withSparkSession { session => + val emptyData = session.createDataFrame(session.sparkContext.emptyRDD[org.apache.spark.sql.Row], + getDfWithIdColumn(session).schema) + val mockLambda: DataFrame => AggregatedMetricState = df => + if (df.isEmpty) throw new RuntimeException("No data to analyze") // Explicitly fail if the data is empty + else AggregatedMetricState(Map("ProductA" -> 0, "ProductB" -> 0), 0) + + val analyzer = ConditionalAggregationAnalyzer(mockLambda,"ProductSales", "category") + + val state = analyzer.computeStateFrom(emptyData) + val metric = analyzer.computeMetricFrom(state) + + metric.value.isFailure shouldBe true + metric.value match { + case Success(_) => fail("Should have failed due to no data") + case Failure(exception) => exception.getMessage should include("Metric computation failed") + } + } + } + + def getDfWithIdColumn(session: SparkSession): DataFrame = { + import session.implicits._ + Seq( + ("ProductA", "North"), + ("ProductA", "South"), + ("ProductB", "East"), + ("ProductA", "West") + ).toDF("product", "region") + } +}