diff --git a/src/main/scala/com/amazon/deequ/VerificationResult.scala b/src/main/scala/com/amazon/deequ/VerificationResult.scala index b9b450f2..418a622e 100644 --- a/src/main/scala/com/amazon/deequ/VerificationResult.scala +++ b/src/main/scala/com/amazon/deequ/VerificationResult.scala @@ -31,7 +31,7 @@ import com.amazon.deequ.repository.SimpleResultSerde import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrame import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.{col, monotonically_increasing_id} +import org.apache.spark.sql.functions.monotonically_increasing_id import java.util.UUID @@ -96,10 +96,9 @@ object VerificationResult { data: DataFrame): DataFrame = { val columnNamesToMetrics: Map[String, Column] = verificationResultToColumn(verificationResult) - val columnsAliased = columnNamesToMetrics.toSeq.map { case (name, col) => col.as(name) } val dataWithID = data.withColumn(UNIQUENESS_ID, monotonically_increasing_id()) - dataWithID.select(col("*") +: columnsAliased: _*).drop(UNIQUENESS_ID) + dataWithID.withColumns(columnNamesToMetrics).drop(UNIQUENESS_ID) } def checkResultsAsJson(verificationResult: VerificationResult, diff --git a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala index fbdb5b2b..742b2ba6 100644 --- a/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala +++ b/src/main/scala/com/amazon/deequ/analyzers/Histogram.scala @@ -84,26 +84,11 @@ case class Histogram( case Some(theState) => val value: Try[Distribution] = Try { - val countColumnName = theState.frequencies.schema.fields - .find(field => field.dataType == LongType && field.name != column) - .map(_.name) - .getOrElse(throw new IllegalStateException(s"Count column not found in the frequencies DataFrame")) - - val topNRowsDF = theState.frequencies - .orderBy(col(countColumnName).desc) - .limit(maxDetailBins) - .collect() - + val topNRows = theState.frequencies.rdd.top(maxDetailBins)(OrderByAbsoluteCount) val binCount = theState.frequencies.count() - val columnName = theState.frequencies.columns - .find(_ == column) - .getOrElse(throw new IllegalStateException(s"Column $column not found")) - - val histogramDetails = topNRowsDF - .map { row => - val discreteValue = row.getAs[String](columnName) - val absolute = row.getAs[Long](countColumnName) + val histogramDetails = topNRows + .map { case Row(discreteValue: String, absolute: Long) => val ratio = absolute.toDouble / theState.numRows discreteValue -> DistributionValue(absolute, ratio) }