diff --git a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala index e2d33672a..42a7e72e5 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala @@ -16,17 +16,20 @@ package com.amazon.deequ.analyzers +import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count} import com.amazon.deequ.analyzers.runners.{IllegalAnalyzerParameterException, MetricCalculationException} import com.amazon.deequ.metrics.{Distribution, DistributionValue, HistogramMetric} import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.sql.functions.{col, sum} +import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} import org.apache.spark.sql.{DataFrame, Row} + import scala.util.{Failure, Try} /** * Histogram is the summary of values in a column of a DataFrame. Groups the given column's values, - * and calculates the number of rows with that specific value and the fraction of this value. + * and calculates either number of rows or with that specific value and the fraction of this value or + * sum of values in other column. * * @param column Column to do histogram analysis on * @param binningUdf Optional binning function to run before grouping to re-categorize the @@ -37,13 +40,15 @@ import scala.util.{Failure, Try} * maxBins sets the N. * This limit does not affect what is being returned as number of bins. It * always returns the dictinct value count. + * @param aggregateFunction function that implements aggregation logic. */ case class Histogram( column: String, binningUdf: Option[UserDefinedFunction] = None, maxDetailBins: Integer = Histogram.MaximumAllowedDetailBins, where: Option[String] = None, - computeFrequenciesAsRatio: Boolean = true) + computeFrequenciesAsRatio: Boolean = true, + aggregateFunction: AggregateFunction = Count) extends Analyzer[FrequenciesAndNumRows, HistogramMetric] with FilterableAnalyzer { @@ -58,19 +63,15 @@ case class Histogram( // TODO figure out a way to pass this in if its known before hand val totalCount = if (computeFrequenciesAsRatio) { - data.count() + aggregateFunction.total(data) } else { 1 } - val frequencies = data + val df = data .transform(filterOptional(where)) .transform(binOptional(binningUdf)) - .select(col(column).cast(StringType)) - .na.fill(Histogram.NullFieldReplacement) - .groupBy(column) - .count() - .withColumnRenamed("count", Analyzers.COUNT_COL) + val frequencies = query(df) Some(FrequenciesAndNumRows(frequencies, totalCount)) } @@ -125,11 +126,67 @@ case class Histogram( case _ => data } } + + private def query(data: DataFrame): DataFrame = { + aggregateFunction.query(this.column, data) + } } object Histogram { val NullFieldReplacement = "NullValue" val MaximumAllowedDetailBins = 1000 + val count_function = "count" + val sum_function = "sum" + + sealed trait AggregateFunction { + def query(column: String, data: DataFrame): DataFrame + + def total(data: DataFrame): Long + + def aggregateColumn(): Option[String] + + def function(): String + } + + case object Count extends AggregateFunction { + override def query(column: String, data: DataFrame): DataFrame = { + data + .select(col(column).cast(StringType)) + .na.fill(Histogram.NullFieldReplacement) + .groupBy(column) + .count() + .withColumnRenamed("count", Analyzers.COUNT_COL) + } + + override def aggregateColumn(): Option[String] = None + + override def function(): String = count_function + + override def total(data: DataFrame): Long = { + data.count() + } + } + + case class Sum(aggColumn: String) extends AggregateFunction { + override def query(column: String, data: DataFrame): DataFrame = { + data + .select(col(column).cast(StringType), col(aggColumn).cast(LongType)) + .na.fill(Histogram.NullFieldReplacement) + .groupBy(column) + .sum(aggColumn) + .withColumnRenamed("count", Analyzers.COUNT_COL) + } + + override def total(data: DataFrame): Long = { + data.groupBy().sum(aggColumn).first().getLong(0) + } + + override def aggregateColumn(): Option[String] = { + Some(aggColumn) + } + + override def function(): String = sum_function + } } object OrderByAbsoluteCount extends Ordering[Row] { diff --git a/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala b/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala index af267ad90..4a2ca1058 100644 --- a/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala +++ b/src/main/scala/com/amazon/deequ/repository/AnalysisResultSerde.scala @@ -30,6 +30,7 @@ import scala.collection._ import scala.collection.JavaConverters._ import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap} import JsonSerializationConstants._ +import com.amazon.deequ.analyzers.Histogram.{AggregateFunction, Count => HistogramCount, Sum => HistogramSum} import org.apache.spark.sql.Column import org.apache.spark.sql.functions.expr @@ -302,6 +303,12 @@ private[deequ] object AnalyzerSerializer result.addProperty(ANALYZER_NAME_FIELD, "Histogram") result.addProperty(COLUMN_FIELD, histogram.column) result.addProperty("maxDetailBins", histogram.maxDetailBins) + // Count is initial and default implementation for Histogram + // We don't include fields below in json to preserve json backward compatibility. + if (histogram.aggregateFunction != Histogram.Count) { + result.addProperty("aggregateFunction", histogram.aggregateFunction.function()) + result.addProperty("aggregateColumn", histogram.aggregateFunction.aggregateColumn().get) + } case _ : Histogram => throw new IllegalArgumentException("Unable to serialize Histogram with binningUdf!") @@ -436,7 +443,10 @@ private[deequ] object AnalyzerDeserializer Histogram( json.get(COLUMN_FIELD).getAsString, None, - json.get("maxDetailBins").getAsInt) + json.get("maxDetailBins").getAsInt, + aggregateFunction = createAggregateFunction( + getOptionalStringParam(json, "aggregateFunction").getOrElse(Histogram.count_function), + getOptionalStringParam(json, "aggregateColumn").getOrElse(""))) case "DataType" => DataType( @@ -489,12 +499,24 @@ private[deequ] object AnalyzerDeserializer } private[this] def getOptionalWhereParam(jsonObject: JsonObject): Option[String] = { - if (jsonObject.has(WHERE_FIELD)) { - Option(jsonObject.get(WHERE_FIELD).getAsString) + getOptionalStringParam(jsonObject, WHERE_FIELD) + } + + private[this] def getOptionalStringParam(jsonObject: JsonObject, field: String): Option[String] = { + if (jsonObject.has(field)) { + Option(jsonObject.get(field).getAsString) } else { None } } + + private[this] def createAggregateFunction(function: String, aggregateColumn: String): AggregateFunction = { + function match { + case Histogram.count_function => HistogramCount + case Histogram.sum_function => HistogramSum(aggregateColumn) + case _ => throw new IllegalArgumentException("Wrong aggregate function name: " + function) + } + } } private[deequ] object MetricSerializer extends JsonSerializer[Metric[_]] { diff --git a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala index fc771d0d2..03787b886 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/AnalyzerTests.scala @@ -260,6 +260,26 @@ class AnalyzerTests extends AnyWordSpec with Matchers with SparkContextSpec with } } + "compute correct sum metrics " in withSparkSession { sparkSession => + val dfFull = getDateDf(sparkSession) + val histogram = Histogram("product", aggregateFunction = Histogram.Sum("units")).calculate(dfFull) + assert(histogram.value.isSuccess) + + histogram.value.get match { + case hv => + assert(hv.numberOfBins == 3) + assert(hv.values.size == 3) + assert(hv.values.keys == Set("Furniture", "Cosmetics", "Electronics")) + assert(hv("Furniture").absolute == 55) + assert(hv("Furniture").ratio == 55.0 / (55 + 20 + 60)) + assert(hv("Cosmetics").absolute == 20) + assert(hv("Cosmetics").ratio == 20.0 / (55 + 20 + 60)) + assert(hv("Electronics").absolute == 60) + assert(hv("Electronics").ratio == 60.0 / (55 + 20 + 60)) + + } + } + "compute correct metrics on numeric values" in withSparkSession { sparkSession => val dfFull = getDfWithNumericValues(sparkSession) val histogram = Histogram("att2").calculate(dfFull) diff --git a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala index 7c4d91625..254fac9b4 100644 --- a/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala +++ b/src/test/scala/com/amazon/deequ/analyzers/runners/AnalyzerContextTest.scala @@ -86,13 +86,13 @@ class AnalyzerContextTest extends AnyWordSpec |{"entity":"Column","instance":"item","name":"Distinctness","value":1.0}, |{"entity":"Column","instance":"att1","name":"Completeness","value":1.0}, |{"entity":"Multicolumn","instance":"att1,att2","name":"Uniqueness","value":0.25}, + |{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0}, + |{"entity":"Dataset","instance":"*","name":"Size","value":4.0}, |{"entity":"Column","instance":"att1","name":"Histogram.bins","value":2.0}, |{"entity":"Column","instance":"att1","name":"Histogram.abs.a","value":3.0}, |{"entity":"Column","instance":"att1","name":"Histogram.ratio.a","value":0.75}, |{"entity":"Column","instance":"att1","name":"Histogram.abs.b","value":1.0}, - |{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25}, - |{"entity":"Dataset","instance":"*","name":"Size (where: att2 == 'd')","value":1.0}, - |{"entity":"Dataset","instance":"*","name":"Size","value":4.0} + |{"entity":"Column","instance":"att1","name":"Histogram.ratio.b","value":0.25} |]""" .stripMargin.replaceAll("\n", "") diff --git a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala index 3b15582c0..7083a5c1d 100644 --- a/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala +++ b/src/test/scala/com/amazon/deequ/repository/AnalysisResultSerdeTest.scala @@ -173,6 +173,112 @@ class AnalysisResultSerdeTest extends FlatSpec with Matchers { assertCorrectlyConvertsAnalysisResults(Seq(result)) } + val histogramSumJson = + """[ + | { + | "resultKey": { + | "dataSetDate": 0, + | "tags": {} + | }, + | "analyzerContext": { + | "metricMap": [ + | { + | "analyzer": { + | "analyzerName": "Histogram", + | "column": "columnA", + | "maxDetailBins": 1000, + | "aggregateFunction": "sum", + | "aggregateColumn": "columnB" + | }, + | "metric": { + | "metricName": "HistogramMetric", + | "column": "columnA", + | "numberOfBins": 10, + | "value": { + | "numberOfBins": 10, + | "values": { + | "some": { + | "absolute": 10, + | "ratio": 0.5 + | } + | } + | } + | } + | } + | ] + | } + | } + |]""".stripMargin + val histogramCountJson = + """[ + | { + | "resultKey": { + | "dataSetDate": 0, + | "tags": {} + | }, + | "analyzerContext": { + | "metricMap": [ + | { + | "analyzer": { + | "analyzerName": "Histogram", + | "column": "columnA", + | "maxDetailBins": 1000 + | }, + | "metric": { + | "metricName": "HistogramMetric", + | "column": "columnA", + | "numberOfBins": 10, + | "value": { + | "numberOfBins": 10, + | "values": { + | "some": { + | "absolute": 10, + | "ratio": 0.5 + | } + | } + | } + | } + | } + | ] + | } + | } + |]""".stripMargin + + "Histogram serialization" should "be backward compatible for count" in { + val expected = histogramCountJson + val analyzer = Histogram("columnA") + val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10))) + val context = AnalyzerContext(Map(analyzer -> metric)) + val result = new AnalysisResult(ResultKey(0), context) + assert(serialize(Seq(result)) == expected) + } + + "Histogram serialization" should "properly serialize sum" in { + val expected = histogramSumJson + val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB")) + val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10))) + val context = AnalyzerContext(Map(analyzer -> metric)) + val result = new AnalysisResult(ResultKey(0), context) + assert(serialize(Seq(result)) == expected) + } + + "Histogram deserialization" should "be backward compatible for count" in { + val analyzer = Histogram("columnA") + val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10))) + val context = AnalyzerContext(Map(analyzer -> metric)) + val expected = new AnalysisResult(ResultKey(0), context) + assert(deserialize(histogramCountJson) == List(expected)) + } + + "Histogram deserialization" should "properly deserialize sum" in { + val analyzer = Histogram("columnA", aggregateFunction = Histogram.Sum("columnB")) + val metric = HistogramMetric("columnA", Success(Distribution(Map("some" -> DistributionValue(10, 0.5)), 10))) + val context = AnalyzerContext(Map(analyzer -> metric)) + val expected = new AnalysisResult(ResultKey(0), context) + assert(deserialize(histogramSumJson) == List(expected)) + } + + def assertCorrectlyConvertsAnalysisResults( analysisResults: Seq[AnalysisResult], shouldFail: Boolean = false)