Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/100 0 n columns #111

Merged
merged 12 commits into from
Jan 2, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

package za.co.absa.atum.agent

import com.typesafe.config.{Config, ConfigFactory}
import za.co.absa.atum.agent.AtumContext.AtumPartitions
import za.co.absa.atum.agent.dispatcher.{ConsoleDispatcher, HttpDispatcher}
Expand Down
13 changes: 6 additions & 7 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import za.co.absa.atum.agent.model._
import za.co.absa.atum.model.dto._

import java.time.ZonedDateTime

import java.util.UUID
import scala.collection.immutable.ListMap

Expand All @@ -37,7 +36,7 @@ import scala.collection.immutable.ListMap
class AtumContext private[agent] (
val atumPartitions: AtumPartitions,
val agent: AtumAgent,
private var measures: Set[Measure] = Set.empty,
private var measures: Set[AtumMeasure] = Set.empty,
private var additionalData: Map[String, Option[String]] = Map.empty
) {

Expand All @@ -46,7 +45,7 @@ class AtumContext private[agent] (
*
* @return the current set of measures
*/
def currentMeasures: Set[Measure] = measures
def currentMeasures: Set[AtumMeasure] = measures

/**
* Returns the sub-partition context in the AtumContext.
Expand Down Expand Up @@ -145,7 +144,7 @@ class AtumContext private[agent] (
*
* @param measure the measure to be added
*/
def addMeasure(newMeasure: Measure): AtumContext = {
def addMeasure(newMeasure: AtumMeasure): AtumContext = {
measures = measures + newMeasure
this
}
Expand All @@ -155,7 +154,7 @@ class AtumContext private[agent] (
*
* @param measures set sequence of measures to be added
*/
def addMeasures(newMeasures: Set[Measure]): AtumContext = {
def addMeasures(newMeasures: Set[AtumMeasure]): AtumContext = {
measures = measures ++ newMeasures
this
}
Expand All @@ -165,15 +164,15 @@ class AtumContext private[agent] (
*
* @param measureToRemove the measure to be removed
*/
def removeMeasure(measureToRemove: Measure): AtumContext = {
def removeMeasure(measureToRemove: AtumMeasure): AtumContext = {
measures = measures - measureToRemove
this
}

private[agent] def copy(
atumPartitions: AtumPartitions = this.atumPartitions,
agent: AtumAgent = this.agent,
measures: Set[Measure] = this.measures,
measures: Set[AtumMeasure] = this.measures,
additionalData: Map[String, Option[String]] = this.additionalData
): AtumContext = {
new AtumContext(atumPartitions, agent, measures, additionalData)
Expand Down
213 changes: 84 additions & 129 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,197 +17,152 @@
package za.co.absa.atum.agent.model

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DecimalType, LongType, StringType}
import org.apache.spark.sql.types.{DataType, DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.atum.agent.core.MeasurementProcessor
import za.co.absa.atum.agent.core.MeasurementProcessor.{MeasurementFunction, ResultOfMeasurement}
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements

/**
* This trait represents a measure that can be applied to a column.
* Type of different measures to be applied to the columns.
*/
sealed trait Measure extends MeasurementProcessor with MeasureType {
val measuredColumn: String
sealed trait Measure {
val measureName: String
def controlColumns: Seq[String]
salamonpavel marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* This trait represents a measure type that can be applied to a column.
*/
trait MeasureType {
val measureName: String
trait AtumMeasure extends Measure with MeasurementProcessor {
val resultValueType: ResultValueType.ResultValueType
}

/**
* This object contains all the possible measures that can be applied to a column.
*/
object Measure {

private val valueColumnName: String = "value"
object AtumMeasure {

val supportedMeasures: Seq[MeasureType] = Seq(
RecordCount,
DistinctRecordCount,
SumOfValuesOfColumn,
AbsSumOfValuesOfColumn,
SumOfHashesOfColumn
val supportedMeasureNames: Seq[String] = Seq(
salamonpavel marked this conversation as resolved.
Show resolved Hide resolved
RecordCount.measureName,
DistinctRecordCount.measureName,
SumOfValuesOfColumn.measureName,
AbsSumOfValuesOfColumn.measureName,
SumOfHashesOfColumn.measureName
)
val supportedMeasureNames: Seq[String] = supportedMeasures.map(_.measureName)

case class RecordCount private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class RecordCount private (measureName: String) extends AtumMeasure {
private val columnExpression = count("*")

override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(measuredColumn)).count().toString
ResultOfMeasurement(resultValue, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0).toString, resultValueType)
}
}
object RecordCount extends MeasureType {
def apply(measuredColumn: String): RecordCount = RecordCount(measuredColumn, measureName, resultValueType)

override val measureName: String = "count"
override def controlColumns: Seq[String] = Seq.empty
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
}
object RecordCount {
private[agent] val measureName: String = "count"
def apply(): RecordCount = RecordCount(measureName)
}

case class DistinctRecordCount private (measureName: String, controlCols: Seq[String]) extends AtumMeasure {
require(controlCols.nonEmpty, "At least one control column has to be defined.")

case class DistinctRecordCount private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
private val columnExpression = countDistinct(col(controlCols.head), controlCols.tail.map(col): _*)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure if this is what we want. I thought that it's about 'distinct -> count' operation, performed on a single column. Do you remember what was used in the old Atum & what can our users require @benedeki ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the suggested code it can be performed on single column or multiple columns. If this particular measure is expected to work differently please let me know. @benedeki

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the universality of supporting multi-column uniqueness 👍


override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(measuredColumn)).distinct().count().toString
ResultOfMeasurement(resultValue, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0)(0).toString, resultValueType)
}
}
object DistinctRecordCount extends MeasureType {
def apply(measuredColumn: String): DistinctRecordCount = {
DistinctRecordCount(measuredColumn, measureName, resultValueType)
}

override val measureName: String = "distinctCount"
override def controlColumns: Seq[String] = controlCols
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
}
object DistinctRecordCount {
private[agent] val measureName: String = "distinctCount"
def apply(controlCols: Seq[String]): DistinctRecordCount = DistinctRecordCount(measureName, controlCols)
}

case class SumOfValuesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class SumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(column)

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(col(valueColumnName))
val resultValue = aggregateColumn(ds, measuredColumn, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
}
}
object SumOfValuesOfColumn extends MeasureType {
def apply(measuredColumn: String): SumOfValuesOfColumn = {
SumOfValuesOfColumn(measuredColumn, measureName, resultValueType)
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override val measureName: String = "aggregatedTotal"
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
}
object SumOfValuesOfColumn {
private[agent] val measureName: String = "aggregatedTotal"
def apply(controlCol: String): SumOfValuesOfColumn = SumOfValuesOfColumn(measureName, controlCol)
}

case class AbsSumOfValuesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class AbsSumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(abs(column))

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(abs(col(valueColumnName)))
val resultValue = aggregateColumn(ds, measuredColumn, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
}
}
object AbsSumOfValuesOfColumn extends MeasureType {
def apply(measuredColumn: String): AbsSumOfValuesOfColumn = {
AbsSumOfValuesOfColumn(measuredColumn, measureName, resultValueType)
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override val measureName: String = "absAggregatedTotal"
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Double
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
}
object AbsSumOfValuesOfColumn {
private[agent] val measureName: String = "absAggregatedTotal"
def apply(controlCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, controlCol)
}

case class SumOfHashesOfColumn private (
measuredColumn: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

case class SumOfHashesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnExpression: Column = sum(crc32(col(controlCol).cast("String")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this function also support a list of columns, concatenated.
Extra bonus is * for all columns. 😉

override def function: MeasurementFunction = (ds: DataFrame) => {

val aggregatedColumnName = ds.schema.getClosestUniqueName("sum_of_hashes")
val value = ds
.withColumn(aggregatedColumnName, crc32(col(measuredColumn).cast("String")))
.agg(sum(col(aggregatedColumnName)))
.collect()(0)(0)
val resultValue = if (value == null) "" else value.toString
ResultOfMeasurement(resultValue, ResultValueType.String)
}
}
object SumOfHashesOfColumn extends MeasureType {
def apply(measuredColumn: String): SumOfHashesOfColumn = {
SumOfHashesOfColumn(measuredColumn, measureName, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(Option(resultValue(0)(0)).getOrElse("").toString, resultValueType)
}

override val measureName: String = "hashCrc32"
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.String
}
object SumOfHashesOfColumn {
private[agent] val measureName: String = "hashCrc32"
def apply(controlCol: String): SumOfHashesOfColumn = SumOfHashesOfColumn(measureName, controlCol)
}

/**
* This method aggregates a column of a given data frame using a given aggregation expression.
* The result is converted to a string.
*
* @param df A data frame
* @param measureColumn A column to aggregate
* @param aggExpression An aggregation expression
* @return A string representation of the aggregated value
*/
private def aggregateColumn(
df: DataFrame,
measureColumn: String,
aggExpression: Column
): String = {
val dataType = df.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
private def castForAggregation(
dataType: DataType,
column: Column
): Column = {
dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
// scala> sc.parallelize(List(Long.MaxValue, 1)).toDF.agg(sum("value")).take(1)(0)(0)
// res11: Any = -9223372036854775808
// Converting to BigDecimal fixes the issue
// val ds2 = ds.select(col(measurement.measuredColumn).cast(DecimalType(38, 0)).as("value"))
// ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = df.select(
col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
column.cast(DecimalType(38, 0))
case _: StringType =>
// Support for string type aggregation
val ds2 = df.select(
col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
column.cast(DecimalType(38, 18))
case _ =>
column
}
}

private def handleAggregationResult(dataType: DataType, result: Any): String = {
val aggregatedValue = dataType match {
case _: LongType =>
if (result == null) 0 else result
case _: StringType =>
val value =
if (collected == null) new java.math.BigDecimal(0)
else collected.asInstanceOf[java.math.BigDecimal]
if (result == null) new java.math.BigDecimal(0)
else result.asInstanceOf[java.math.BigDecimal]
value.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
Comment on lines +152 to 162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, this we might want to revisit and rethink. But not part of this PR. 😉

val ds2 = df.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
if (result == null) 0 else result
}
// check if total is required to be presented as larger type - big decimal

workaroundBigDecimalIssues(aggregatedValue)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object Measurement {
object MeasurementProvided {

private def handleSpecificType[T](
measure: Measure,
measure: AtumMeasure,
resultValue: T,
requiredType: ResultValueType.ResultValueType
): MeasurementProvided[T] = {
Expand All @@ -65,7 +65,7 @@ object Measurement {
* @tparam T A type of the result value.
* @return A measurement.
*/
def apply[T](measure: Measure, resultValue: T): Measurement = {
def apply[T](measure: AtumMeasure, resultValue: T): Measurement = {
resultValue match {
case l: Long =>
handleSpecificType[Long](measure, l, ResultValueType.Long)
Expand All @@ -86,11 +86,11 @@ object Measurement {
}
}

/**
* When the Atum Agent itself performs the measurements, using Spark, then in some cases some adjustments are
* needed - thus we are converting the results to strings always - but we need to keep the information about
* the actual type as well.
*/
case class MeasurementByAtum(measure: Measure, resultValue: String, resultType: ResultValueType.ResultValueType)
extends Measurement
/**
* When the Atum Agent itself performs the measurements, using Spark, then in some cases some adjustments are
* needed - thus we are converting the results to strings always - but we need to keep the information about
* the actual type as well.
*/
case class MeasurementByAtum(measure: AtumMeasure, resultValue: String, resultType: ResultValueType.ResultValueType)
extends Measurement
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ private [agent] object MeasurementBuilder {

private [agent] def buildMeasurementDTO(measurement: Measurement): MeasurementDTO = {
val measureName = measurement.measure.measureName
val measuredColumns = Seq(measurement.measure.measuredColumn)
val measureDTO = MeasureDTO(measureName, measuredColumns)

val measureDTO = MeasureDTO(measureName, measurement.measure.controlColumns)
val measureResultDTO = MeasureResultDTO(TypedValue(measurement.resultValue.toString, measurement.resultType))

MeasurementDTO(measureDTO, measureResultDTO)
Expand Down
Loading
Loading