Skip to content

Commit

Permalink
MetricsRepository using Spark tables as the data source (#518)
Browse files Browse the repository at this point in the history
* spark table repository

* review comments

---------

Co-authored-by: vpenikalapati <[email protected]>
  • Loading branch information
VenkataKarthikP and vpenikalapati authored Nov 28, 2023
1 parent 54c5e48 commit 1fc09e1
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 6 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.iceberg</groupId>
<artifactId>iceberg-spark-runtime-3.3_2.12</artifactId>
<version>0.14.0</version>
<scope>test</scope>
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.repository.sparktable

import com.amazon.deequ.analyzers.Analyzer
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository {

import session.implicits._

override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = {
val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext)))

val successfulMetrics = analyzerContext.metricMap
.filter { case (_, metric) => metric.value.isSuccess }

val metricDF = successfulMetrics.map { case (analyzer, metric) =>
SparkTableMetric(resultKey.toString, analyzer.toString, metric.value.toString,
resultKey.dataSetDate, serializedContext)
}.toSeq.toDF()

metricDF.write
.mode(SaveMode.Append)
.saveAsTable(tableName)
}

override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = {
val df: DataFrame = session.table(tableName)
val matchingRows = df.filter(col("resultKey") === resultKey.toString).collect()

if (matchingRows.isEmpty) {
None
} else {
val serializedContext = matchingRows(0).getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(serializedContext).headOption.map(_.analyzerContext)
}
}

override def load(): MetricsRepositoryMultipleResultsLoader = {
SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName)
}

}

case class SparkTableMetric(resultKey: String, metricName: String, metricValue: String, resultTimestamp: Long,
serializedContext: String)

case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession,
tableName: String,
private val tagValues: Option[Map[String, String]] = None,
private val analyzers: Option[Seq[Analyzer[_, Metric[_]]]]
= None,
private val timeAfter: Option[Long] = None,
private val timeBefore: Option[Long] = None
) extends MetricsRepositoryMultipleResultsLoader {

override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader =
this.copy(tagValues = Some(tagValues))

override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader =
this.copy(analyzers = Some(analyzers))

override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
this.copy(timeAfter = Some(dateTime))

override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
this.copy(timeBefore = Some(dateTime))

override def get(): Seq[AnalysisResult] = {
val initialDF: DataFrame = session.table(tableName)

val tagValuesFilter: DataFrame => DataFrame = df => {
tagValues.map { tags =>
tags.foldLeft(df) { (currentDF, tag) =>
currentDF.filter(row => {
val ser = row.getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(ser).exists(ar => {
val tags = ar.resultKey.tags
tags.contains(tag._1) && tags(tag._1) == tag._2
})
})
}
}.getOrElse(df)
}

val specificAnalyzersFilter: DataFrame => DataFrame = df => {
analyzers.map(analyzers => df.filter(col("metricName").isin(analyzers.map(_.toString): _*)))
.getOrElse(df)
}

val timeAfterFilter: DataFrame => DataFrame = df => {
timeAfter.map(time => df.filter(col("resultTimestamp") > time.toString)).getOrElse(df)
}

val timeBeforeFilter: DataFrame => DataFrame = df => {
timeBefore.map(time => df.filter(col("resultTimestamp") < time.toString)).getOrElse(df)
}

val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter)
.foldLeft(initialDF) {
(df, filter) => filter(df)
}

// Convert the final DataFrame to the desired output format
filteredDF.collect().flatMap(row => {
val serializedContext = row.getAs[String]("serializedContext")
AnalysisResultSerde.deserialize(serializedContext)
}).toSeq
}


}
55 changes: 49 additions & 6 deletions src/test/scala/com/amazon/deequ/SparkContextSpec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
Expand All @@ -19,16 +19,44 @@ package com.amazon.deequ
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}

import java.nio.file.{Files, Path}
import scala.collection.convert.ImplicitConversions.`iterator asScala`

/**
* To be mixed with Tests so they can use a default spark context suitable for testing
*/
trait SparkContextSpec {

val tmpWareHouseDir: Path = Files.createTempDirectory("deequ_tmp")

/**
* @param testFun thunk to run with SparkSession as an argument
*/
def withSparkSession(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession
val session = setupSparkSession()
try {
testFun(session)
} finally {
/* empty cache of RDD size, as the referred ids are only valid within a session */
tearDownSparkSession(session)
}
}

def withSparkSessionCustomWareHouse(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
try {
testFun(session)
} finally {
tearDownSparkSession(session)
}
}

def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession(Some(tmpWareHouseDir.toAbsolutePath.toString))
session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
session.conf.set("spark.sql.catalog.local.type", "hadoop")
session.conf.set("spark.sql.catalog.local.warehouse", tmpWareHouseDir.toAbsolutePath.toString)

try {
testFun(session)
} finally {
Expand All @@ -44,7 +72,7 @@ trait SparkContextSpec {
*/
def withMonitorableSparkSession(testFun: (SparkSession, SparkMonitor) => Any): Unit = {
val monitor = new SparkMonitor
val session = setupSparkSession
val session = setupSparkSession()
session.sparkContext.addSparkListener(monitor)
try {
testFun(session, monitor)
Expand Down Expand Up @@ -72,19 +100,32 @@ trait SparkContextSpec {
*
* @return sparkSession to be used
*/
private def setupSparkSession = {
val session = SparkSession.builder()
private def setupSparkSession(wareHouseDir: Option[String] = None) = {
val sessionBuilder = SparkSession.builder()
.master("local")
.appName("test")
.config("spark.ui.enabled", "false")
.config("spark.sql.shuffle.partitions", 2.toString)
.config("spark.sql.adaptive.enabled", value = false)
.config("spark.driver.bindAddress", "127.0.0.1")
.getOrCreate()

val session = wareHouseDir.fold(sessionBuilder.getOrCreate())(sessionBuilder
.config("spark.sql.warehouse.dir", _).getOrCreate())

session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir"))
session
}

/**
* to cleanup temp directory used in test
* @param path - path to cleanup
*/
private def deleteDirectory(path: Path): Unit = {
if (Files.exists(path)) {
Files.walk(path).iterator().toList.reverse.foreach(Files.delete)
}
}

/**
* Tears down the sparkSession
*
Expand All @@ -94,6 +135,8 @@ trait SparkContextSpec {
private def tearDownSparkSession(session: SparkSession) = {
session.stop()
System.clearProperty("spark.driver.port")
deleteDirectory(tmpWareHouseDir)

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/**
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.deequ.repository.sparktable

import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.analyzers.Size
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.metrics.{DoubleMetric, Entity}
import com.amazon.deequ.repository.ResultKey
import com.amazon.deequ.utils.FixtureSupport
import org.scalatest.wordspec.AnyWordSpec

import scala.util.Try

class SparkTableMetricsRepositoryTest extends AnyWordSpec
with SparkContextSpec
with FixtureSupport {

// private var spark: SparkSession = _
// private var repository: SparkTableMetricsRepository = _
private val analyzer = Size()

"spark table metrics repository " should {
"save and load a single metric" in withSparkSessionCustomWareHouse { spark =>
val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context = AnalyzerContext(Map(analyzer -> metric))

val repository = new SparkTableMetricsRepository(spark, "metrics_table")
// Save the metric
repository.save(resultKey, context)

// Load the metric
val loadedContext = repository.loadByKey(resultKey)

assert(loadedContext.isDefined)
assert(loadedContext.get.metric(analyzer).contains(metric))

}

"save multiple metrics and load them" in withSparkSessionCustomWareHouse { spark =>
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context1 = AnalyzerContext(Map(analyzer -> metric))

val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2"))
val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
val context2 = AnalyzerContext(Map(analyzer -> metric2))

repository.save(resultKey1, context1)
repository.save(resultKey2, context2)

val loadedMetrics = repository.load().get()

assert(loadedMetrics.length == 2)

loadedMetrics.flatMap(_.resultKey.tags)

}

"save and load metrics with tag" in withSparkSessionCustomWareHouse { spark =>
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context1 = AnalyzerContext(Map(analyzer -> metric))

val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B"))
val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
val context2 = AnalyzerContext(Map(analyzer -> metric2))

repository.save(resultKey1, context1)
repository.save(resultKey2, context2)
val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get()
assert(loadedMetricsForTagA.length == 1)

val tagsMapA = loadedMetricsForTagA.flatMap(_.resultKey.tags).toMap
assert(tagsMapA.size == 1, "should have 1 result")
assert(tagsMapA.contains("tag"), "should contain tag")
assert(tagsMapA("tag") == "A", "tag should be A")

val loadedMetricsForAllMetrics = repository.load().forAnalyzers(Seq(analyzer)).get()
assert(loadedMetricsForAllMetrics.length == 2, "should have 2 results")

}

"save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {
val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context = AnalyzerContext(Map(analyzer -> metric))

val repository = new SparkTableMetricsRepository(spark, "local.metrics_table")
// Save the metric
repository.save(resultKey, context)

// Load the metric
val loadedContext = repository.loadByKey(resultKey)

assert(loadedContext.isDefined)
assert(loadedContext.get.metric(analyzer).contains(metric))
}

}
}
}

0 comments on commit 1fc09e1

Please sign in to comment.