diff --git a/src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala b/src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala index cb15e130..eaccd2b2 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzer.scala @@ -1,6 +1,5 @@ /** * Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. - * * Licensed under the Apache License, Version 2.0 (the "License"). You may not * use this file except in compliance with the License. A copy of the License * is located at @@ -38,10 +37,13 @@ case class AggregatedMetricState(counts: Map[String, Int], totalRows: Int) } // Define the analyzer -case class ConditionalAggregationAnalyzer(aggregatorFunc: DataFrame => AggregatedMetricState, metricName: String, instance: String) +case class ConditionalAggregationAnalyzer(aggregatorFunc: DataFrame => AggregatedMetricState, + metricName: String, + instance: String) extends Analyzer[AggregatedMetricState, AttributeDoubleMetric] { - def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[AggregatedMetricState] = { + def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None) + : Option[AggregatedMetricState] = { Try(aggregatorFunc(data)) match { case Success(state) => Some(state) case Failure(_) => None diff --git a/src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala b/src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala index fe15b889..026d6b66 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/ConditionalAggregationAnalyzerTest.scala @@ -28,11 +28,13 @@ import org.apache.spark.sql.DataFrame import com.amazon.deequ.metrics.AttributeDoubleMetric -class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { +class ConditionalAggregationAnalyzerTest + extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport { "ConditionalAggregationAnalyzerTest" should { - "Example use: return correct counts for product sales in different categories" in withSparkSession + """Example use: return correct counts + |for product sales in different categories""".stripMargin in withSparkSession { session => val data = getDfWithIdColumn(session) val mockLambda: DataFrame => AggregatedMetricState = _ => @@ -53,7 +55,7 @@ class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with val mockLambda: DataFrame => AggregatedMetricState = _ => AggregatedMetricState(Map.empty[String, Int], 100) - val analyzer = ConditionalAggregationAnalyzer(mockLambda,"WebsiteTraffic", "page") + val analyzer = ConditionalAggregationAnalyzer(mockLambda, "WebsiteTraffic", "page") val state = analyzer.computeStateFrom(data) val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state) @@ -64,9 +66,10 @@ class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with "return a failure metric when the lambda function fails" in withSparkSession { session => val data = getDfWithIdColumn(session) - val failingLambda: DataFrame => AggregatedMetricState = _ => throw new RuntimeException("Test failure") + val failingLambda: DataFrame => AggregatedMetricState = + _ => throw new RuntimeException("Test failure") - val analyzer = ConditionalAggregationAnalyzer(failingLambda,"ProductSales", "category") + val analyzer = ConditionalAggregationAnalyzer(failingLambda, "ProductSales", "category") val state = analyzer.computeStateFrom(data) val metric = analyzer.computeMetricFrom(state) @@ -79,13 +82,16 @@ class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with } "return a failure metric if there are no rows in DataFrame" in withSparkSession { session => - val emptyData = session.createDataFrame(session.sparkContext.emptyRDD[org.apache.spark.sql.Row], + val emptyData = session.createDataFrame( + session.sparkContext.emptyRDD[org.apache.spark.sql.Row], getDfWithIdColumn(session).schema) val mockLambda: DataFrame => AggregatedMetricState = df => - if (df.isEmpty) throw new RuntimeException("No data to analyze") // Explicitly fail if the data is empty + if (df.isEmpty) throw new RuntimeException("No data to analyze") else AggregatedMetricState(Map("ProductA" -> 0, "ProductB" -> 0), 0) - val analyzer = ConditionalAggregationAnalyzer(mockLambda,"ProductSales", "category") + val analyzer = ConditionalAggregationAnalyzer(mockLambda, + "ProductSales", + "category") val state = analyzer.computeStateFrom(emptyData) val metric = analyzer.computeMetricFrom(state)