Skip to content

Commit

Permalink
Add support for ConditionalAggregationAnalyzer
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshua Zexter committed Jul 29, 2024
2 parents 4118c50 + 0ba95ac commit e120ec5
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/**
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
Expand Down Expand Up @@ -38,10 +37,13 @@ case class AggregatedMetricState(counts: Map[String, Int], totalRows: Int)
}

// Define the analyzer
case class ConditionalAggregationAnalyzer(aggregatorFunc: DataFrame => AggregatedMetricState, metricName: String, instance: String)
case class ConditionalAggregationAnalyzer(aggregatorFunc: DataFrame => AggregatedMetricState,
metricName: String,
instance: String)
extends Analyzer[AggregatedMetricState, AttributeDoubleMetric] {

def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None): Option[AggregatedMetricState] = {
def computeStateFrom(data: DataFrame, filterCondition: Option[String] = None)
: Option[AggregatedMetricState] = {
Try(aggregatorFunc(data)) match {
case Success(state) => Some(state)
case Failure(_) => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@ import org.apache.spark.sql.DataFrame

import com.amazon.deequ.metrics.AttributeDoubleMetric

class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {
class ConditionalAggregationAnalyzerTest
extends AnyWordSpec with Matchers with SparkContextSpec with FixtureSupport {

"ConditionalAggregationAnalyzerTest" should {

"Example use: return correct counts for product sales in different categories" in withSparkSession
"""Example use: return correct counts
|for product sales in different categories""".stripMargin in withSparkSession
{ session =>
val data = getDfWithIdColumn(session)
val mockLambda: DataFrame => AggregatedMetricState = _ =>
Expand All @@ -53,7 +55,7 @@ class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with
val mockLambda: DataFrame => AggregatedMetricState = _ =>
AggregatedMetricState(Map.empty[String, Int], 100)

val analyzer = ConditionalAggregationAnalyzer(mockLambda,"WebsiteTraffic", "page")
val analyzer = ConditionalAggregationAnalyzer(mockLambda, "WebsiteTraffic", "page")

val state = analyzer.computeStateFrom(data)
val metric: AttributeDoubleMetric = analyzer.computeMetricFrom(state)
Expand All @@ -64,9 +66,10 @@ class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with

"return a failure metric when the lambda function fails" in withSparkSession { session =>
val data = getDfWithIdColumn(session)
val failingLambda: DataFrame => AggregatedMetricState = _ => throw new RuntimeException("Test failure")
val failingLambda: DataFrame => AggregatedMetricState =
_ => throw new RuntimeException("Test failure")

val analyzer = ConditionalAggregationAnalyzer(failingLambda,"ProductSales", "category")
val analyzer = ConditionalAggregationAnalyzer(failingLambda, "ProductSales", "category")

val state = analyzer.computeStateFrom(data)
val metric = analyzer.computeMetricFrom(state)
Expand All @@ -79,13 +82,16 @@ class ConditionalAggregationAnalyzerTest extends AnyWordSpec with Matchers with
}

"return a failure metric if there are no rows in DataFrame" in withSparkSession { session =>
val emptyData = session.createDataFrame(session.sparkContext.emptyRDD[org.apache.spark.sql.Row],
val emptyData = session.createDataFrame(
session.sparkContext.emptyRDD[org.apache.spark.sql.Row],
getDfWithIdColumn(session).schema)
val mockLambda: DataFrame => AggregatedMetricState = df =>
if (df.isEmpty) throw new RuntimeException("No data to analyze") // Explicitly fail if the data is empty
if (df.isEmpty) throw new RuntimeException("No data to analyze")
else AggregatedMetricState(Map("ProductA" -> 0, "ProductB" -> 0), 0)

val analyzer = ConditionalAggregationAnalyzer(mockLambda,"ProductSales", "category")
val analyzer = ConditionalAggregationAnalyzer(mockLambda,
"ProductSales",
"category")

val state = analyzer.computeStateFrom(emptyData)
val metric = analyzer.computeMetricFrom(state)
Expand Down

0 comments on commit e120ec5

Please sign in to comment.