Skip to content

Commit

Permalink
Merge branch 'master' into feature/#120-retrieve-Measures-and-Additio…
Browse files Browse the repository at this point in the history
…nalData-for-a-given-partitioning
  • Loading branch information
TebaleloS authored Mar 11, 2024
2 parents 74af153 + a7f3f64 commit 74993c0
Show file tree
Hide file tree
Showing 22 changed files with 380 additions and 324 deletions.
2 changes: 1 addition & 1 deletion agent/src/main/scala/za/co/absa/atum/agent/AtumAgent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class AtumAgent private[agent] () {
}

/**
* Provides an AtumContext given a `AtumPartitions` instance. Retrieves the data from AtumService API.
* Provides an AtumContext given a `AtumPartitions` instance. Retrieves the data from AtumService API.
*
* Note: if partitioning doesn't exist in the store yet, a new one will be created with the author stored in
* `AtumAgent.currentUser`. If partitioning already exists, this attribute will be ignored because there
Expand Down
20 changes: 10 additions & 10 deletions agent/src/main/scala/za/co/absa/atum/agent/AtumContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package za.co.absa.atum.agent

import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.AtumContext.AtumPartitions
import za.co.absa.atum.agent.model.Measurement.MeasurementByAtum
import za.co.absa.atum.agent.model._
import za.co.absa.atum.model.dto._

Expand Down Expand Up @@ -56,10 +55,11 @@ class AtumContext private[agent] (
agent.getOrCreateAtumSubContext(atumPartitions ++ subPartitions)(this)
}

private def takeMeasurements(df: DataFrame): Set[MeasurementByAtum] = {
private def takeMeasurements(df: DataFrame): Set[MeasurementDTO] = {
measures.map { m =>
val measurementResult = m.function(df)
MeasurementByAtum(m, measurementResult.result, measurementResult.resultType)
// TODO group measurements together: https://github.com/AbsaOSS/atum-service/issues/98
val measureResult = m.function(df)
MeasurementBuilder.buildMeasurementDTO(m, measureResult)
}
}

Expand All @@ -78,7 +78,7 @@ class AtumContext private[agent] (
*/
def createCheckpoint(checkpointName: String, dataToMeasure: DataFrame): AtumContext = {
val startTime = ZonedDateTime.now()
val measurements = takeMeasurements(dataToMeasure)
val measurementDTOs = takeMeasurements(dataToMeasure)
val endTime = ZonedDateTime.now()

val checkpointDTO = CheckpointDTO(
Expand All @@ -89,7 +89,7 @@ class AtumContext private[agent] (
partitioning = AtumPartitions.toSeqPartitionDTO(this.atumPartitions),
processStartTime = startTime,
processEndTime = Some(endTime),
measurements = measurements.map(MeasurementBuilder.buildMeasurementDTO).toSeq
measurements = measurementDTOs
)

agent.saveCheckpoint(checkpointDTO)
Expand All @@ -103,7 +103,7 @@ class AtumContext private[agent] (
* @param measurements the measurements to be included in the checkpoint
* @return the AtumContext after the checkpoint has been created
*/
def createCheckpointOnProvidedData(checkpointName: String, measurements: Seq[Measurement]): AtumContext = {
def createCheckpointOnProvidedData(checkpointName: String, measurements: Map[AtumMeasure, MeasureResult]): AtumContext = {
val dateTimeNow = ZonedDateTime.now()

val checkpointDTO = CheckpointDTO(
Expand All @@ -113,7 +113,7 @@ class AtumContext private[agent] (
partitioning = AtumPartitions.toSeqPartitionDTO(this.atumPartitions),
processStartTime = dateTimeNow,
processEndTime = Some(dateTimeNow),
measurements = measurements.map(MeasurementBuilder.buildMeasurementDTO)
measurements = MeasurementBuilder.buildAndValidateMeasurementsDTO(measurements)
)

agent.saveCheckpoint(checkpointDTO)
Expand Down Expand Up @@ -154,7 +154,7 @@ class AtumContext private[agent] (
/**
* Adds a measure to the AtumContext.
*
* @param measure the measure to be added
* @param newMeasure the measure to be added
*/
def addMeasure(newMeasure: AtumMeasure): AtumContext = {
measures = measures + newMeasure
Expand All @@ -164,7 +164,7 @@ class AtumContext private[agent] (
/**
* Adds multiple measures to the AtumContext.
*
* @param measures set sequence of measures to be added
* @param newMeasures set sequence of measures to be added
*/
def addMeasures(newMeasures: Set[AtumMeasure]): AtumContext = {
measures = measures ++ newMeasures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package za.co.absa.atum.agent.core

import org.apache.spark.sql.DataFrame
import za.co.absa.atum.agent.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType
import za.co.absa.atum.agent.model.MeasureResult

/**
* This trait provides a contract for different measurement processors
Expand All @@ -27,29 +27,19 @@ trait MeasurementProcessor {

/**
* This method is used to compute measure on Spark `Dataframe`.
* @param df: Spark `Dataframe` to be measured.
* @return Result of measurement.
*/
def function: MeasurementFunction

}

/**
* This companion object provides a set of types for measurement processors
*/
object MeasurementProcessor {
/**
* The raw result of measurement is always gonna be string, because we want to avoid some floating point issues
* (overflows, consistent representation of numbers - whether they are coming from Java or Scala world, and more),
* but the actual type is stored alongside the computation because we don't want to lost this information.
* This type alias describes a function that is used to compute measure on Spark `Dataframe`.
* It receives a Spark `Dataframe` to be measured on its input.
*
* @return Result of measurement.
*/
final case class ResultOfMeasurement(result: String, resultType: ResultValueType.ResultValueType)

/**
* This type alias describes a function that is used to compute measure on Spark `Dataframe`.
* @param df: Spark `Dataframe` to be measured.
* @return Result of measurement.
*/
type MeasurementFunction = DataFrame => ResultOfMeasurement

type MeasurementFunction = DataFrame => MeasureResult
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import za.co.absa.atum.agent.exception.AtumAgentException.HttpException
import za.co.absa.atum.model.dto.{AdditionalDataSubmitDTO, AtumContextDTO, CheckpointDTO, PartitioningSubmitDTO}
import za.co.absa.atum.model.utils.SerializationUtils

import scala.util.{Failure, Success, Try}

class HttpDispatcher(config: Config) extends Dispatcher with Logging {

private val serverUrl = config.getString("url")
Expand All @@ -51,7 +49,7 @@ class HttpDispatcher(config: Config) extends Dispatcher with Logging {
val response = backend.send(request)

SerializationUtils.fromJson[AtumContextDTO](
safeResponseBody(response).get
handleResponseBody(response)
)
}

Expand All @@ -62,7 +60,7 @@ class HttpDispatcher(config: Config) extends Dispatcher with Logging {

val response = backend.send(request)

safeResponseBody(response).get
handleResponseBody(response)
}

override def saveAdditionalData(additionalDataSubmitDTO: AdditionalDataSubmitDTO): Unit = {
Expand All @@ -72,13 +70,13 @@ class HttpDispatcher(config: Config) extends Dispatcher with Logging {

val response = backend.send(request)

safeResponseBody(response).get
handleResponseBody(response)
}

def safeResponseBody(response: Response[Either[String, String]]): Try[String] = {
private def handleResponseBody(response: Response[Either[String, String]]): String = {
response.body match {
case Left(body) => Failure(HttpException(response.code.code, body))
case Right(body) => Success(body)
case Left(body) => throw HttpException(response.code.code, body)
case Right(body) => body
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ object AtumAgentException {
*
* @param message A message describing the exception.
*/
case class MeasurementProvidedException(message: String) extends AtumAgentException(message)
case class MeasurementException(message: String) extends AtumAgentException(message)

/**
* This type represents an exception thrown when a measure is not supported by the Atum Agent.
Expand Down
66 changes: 33 additions & 33 deletions agent/src/main/scala/za/co/absa/atum/agent/model/Measure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DecimalType, LongType, StringType}
import org.apache.spark.sql.{Column, DataFrame}
import za.co.absa.atum.agent.core.MeasurementProcessor
import za.co.absa.atum.agent.core.MeasurementProcessor.{MeasurementFunction, ResultOfMeasurement}
import za.co.absa.atum.agent.core.MeasurementProcessor.MeasurementFunction
import za.co.absa.atum.model.dto.MeasureResultDTO.ResultValueType

/**
* Type of different measures to be applied to the columns.
*/
sealed trait Measure {
val measureName: String
def controlColumns: Seq[String]
def measuredColumns: Seq[String]
}

trait AtumMeasure extends Measure with MeasurementProcessor {
val resultValueType: ResultValueType.ResultValueType
val resultValueType: ResultValueType
}

object AtumMeasure {
Expand All @@ -51,83 +51,83 @@ object AtumMeasure {
override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0).toString, resultValueType)
MeasureResult(resultValue(0).toString, resultValueType)
}

override def controlColumns: Seq[String] = Seq.empty
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
override def measuredColumns: Seq[String] = Seq.empty
override val resultValueType: ResultValueType = ResultValueType.Long
}
object RecordCount {
private[agent] val measureName: String = "count"
def apply(): RecordCount = RecordCount(measureName)
}

case class DistinctRecordCount private (measureName: String, controlCols: Seq[String]) extends AtumMeasure {
require(controlCols.nonEmpty, "At least one control column has to be defined.")
case class DistinctRecordCount private (measureName: String, measuredCols: Seq[String]) extends AtumMeasure {
require(measuredCols.nonEmpty, "At least one measured column has to be defined.")

private val columnExpression = countDistinct(col(controlCols.head), controlCols.tail.map(col): _*)
private val columnExpression = countDistinct(col(measuredCols.head), measuredCols.tail.map(col): _*)

override def function: MeasurementFunction =
(ds: DataFrame) => {
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(resultValue(0)(0).toString, resultValueType)
MeasureResult(resultValue(0)(0).toString, resultValueType)
}

override def controlColumns: Seq[String] = controlCols
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.Long
override def measuredColumns: Seq[String] = measuredCols
override val resultValueType: ResultValueType = ResultValueType.Long
}
object DistinctRecordCount {
private[agent] val measureName: String = "distinctCount"
def apply(controlCols: Seq[String]): DistinctRecordCount = DistinctRecordCount(measureName, controlCols)
def apply(measuredCols: Seq[String]): DistinctRecordCount = DistinctRecordCount(measureName, measuredCols)
}

case class SumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
case class SumOfValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(column)

override def function: MeasurementFunction = (ds: DataFrame) => {
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
val dataType = ds.select(measuredCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
override def measuredColumns: Seq[String] = Seq(measuredCol)
override val resultValueType: ResultValueType = ResultValueType.BigDecimal
}
object SumOfValuesOfColumn {
private[agent] val measureName: String = "aggregatedTotal"
def apply(controlCol: String): SumOfValuesOfColumn = SumOfValuesOfColumn(measureName, controlCol)
def apply(measuredCol: String): SumOfValuesOfColumn = SumOfValuesOfColumn(measureName, measuredCol)
}

case class AbsSumOfValuesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
case class AbsSumOfValuesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
private val columnAggFn: Column => Column = column => sum(abs(column))

override def function: MeasurementFunction = (ds: DataFrame) => {
val dataType = ds.select(controlCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(controlCol)))).collect()
ResultOfMeasurement(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
val dataType = ds.select(measuredCol).schema.fields(0).dataType
val resultValue = ds.select(columnAggFn(castForAggregation(dataType, col(measuredCol)))).collect()
MeasureResult(handleAggregationResult(dataType, resultValue(0)(0)), resultValueType)
}

override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.BigDecimal
override def measuredColumns: Seq[String] = Seq(measuredCol)
override val resultValueType: ResultValueType = ResultValueType.BigDecimal
}
object AbsSumOfValuesOfColumn {
private[agent] val measureName: String = "absAggregatedTotal"
def apply(controlCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, controlCol)
def apply(measuredCol: String): AbsSumOfValuesOfColumn = AbsSumOfValuesOfColumn(measureName, measuredCol)
}

case class SumOfHashesOfColumn private (measureName: String, controlCol: String) extends AtumMeasure {
private val columnExpression: Column = sum(crc32(col(controlCol).cast("String")))
case class SumOfHashesOfColumn private (measureName: String, measuredCol: String) extends AtumMeasure {
private val columnExpression: Column = sum(crc32(col(measuredCol).cast("String")))
override def function: MeasurementFunction = (ds: DataFrame) => {
val resultValue = ds.select(columnExpression).collect()
ResultOfMeasurement(Option(resultValue(0)(0)).getOrElse("").toString, resultValueType)
MeasureResult(Option(resultValue(0)(0)).getOrElse("").toString, resultValueType)
}

override def controlColumns: Seq[String] = Seq(controlCol)
override val resultValueType: ResultValueType.ResultValueType = ResultValueType.String
override def measuredColumns: Seq[String] = Seq(measuredCol)
override val resultValueType: ResultValueType = ResultValueType.String
}
object SumOfHashesOfColumn {
private[agent] val measureName: String = "hashCrc32"
def apply(controlCol: String): SumOfHashesOfColumn = SumOfHashesOfColumn(measureName, controlCol)
def apply(measuredCol: String): SumOfHashesOfColumn = SumOfHashesOfColumn(measureName, measuredCol)
}

private def castForAggregation(
Expand Down
Loading

0 comments on commit 74993c0

Please sign in to comment.