Skip to content

Commit

Permalink
spark table repository
Browse files Browse the repository at this point in the history
  • Loading branch information
vpenikalapati authored and VenkataKarthikP committed Oct 30, 2023
1 parent d7eb316 commit 209eba3
Show file tree
Hide file tree
Showing 4 changed files with 292 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.apache.iceberg</groupId>
<artifactId>iceberg-spark-runtime-3.3_2.12</artifactId>
<version>0.14.0</version>
<scope>test</scope>
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/
package com.amazon.deequ.repository.sparktable

import com.amazon.deequ.analyzers.Analyzer
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.metrics.Metric
import com.amazon.deequ.repository._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}

class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository {

private val SCHEMA = StructType(Array(
StructField("result_key", StringType),
StructField("metric_name", StringType),
StructField("metric_value", StringType),
StructField("result_timestamp", StringType),
StructField("serialized_context", StringType)
))

override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = {
val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext)))

val rows = analyzerContext.metricMap.map { case (analyzer, metric) =>
Row(resultKey.toString, analyzer.toString, metric.value.toString,
resultKey.dataSetDate.toString, serializedContext)
}.toSeq

val metricDF = session.createDataFrame(session.sparkContext.parallelize(rows), SCHEMA)

metricDF.write
.mode(SaveMode.Append)
.saveAsTable(tableName)
}

override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = {
val df: DataFrame = session.table(tableName)
val matchingRows = df.filter(col("result_key") === resultKey.toString).collect()

if (matchingRows.isEmpty) {
None
} else {
val serializedContext = matchingRows(0).getAs[String]("serialized_context")
val analysisResult = AnalysisResultSerde.deserialize(serializedContext).head
Some(analysisResult.analyzerContext)
}
}

override def load(): MetricsRepositoryMultipleResultsLoader = {
SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName)
}

}


case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession,
tableName: String,
tagValues: Option[Map[String, String]] = None,
analyzers: Option[Seq[Analyzer[_, Metric[_]]]] = None,
timeAfter: Option[Long] = None,
timeBefore: Option[Long] = None
) extends MetricsRepositoryMultipleResultsLoader {

override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader =
this.copy(tagValues = Some(tagValues))

override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader =
this.copy(analyzers = Some(analyzers))

override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
this.copy(timeAfter = Some(dateTime))

override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
this.copy(timeBefore = Some(dateTime))

override def get(): Seq[AnalysisResult] = {
val initialDF: DataFrame = session.table(tableName)

initialDF.printSchema()
val tagValuesFilter: DataFrame => DataFrame = df => {
tagValues.map { tags =>
tags.foldLeft(df) { (currentDF, tag) =>
currentDF.filter(row => {
val ser = row.getAs[String]("serialized_context")
AnalysisResultSerde.deserialize(ser).exists(ar => {
val tags = ar.resultKey.tags
tags.contains(tag._1) && tags(tag._1) == tag._2
})
})
}
}.getOrElse(df)
}

val specificAnalyzersFilter: DataFrame => DataFrame = df => {
analyzers.map(analyzers => df.filter(col("metric_name").isin(analyzers.map(_.toString): _*)))
.getOrElse(df)
}

val timeAfterFilter: DataFrame => DataFrame = df => {
timeAfter.map(time => df.filter(col("result_timestamp") > time.toString)).getOrElse(df)
}

val timeBeforeFilter: DataFrame => DataFrame = df => {
timeBefore.map(time => df.filter(col("result_timestamp") < time.toString)).getOrElse(df)
}

val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter)
.foldLeft(initialDF) {
(df, filter) => filter(df)
}

// Convert the final DataFrame to the desired output format
filteredDF.collect().flatMap(row => {
val serializedContext = row.getAs[String]("serialized_context")
AnalysisResultSerde.deserialize(serializedContext)
}).toSeq
}


}
32 changes: 32 additions & 0 deletions src/test/scala/com/amazon/deequ/SparkContextSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@ package com.amazon.deequ
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}

import java.nio.file.{Files, Path}
import scala.collection.convert.ImplicitConversions.`iterator asScala`

/**
* To be mixed with Tests so they can use a default spark context suitable for testing
*/
trait SparkContextSpec {

val warehouseDir: Path = Files.createTempDirectory("my_temp_dir_")

/**
* @param testFun thunk to run with SparkSession as an argument
*/
Expand All @@ -37,6 +42,20 @@ trait SparkContextSpec {
}
}

def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = {
val session = setupSparkSession
session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
session.conf.set("spark.sql.catalog.local.type", "hadoop")
session.conf.set("spark.sql.catalog.local.warehouse", warehouseDir.toAbsolutePath.toString)

try {
testFun(session)
} finally {
/* empty cache of RDD size, as the referred ids are only valid within a session */
tearDownSparkSession(session)
}
}

/**
* @param testFun thunk to run with SparkSession and SparkMonitor as an argument for the tests
* that would like to get details on spark jobs submitted
Expand Down Expand Up @@ -80,11 +99,22 @@ trait SparkContextSpec {
.config("spark.sql.shuffle.partitions", 2.toString)
.config("spark.sql.adaptive.enabled", value = false)
.config("spark.driver.bindAddress", "127.0.0.1")
.config("spark.sql.warehouse.dir", warehouseDir.toAbsolutePath.toString)
.getOrCreate()
session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir"))
session
}

/**
* to cleanup temp directory used in test
* @param path - path to cleanup
*/
private def deleteDirectory(path: Path): Unit = {
if (Files.exists(path)) {
Files.walk(path).iterator().toList.reverse.foreach(Files.delete)
}
}

/**
* Tears down the sparkSession
*
Expand All @@ -94,6 +124,8 @@ trait SparkContextSpec {
private def tearDownSparkSession(session: SparkSession) = {
session.stop()
System.clearProperty("spark.driver.port")
deleteDirectory(warehouseDir)

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/**
* Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not
* use this file except in compliance with the License. A copy of the License
* is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*
*/

package com.amazon.deequ.repository.sparktable

import com.amazon.deequ.SparkContextSpec
import com.amazon.deequ.analyzers.Size
import com.amazon.deequ.analyzers.runners.AnalyzerContext
import com.amazon.deequ.metrics.{DoubleMetric, Entity}
import com.amazon.deequ.repository.ResultKey
import com.amazon.deequ.utils.FixtureSupport
import org.scalatest.wordspec.AnyWordSpec

import scala.util.Try

class SparkTableMetricsRepositoryTest extends AnyWordSpec
with SparkContextSpec
with FixtureSupport {

// private var spark: SparkSession = _
// private var repository: SparkTableMetricsRepository = _
private val analyzer = Size()

"spark table metrics repository " should {
"save and load a single metric" in withSparkSession { spark => {
val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context = AnalyzerContext(Map(analyzer -> metric))

val repository = new SparkTableMetricsRepository(spark, "metrics_table")
// Save the metric
repository.save(resultKey, context)

// Load the metric
val loadedContext = repository.loadByKey(resultKey)

assert(loadedContext.isDefined)
assert(loadedContext.get.metric(analyzer).contains(metric))
}

}

"save multiple metrics and load them" in withSparkSession { spark => {
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context1 = AnalyzerContext(Map(analyzer -> metric))

val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2"))
val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
val context2 = AnalyzerContext(Map(analyzer -> metric2))

repository.save(resultKey1, context1)
repository.save(resultKey2, context2)

val loadedMetrics = repository.load().get()

assert(loadedMetrics.length == 2)

loadedMetrics.flatMap(_.resultKey.tags)
}
}

"save and load metrics with tag" in withSparkSession { spark => {
val repository = new SparkTableMetricsRepository(spark, "metrics_table")

val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context1 = AnalyzerContext(Map(analyzer -> metric))

val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B"))
val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
val context2 = AnalyzerContext(Map(analyzer -> metric2))

repository.save(resultKey1, context1)
repository.save(resultKey2, context2)
val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get()
assert(loadedMetricsForTagA.length == 1)
// additional assertions to ensure the loaded metric is the one with tag "A"

val loadedMetricsForMetricM1 = repository.load().forAnalyzers(Seq(analyzer))
assert(loadedMetricsForTagA.length == 1)

}
}

"save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {
val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
val context = AnalyzerContext(Map(analyzer -> metric))

val repository = new SparkTableMetricsRepository(spark, "local.metrics_table")
// Save the metric
repository.save(resultKey, context)

// Load the metric
val loadedContext = repository.loadByKey(resultKey)

assert(loadedContext.isDefined)
assert(loadedContext.get.metric(analyzer).contains(metric))
}

}
}
}

0 comments on commit 209eba3

Please sign in to comment.