Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/100 0 n columns #111

Merged
merged 12 commits into from
Jan 2, 2024
22 changes: 13 additions & 9 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package za.co.absa.atum.agent

import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.AtumContext.AtumPartitions
import za.co.absa.atum.agent.model.Measurement.MeasurementByAtum
import za.co.absa.atum.agent.model.Measurement.MeasurementProvided.MeasurementByAtum
import za.co.absa.atum.agent.model._
import za.co.absa.atum.model.dto._

Expand All @@ -35,11 +35,11 @@ import scala.collection.immutable.ListMap
class AtumContext private[agent] (
val atumPartitions: AtumPartitions,
val agent: AtumAgent,
private var measures: Set[Measure] = Set.empty,
private var measures: Set[AtumMeasure] = Set.empty,
private var additionalData: Map[String, Option[String]] = Map.empty
) {

def currentMeasures: Set[Measure] = measures
def currentMeasures: Set[AtumMeasure] = measures

def subPartitionContext(subPartitions: AtumPartitions): AtumContext = {
agent.getOrCreateAtumSubContext(atumPartitions ++ subPartitions)(this)
Expand Down Expand Up @@ -72,7 +72,11 @@ class AtumContext private[agent] (
this
}

def createCheckpointOnProvidedData(checkpointName: String, author: String, measurements: Seq[Measurement]): AtumContext = {
def createCheckpointOnProvidedData(
checkpointName: String,
author: String,
measurements: Seq[Measurement]
): AtumContext = {
val offsetDateTimeNow = OffsetDateTime.now()

val checkpointDTO = CheckpointDTO(
Expand All @@ -97,25 +101,25 @@ class AtumContext private[agent] (
this.additionalData
}

def addMeasure(newMeasure: Measure): AtumContext = {
def addMeasure(newMeasure: AtumMeasure): AtumContext = {
measures = measures + newMeasure
this
}

def addMeasures(newMeasures: Set[Measure]): AtumContext = {
def addMeasures(newMeasures: Set[AtumMeasure]): AtumContext = {
measures = measures ++ newMeasures
this
}

def removeMeasure(measureToRemove: Measure): AtumContext = {
def removeMeasure(measureToRemove: AtumMeasure): AtumContext = {
measures = measures - measureToRemove
this
}

private[agent] def copy(
atumPartitions: AtumPartitions = this.atumPartitions,
agent: AtumAgent = this.agent,
measures: Set[Measure] = this.measures,
measures: Set[AtumMeasure] = this.measures,
additionalData: Map[String, Option[String]] = this.additionalData
): AtumContext = {
new AtumContext(atumPartitions, agent, measures, additionalData)
Expand All @@ -131,7 +135,7 @@ object AtumContext {
}

def apply(elems: Seq[(String, String)]): AtumPartitions = {
ListMap(elems:_*)
ListMap(elems: _*)
}

private[agent] def toSeqPartitionDTO(atumPartitions: AtumPartitions): Seq[PartitionDTO] = {
Expand Down
200 changes: 87 additions & 113 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,182 +17,156 @@
package za.co.absa.atum.agent.model

import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DecimalType, LongType, StringType}
import org.apache.spark.sql.types.{DataType, DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.atum.agent.core.MeasurementProcessor
import za.co.absa.atum.agent.core.MeasurementProcessor.{MeasurementFunction, ResultOfMeasurement}
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.spark.commons.implicits.StructTypeImplicits.StructTypeEnhancements

/**
* Type of different measures to be applied to the columns.
*/
sealed trait Measure extends MeasurementProcessor with MeasureType {
val controlCol: String
sealed trait Measure {
val measureName: String
def controlColumns: Seq[String]
salamonpavel marked this conversation as resolved.
Show resolved Hide resolved
}

trait MeasureType {
val measureName: String
val resultValueType: ResultValueType.ResultValueType
case class CustomMeasure(measureName: String, controlCols: Seq[String]) extends Measure {
override def controlColumns: Seq[String] = controlCols
}

object Measure {
trait AtumMeasure extends Measure with MeasurementProcessor {
val resultValueType: ResultValueType.ResultValueType
}

private val valueColumnName: String = "value"
object AtumMeasure {

val supportedMeasures: Seq[MeasureType] = Seq(
RecordCount,
DistinctRecordCount,
SumOfValuesOfColumn,
AbsSumOfValuesOfColumn,
SumOfHashesOfColumn
val supportedMeasureNames: Seq[String] = Seq(
salamonpavel marked this conversation as resolved.
Show resolved Hide resolved
RecordCount.measureName,
DistinctRecordCount.measureName,
SumOfValuesOfColumn.measureName,
AbsSumOfValuesOfColumn.measureName,
SumOfHashesOfColumn.measureName
)
val supportedMeasureNames: Seq[String] = supportedMeasures.map(_.measureName)

case class RecordCount private (
controlCol: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class RecordCount private (measureName: String) extends AtumMeasure {
private val columnExpression = count("*")

override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(controlCol)).count().toString
ResultOfMeasurement(resultValue, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0).toString, resultValueType)
}
}
object RecordCount extends MeasureType {
def apply(controlCol: String): RecordCount = RecordCount(controlCol, measureName, resultValueType)

override val measureName: String = "count"
override def controlColumns: Seq[String] = Seq.empty
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
}
object RecordCount {
private[agent] val measureName: String = "count"
def apply(): RecordCount = RecordCount(measureName)
}

case class DistinctRecordCount private (measureName: String, controlCols: Seq[String]) extends AtumMeasure {
require(controlCols.nonEmpty, "At least one control column has to be defined.")

case class DistinctRecordCount private (
controlCol: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
private val columnExpression = countDistinct(col(controlCols.head), controlCols.tail.map(col): _*)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not sure if this is what we want. I thought that it's about 'distinct -> count' operation, performed on a single column. Do you remember what was used in the old Atum & what can our users require @benedeki ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the suggested code it can be performed on single column or multiple columns. If this particular measure is expected to work differently please let me know. @benedeki

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the universality of supporting multi-column uniqueness 👍


override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(col(controlCol)).distinct().count().toString
ResultOfMeasurement(resultValue, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0)(0).toString, resultValueType)
}
}
object DistinctRecordCount extends MeasureType {
def apply(controlCol: String): DistinctRecordCount = {
DistinctRecordCount(controlCol, measureName, resultValueType)
}

override val measureName: String = "distinctCount"
override def controlColumns: Seq[String] = controlCols
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
}
object DistinctRecordCount {
private[agent] val measureName: String = "distinctCount"
def apply(controlCols: Seq[String]): DistinctRecordCount = DistinctRecordCount(measureName, controlCols)
}

case class SumOfValuesOfColumn private (
controlCol: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class SumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(column)

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(col(valueColumnName))
val resultValue = aggregateColumn(ds, controlCol, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
}
}
object SumOfValuesOfColumn extends MeasureType {
def apply(controlCol: String): SumOfValuesOfColumn = {
SumOfValuesOfColumn(controlCol, measureName, resultValueType)
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override val measureName: String = "aggregatedTotal"
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
}
object SumOfValuesOfColumn {
private[agent] val measureName: String = "aggregatedTotal"
def apply(controlCol: String): SumOfValuesOfColumn = SumOfValuesOfColumn(measureName, controlCol)
}

case class AbsSumOfValuesOfColumn private (
controlCol: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {
case class AbsSumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(abs(column))

override def function: MeasurementFunction = (ds: DataFrame) => {
val aggCol = sum(abs(col(valueColumnName)))
val resultValue = aggregateColumn(ds, controlCol, aggCol)
ResultOfMeasurement(resultValue, resultValueType)
}
}
object AbsSumOfValuesOfColumn extends MeasureType {
def apply(controlCol: String): AbsSumOfValuesOfColumn = {
AbsSumOfValuesOfColumn(controlCol, measureName, resultValueType)
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override val measureName: String = "absAggregatedTotal"
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Double
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
}
object AbsSumOfValuesOfColumn {
private[agent] val measureName: String = "absAggregatedTotal"
def apply(controlCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, controlCol)
}

case class SumOfHashesOfColumn private (
controlCol: String,
measureName: String,
resultValueType: ResultValueType.ResultValueType
) extends Measure {

case class SumOfHashesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnExpression: Column = sum(crc32(col(controlCol).cast("String")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this function also support a list of columns, concatenated.
Extra bonus is * for all columns. 😉

override def function: MeasurementFunction = (ds: DataFrame) => {

val aggregatedColumnName = ds.schema.getClosestUniqueName("sum_of_hashes")
val value = ds
.withColumn(aggregatedColumnName, crc32(col(controlCol).cast("String")))
.agg(sum(col(aggregatedColumnName)))
.collect()(0)(0)
val resultValue = if (value == null) "" else value.toString
ResultOfMeasurement(resultValue, ResultValueType.String)
}
}
object SumOfHashesOfColumn extends MeasureType {
def apply(controlCol: String): SumOfHashesOfColumn = {
SumOfHashesOfColumn(controlCol, measureName, resultValueType)
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(Option(resultValue(0)(0)).getOrElse("").toString, resultValueType)
}

override val measureName: String = "hashCrc32"
override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.String
}
object SumOfHashesOfColumn {
private[agent] val measureName: String = "hashCrc32"
def apply(controlCol: String): SumOfHashesOfColumn = SumOfHashesOfColumn(measureName, controlCol)
}

private def aggregateColumn(
ds: DataFrame,
measureColumn: String,
aggExpression: Column
): String = {
val dataType = ds.select(measureColumn).schema.fields(0).dataType
val aggregatedValue = dataType match {
private def castForAggregation(
dataType: DataType,
column: Column
): Column = {
dataType match {
case _: LongType =>
// This is protection against long overflow, e.g. Long.MaxValue = 9223372036854775807:
// scala> sc.parallelize(List(Long.MaxValue, 1)).toDF.agg(sum("value")).take(1)(0)(0)
// res11: Any = -9223372036854775808
// Converting to BigDecimal fixes the issue
// val ds2 = ds.select(col(measurement.controlCol).cast(DecimalType(38, 0)).as("value"))
// ds2.agg(sum(abs($"value"))).collect()(0)(0)
val ds2 = ds.select(
col(measureColumn).cast(DecimalType(38, 0)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
column.cast(DecimalType(38, 0))
case _: StringType =>
// Support for string type aggregation
val ds2 = ds.select(
col(measureColumn).cast(DecimalType(38, 18)).as(valueColumnName)
)
val collected = ds2.agg(aggExpression).collect()(0)(0)
column.cast(DecimalType(38, 18))
case _ =>
column
}
}

private def handleAggregationResult(dataType: DataType, result: Any): String = {
val aggregatedValue = dataType match {
case _: LongType =>
if (result == null) 0 else result
case _: StringType =>
val value =
if (collected == null) new java.math.BigDecimal(0)
else collected.asInstanceOf[java.math.BigDecimal]
if (result == null) new java.math.BigDecimal(0)
else result.asInstanceOf[java.math.BigDecimal]
value.stripTrailingZeros // removes trailing zeros (2001.500000 -> 2001.5, but can introduce scientific notation (600.000 -> 6E+2)
.toPlainString // converts to normal string (6E+2 -> "600")
case _ =>
Comment on lines +152 to 162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, this we might want to revisit and rethink. But not part of this PR. 😉

val ds2 = ds.select(col(measureColumn).as(valueColumnName))
val collected = ds2.agg(aggExpression).collect()(0)(0)
if (collected == null) 0 else collected
if (result == null) 0 else result
}
// check if total is required to be presented as larger type - big decimal

workaroundBigDecimalIssues(aggregatedValue)
}

Expand Down
Loading