diff --git a/pom.xml b/pom.xml
index 262a670f6..d57cc1b52 100644
--- a/pom.xml
+++ b/pom.xml
@@ -155,6 +155,12 @@
test
+
+ org.apache.iceberg
+ iceberg-spark-runtime-3.3_2.12
+ 0.14.0
+ test
+
diff --git a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala
new file mode 100644
index 000000000..5fb195b07
--- /dev/null
+++ b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala
@@ -0,0 +1,135 @@
+/**
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License
+ * is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ *
+ */
+package com.amazon.deequ.repository.sparktable
+
+import com.amazon.deequ.analyzers.Analyzer
+import com.amazon.deequ.analyzers.runners.AnalyzerContext
+import com.amazon.deequ.metrics.Metric
+import com.amazon.deequ.repository._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession}
+
+class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository {
+
+ private val SCHEMA = StructType(Array(
+ StructField("result_key", StringType),
+ StructField("metric_name", StringType),
+ StructField("metric_value", StringType),
+ StructField("result_timestamp", StringType),
+ StructField("serialized_context", StringType)
+ ))
+
+ override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = {
+ val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext)))
+
+ val rows = analyzerContext.metricMap.map { case (analyzer, metric) =>
+ Row(resultKey.toString, analyzer.toString, metric.value.toString,
+ resultKey.dataSetDate.toString, serializedContext)
+ }.toSeq
+
+ val metricDF = session.createDataFrame(session.sparkContext.parallelize(rows), SCHEMA)
+
+ metricDF.write
+ .mode(SaveMode.Append)
+ .saveAsTable(tableName)
+ }
+
+ override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = {
+ val df: DataFrame = session.table(tableName)
+ val matchingRows = df.filter(col("result_key") === resultKey.toString).collect()
+
+ if (matchingRows.isEmpty) {
+ None
+ } else {
+ val serializedContext = matchingRows(0).getAs[String]("serialized_context")
+ val analysisResult = AnalysisResultSerde.deserialize(serializedContext).head
+ Some(analysisResult.analyzerContext)
+ }
+ }
+
+ override def load(): MetricsRepositoryMultipleResultsLoader = {
+ SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName)
+ }
+
+}
+
+
+case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession,
+ tableName: String,
+ tagValues: Option[Map[String, String]] = None,
+ analyzers: Option[Seq[Analyzer[_, Metric[_]]]] = None,
+ timeAfter: Option[Long] = None,
+ timeBefore: Option[Long] = None
+ ) extends MetricsRepositoryMultipleResultsLoader {
+
+ override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader =
+ this.copy(tagValues = Some(tagValues))
+
+ override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader =
+ this.copy(analyzers = Some(analyzers))
+
+ override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
+ this.copy(timeAfter = Some(dateTime))
+
+ override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader =
+ this.copy(timeBefore = Some(dateTime))
+
+ override def get(): Seq[AnalysisResult] = {
+ val initialDF: DataFrame = session.table(tableName)
+
+ initialDF.printSchema()
+ val tagValuesFilter: DataFrame => DataFrame = df => {
+ tagValues.map { tags =>
+ tags.foldLeft(df) { (currentDF, tag) =>
+ currentDF.filter(row => {
+ val ser = row.getAs[String]("serialized_context")
+ AnalysisResultSerde.deserialize(ser).exists(ar => {
+ val tags = ar.resultKey.tags
+ tags.contains(tag._1) && tags(tag._1) == tag._2
+ })
+ })
+ }
+ }.getOrElse(df)
+ }
+
+ val specificAnalyzersFilter: DataFrame => DataFrame = df => {
+ analyzers.map(analyzers => df.filter(col("metric_name").isin(analyzers.map(_.toString): _*)))
+ .getOrElse(df)
+ }
+
+ val timeAfterFilter: DataFrame => DataFrame = df => {
+ timeAfter.map(time => df.filter(col("result_timestamp") > time.toString)).getOrElse(df)
+ }
+
+ val timeBeforeFilter: DataFrame => DataFrame = df => {
+ timeBefore.map(time => df.filter(col("result_timestamp") < time.toString)).getOrElse(df)
+ }
+
+ val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter)
+ .foldLeft(initialDF) {
+ (df, filter) => filter(df)
+ }
+
+ // Convert the final DataFrame to the desired output format
+ filteredDF.collect().flatMap(row => {
+ val serializedContext = row.getAs[String]("serialized_context")
+ AnalysisResultSerde.deserialize(serializedContext)
+ }).toSeq
+ }
+
+
+}
diff --git a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala
index b54e05eee..cff2c7448 100644
--- a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala
+++ b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala
@@ -19,11 +19,16 @@ package com.amazon.deequ
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, SparkSession}
+import java.nio.file.{Files, Path}
+import scala.collection.convert.ImplicitConversions.`iterator asScala`
+
/**
* To be mixed with Tests so they can use a default spark context suitable for testing
*/
trait SparkContextSpec {
+ val warehouseDir: Path = Files.createTempDirectory("my_temp_dir_")
+
/**
* @param testFun thunk to run with SparkSession as an argument
*/
@@ -37,6 +42,20 @@ trait SparkContextSpec {
}
}
+ def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = {
+ val session = setupSparkSession
+ session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog")
+ session.conf.set("spark.sql.catalog.local.type", "hadoop")
+ session.conf.set("spark.sql.catalog.local.warehouse", warehouseDir.toAbsolutePath.toString)
+
+ try {
+ testFun(session)
+ } finally {
+ /* empty cache of RDD size, as the referred ids are only valid within a session */
+ tearDownSparkSession(session)
+ }
+ }
+
/**
* @param testFun thunk to run with SparkSession and SparkMonitor as an argument for the tests
* that would like to get details on spark jobs submitted
@@ -80,11 +99,22 @@ trait SparkContextSpec {
.config("spark.sql.shuffle.partitions", 2.toString)
.config("spark.sql.adaptive.enabled", value = false)
.config("spark.driver.bindAddress", "127.0.0.1")
+ .config("spark.sql.warehouse.dir", warehouseDir.toAbsolutePath.toString)
.getOrCreate()
session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir"))
session
}
+ /**
+ * to cleanup temp directory used in test
+ * @param path - path to cleanup
+ */
+ private def deleteDirectory(path: Path): Unit = {
+ if (Files.exists(path)) {
+ Files.walk(path).iterator().toList.reverse.foreach(Files.delete)
+ }
+ }
+
/**
* Tears down the sparkSession
*
@@ -94,6 +124,8 @@ trait SparkContextSpec {
private def tearDownSparkSession(session: SparkSession) = {
session.stop()
System.clearProperty("spark.driver.port")
+ deleteDirectory(warehouseDir)
+
}
}
diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala
new file mode 100644
index 000000000..5b66ce305
--- /dev/null
+++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala
@@ -0,0 +1,119 @@
+/**
+ * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"). You may not
+ * use this file except in compliance with the License. A copy of the License
+ * is located at
+ *
+ * http://aws.amazon.com/apache2.0/
+ *
+ * or in the "license" file accompanying this file. This file is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ *
+ */
+
+package com.amazon.deequ.repository.sparktable
+
+import com.amazon.deequ.SparkContextSpec
+import com.amazon.deequ.analyzers.Size
+import com.amazon.deequ.analyzers.runners.AnalyzerContext
+import com.amazon.deequ.metrics.{DoubleMetric, Entity}
+import com.amazon.deequ.repository.ResultKey
+import com.amazon.deequ.utils.FixtureSupport
+import org.scalatest.wordspec.AnyWordSpec
+
+import scala.util.Try
+
+class SparkTableMetricsRepositoryTest extends AnyWordSpec
+ with SparkContextSpec
+ with FixtureSupport {
+
+ // private var spark: SparkSession = _
+ // private var repository: SparkTableMetricsRepository = _
+ private val analyzer = Size()
+
+ "spark table metrics repository " should {
+ "save and load a single metric" in withSparkSession { spark => {
+ val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context = AnalyzerContext(Map(analyzer -> metric))
+
+ val repository = new SparkTableMetricsRepository(spark, "metrics_table")
+ // Save the metric
+ repository.save(resultKey, context)
+
+ // Load the metric
+ val loadedContext = repository.loadByKey(resultKey)
+
+ assert(loadedContext.isDefined)
+ assert(loadedContext.get.metric(analyzer).contains(metric))
+ }
+
+ }
+
+ "save multiple metrics and load them" in withSparkSession { spark => {
+ val repository = new SparkTableMetricsRepository(spark, "metrics_table")
+
+ val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context1 = AnalyzerContext(Map(analyzer -> metric))
+
+ val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2"))
+ val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
+ val context2 = AnalyzerContext(Map(analyzer -> metric2))
+
+ repository.save(resultKey1, context1)
+ repository.save(resultKey2, context2)
+
+ val loadedMetrics = repository.load().get()
+
+ assert(loadedMetrics.length == 2)
+
+ loadedMetrics.flatMap(_.resultKey.tags)
+ }
+ }
+
+ "save and load metrics with tag" in withSparkSession { spark => {
+ val repository = new SparkTableMetricsRepository(spark, "metrics_table")
+
+ val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context1 = AnalyzerContext(Map(analyzer -> metric))
+
+ val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B"))
+ val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101))
+ val context2 = AnalyzerContext(Map(analyzer -> metric2))
+
+ repository.save(resultKey1, context1)
+ repository.save(resultKey2, context2)
+ val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get()
+ assert(loadedMetricsForTagA.length == 1)
+ // additional assertions to ensure the loaded metric is the one with tag "A"
+
+ val loadedMetricsForMetricM1 = repository.load().forAnalyzers(Seq(analyzer))
+ assert(loadedMetricsForTagA.length == 1)
+
+ }
+ }
+
+ "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => {
+ val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value"))
+ val metric = DoubleMetric(Entity.Column, "m1", "", Try(100))
+ val context = AnalyzerContext(Map(analyzer -> metric))
+
+ val repository = new SparkTableMetricsRepository(spark, "local.metrics_table")
+ // Save the metric
+ repository.save(resultKey, context)
+
+ // Load the metric
+ val loadedContext = repository.loadByKey(resultKey)
+
+ assert(loadedContext.isDefined)
+ assert(loadedContext.get.metric(analyzer).contains(metric))
+ }
+
+ }
+ }
+}