Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetricsRepository using Spark tables as the data source #518

Merged
merged 2 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.iceberg</groupId>
<artifactId>iceberg-spark-runtime-3.3_2.12</artifactId>
<version>0.14.0</version>
<scope>test</scope>
</dependency>

VenkataKarthikP marked this conversation as resolved.
Show resolved Hide resolved
</dependencies>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.repository.sparktable

import com.amazon.deequ.analyzers.Analyzer
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository {

import session.implicits._

override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = {
val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext)))

val successfulMetrics = analyzerContext.metricMap
.filter { case (_, metric) => metric.value.isSuccess }

val metricDF = successfulMetrics.map { case (analyzer, metric) =>
SparkTableMetric(resultKey.toString, analyzer.toString, metric.value.toString,
resultKey.dataSetDate, serializedContext)
}.toSeq.toDF()

metricDF.write
.mode(SaveMode.Append)
.saveAsTable(tableName)
Comment on lines +40 to +42
Copy link

@SirWerto SirWerto Nov 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @VenkataKarthikP, thanks a lot for the added feature!!!

Could we consider to let the user define the write options in some way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I think it's great idea. let me check!

}

override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = {
val df: DataFrame = session.table(tableName)
val matchingRows = df.filter(col("resultKey") === resultKey.toString).collect()

if (matchingRows.isEmpty) {
None
} else {
val serializedContext = matchingRows(0).getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(serializedContext).headOption.map(_.analyzerContext)
}
}

override def load(): MetricsRepositoryMultipleResultsLoader = {
SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName)
}

}

case class SparkTableMetric(resultKey: String, metricName: String, metricValue: String, resultTimestamp: Long,
serializedContext: String)

case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession,
tableName: String,
private val tagValues: Option[Map[String, String]] = None,
private val analyzers: Option[Seq[Analyzer[_, Metric[_]]]]
= None,
private val timeAfter: Option[Long] = None,
private val timeBefore: Option[Long] = None
) extends MetricsRepositoryMultipleResultsLoader {

override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader =
this.copy(tagValues = Some(tagValues))

override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader =
this.copy(analyzers = Some(analyzers))

override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
this.copy(timeAfter = Some(dateTime))

override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
this.copy(timeBefore = Some(dateTime))
VenkataKarthikP marked this conversation as resolved.
Show resolved Hide resolved

override def get(): Seq[AnalysisResult] = {
val initialDF: DataFrame = session.table(tableName)

val tagValuesFilter: DataFrame => DataFrame = df => {
tagValues.map { tags =>
tags.foldLeft(df) { (currentDF, tag) =>
currentDF.filter(row => {
val ser = row.getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(ser).exists(ar => {
val tags = ar.resultKey.tags
tags.contains(tag._1) && tags(tag._1) == tag._2
})
})
}
}.getOrElse(df)
}

val specificAnalyzersFilter: DataFrame => DataFrame = df => {
analyzers.map(analyzers => df.filter(col("metricName").isin(analyzers.map(_.toString): _*)))
.getOrElse(df)
}

val timeAfterFilter: DataFrame => DataFrame = df => {
timeAfter.map(time => df.filter(col("resultTimestamp") > time.toString)).getOrElse(df)
}

val timeBeforeFilter: DataFrame => DataFrame = df => {
timeBefore.map(time => df.filter(col("resultTimestamp") < time.toString)).getOrElse(df)
}

val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter)
.foldLeft(initialDF) {
(df, filter) => filter(df)
}

// Convert the final DataFrame to the desired output format
filteredDF.collect().flatMap(row => {
val serializedContext = row.getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(serializedContext)
}).toSeq
}


}
55 changes: 49 additions & 6 deletions src/test/scala/com/amazon/deequ/SparkContextSpec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
Expand All @@ -19,16 +19,44 @@ package com.amazon.deequ
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}

import java.nio.file.{Files, Path}
import scala.collection.convert.ImplicitConversions.`iterator asScala`

/**
* To be mixed with Tests so they can use a default spark context suitable for testing
*/
trait SparkContextSpec {

val tmpWareHouseDir: Path = Files.createTempDirectory("deequ_tmp")

/**
* @param testFun thunk to run with SparkSession as an argument
*/
def withSparkSession(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession
val session = setupSparkSession()
try {
testFun(session)
} finally {
/* empty cache of RDD size, as the referred ids are only valid within a session */
tearDownSparkSession(session)
}
}

def withSparkSessionCustomWareHouse(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
try {
testFun(session)
} finally {
tearDownSparkSession(session)
}
}

def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
session.conf.set("spark.sql.catalog.local.type", "hadoop")
session.conf.set("spark.sql.catalog.local.warehouse", tmpWareHouseDir.toAbsolutePath.toString)

try {
testFun(session)
} finally {
Expand All @@ -44,7 +72,7 @@ trait SparkContextSpec {
*/
def withMonitorableSparkSession(testFun: (SparkSession, SparkMonitor) => Any): Unit = {
val monitor = new SparkMonitor
val session = setupSparkSession
val session = setupSparkSession()
session.sparkContext.addSparkListener(monitor)
try {
testFun(session, monitor)
Expand Down Expand Up @@ -72,19 +100,32 @@ trait SparkContextSpec {
*
* @return sparkSession to be used
*/
private def setupSparkSession = {
val session = SparkSession.builder()
private def setupSparkSession(wareHouseDir: Option[String] = None) = {
val sessionBuilder = SparkSession.builder()
.master("local")
.appName("test")
.config("spark.ui.enabled", "false")
.config("spark.sql.shuffle.partitions", 2.toString)
.config("spark.sql.adaptive.enabled", value = false)
.config("spark.driver.bindAddress", "127.0.0.1")
.getOrCreate()

val session = wareHouseDir.fold(sessionBuilder.getOrCreate())(sessionBuilder
.config("spark.sql.warehouse.dir", _).getOrCreate())

session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir"))
session
}

/**
* to cleanup temp directory used in test
* @param path - path to cleanup
*/
private def deleteDirectory(path: Path): Unit = {
if (Files.exists(path)) {
Files.walk(path).iterator().toList.reverse.foreach(Files.delete)
}
}

/**
* Tears down the sparkSession
*
Expand All @@ -94,6 +135,8 @@ trait SparkContextSpec {
private def tearDownSparkSession(session: SparkSession) = {
session.stop()
System.clearProperty("spark.driver.port")
deleteDirectory(tmpWareHouseDir)

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.deequ.repository.sparktable

import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.analyzers.Size
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.metrics.{DoubleMetric, Entity}
import com.amazon.deequ.repository.ResultKey
import com.amazon.deequ.utils.FixtureSupport
import org.scalatest.wordspec.AnyWordSpec

import scala.util.Try

class SparkTableMetricsRepositoryTest extends AnyWordSpec
with SparkContextSpec
with FixtureSupport {

// private var spark: SparkSession = _
// private var repository: SparkTableMetricsRepository = _
private val analyzer = Size()

"spark table metrics repository " should {
"save and load a single metric" in withSparkSessionCustomWareHouse { spark =>
val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context = AnalyzerContext(Map(analyzer -> metric))

val repository = new SparkTableMetricsRepository(spark, "metrics_table")
// Save the metric
repository.save(resultKey, context)

// Load the metric
val loadedContext = repository.loadByKey(resultKey)

assert(loadedContext.isDefined)
assert(loadedContext.get.metric(analyzer).contains(metric))

}

"save multiple metrics and load them" in withSparkSessionCustomWareHouse { spark =>
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context1 = AnalyzerContext(Map(analyzer -> metric))

val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2"))
val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
val context2 = AnalyzerContext(Map(analyzer -> metric2))

repository.save(resultKey1, context1)
repository.save(resultKey2, context2)

val loadedMetrics = repository.load().get()

assert(loadedMetrics.length == 2)

loadedMetrics.flatMap(_.resultKey.tags)

}

"save and load metrics with tag" in withSparkSessionCustomWareHouse { spark =>
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context1 = AnalyzerContext(Map(analyzer -> metric))

val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B"))
val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
val context2 = AnalyzerContext(Map(analyzer -> metric2))

repository.save(resultKey1, context1)
repository.save(resultKey2, context2)
val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get()
assert(loadedMetricsForTagA.length == 1)

val tagsMapA = loadedMetricsForTagA.flatMap(_.resultKey.tags).toMap
assert(tagsMapA.size == 1, "should have 1 result")
assert(tagsMapA.contains("tag"), "should contain tag")
assert(tagsMapA("tag") == "A", "tag should be A")

val loadedMetricsForAllMetrics = repository.load().forAnalyzers(Seq(analyzer)).get()
assert(loadedMetricsForAllMetrics.length == 2, "should have 2 results")

}

"save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {
val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context = AnalyzerContext(Map(analyzer -> metric))

val repository = new SparkTableMetricsRepository(spark, "local.metrics_table")
// Save the metric
repository.save(resultKey, context)

// Load the metric
val loadedContext = repository.loadByKey(resultKey)

assert(loadedContext.isDefined)
assert(loadedContext.get.metric(analyzer).contains(metric))
}

}
}
}