diff --git a/pom.xml b/pom.xml index 262a670f6..d57cc1b52 100644 --- a/pom.xml +++ b/pom.xml @@ -155,6 +155,12 @@ test + + org.apache.iceberg + iceberg-spark-runtime-3.3_2.12 + 0.14.0 + test + diff --git a/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala new file mode 100644 index 000000000..5fb195b07 --- /dev/null +++ b/src/main/scala/com/amazon/deequ/repository/sparktable/SparkMetricsRepository.scala @@ -0,0 +1,135 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ +package com.amazon.deequ.repository.sparktable + +import com.amazon.deequ.analyzers.Analyzer +import com.amazon.deequ.analyzers.runners.AnalyzerContext +import com.amazon.deequ.metrics.Metric +import com.amazon.deequ.repository._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} + +class SparkTableMetricsRepository(session: SparkSession, tableName: String) extends MetricsRepository { + + private val SCHEMA = StructType(Array( + StructField("result_key", StringType), + StructField("metric_name", StringType), + StructField("metric_value", StringType), + StructField("result_timestamp", StringType), + StructField("serialized_context", StringType) + )) + + override def save(resultKey: ResultKey, analyzerContext: AnalyzerContext): Unit = { + val serializedContext = AnalysisResultSerde.serialize(Seq(AnalysisResult(resultKey, analyzerContext))) + + val rows = analyzerContext.metricMap.map { case (analyzer, metric) => + Row(resultKey.toString, analyzer.toString, metric.value.toString, + resultKey.dataSetDate.toString, serializedContext) + }.toSeq + + val metricDF = session.createDataFrame(session.sparkContext.parallelize(rows), SCHEMA) + + metricDF.write + .mode(SaveMode.Append) + .saveAsTable(tableName) + } + + override def loadByKey(resultKey: ResultKey): Option[AnalyzerContext] = { + val df: DataFrame = session.table(tableName) + val matchingRows = df.filter(col("result_key") === resultKey.toString).collect() + + if (matchingRows.isEmpty) { + None + } else { + val serializedContext = matchingRows(0).getAs[String]("serialized_context") + val analysisResult = AnalysisResultSerde.deserialize(serializedContext).head + Some(analysisResult.analyzerContext) + } + } + + override def load(): MetricsRepositoryMultipleResultsLoader = { + SparkTableMetricsRepositoryMultipleResultsLoader(session, tableName) + } + +} + + +case class SparkTableMetricsRepositoryMultipleResultsLoader(session: SparkSession, + tableName: String, + tagValues: Option[Map[String, String]] = None, + analyzers: Option[Seq[Analyzer[_, Metric[_]]]] = None, + timeAfter: Option[Long] = None, + timeBefore: Option[Long] = None + ) extends MetricsRepositoryMultipleResultsLoader { + + override def withTagValues(tagValues: Map[String, String]): MetricsRepositoryMultipleResultsLoader = + this.copy(tagValues = Some(tagValues)) + + override def forAnalyzers(analyzers: Seq[Analyzer[_, Metric[_]]]): MetricsRepositoryMultipleResultsLoader = + this.copy(analyzers = Some(analyzers)) + + override def after(dateTime: Long): MetricsRepositoryMultipleResultsLoader = + this.copy(timeAfter = Some(dateTime)) + + override def before(dateTime: Long): MetricsRepositoryMultipleResultsLoader = + this.copy(timeBefore = Some(dateTime)) + + override def get(): Seq[AnalysisResult] = { + val initialDF: DataFrame = session.table(tableName) + + initialDF.printSchema() + val tagValuesFilter: DataFrame => DataFrame = df => { + tagValues.map { tags => + tags.foldLeft(df) { (currentDF, tag) => + currentDF.filter(row => { + val ser = row.getAs[String]("serialized_context") + AnalysisResultSerde.deserialize(ser).exists(ar => { + val tags = ar.resultKey.tags + tags.contains(tag._1) && tags(tag._1) == tag._2 + }) + }) + } + }.getOrElse(df) + } + + val specificAnalyzersFilter: DataFrame => DataFrame = df => { + analyzers.map(analyzers => df.filter(col("metric_name").isin(analyzers.map(_.toString): _*))) + .getOrElse(df) + } + + val timeAfterFilter: DataFrame => DataFrame = df => { + timeAfter.map(time => df.filter(col("result_timestamp") > time.toString)).getOrElse(df) + } + + val timeBeforeFilter: DataFrame => DataFrame = df => { + timeBefore.map(time => df.filter(col("result_timestamp") < time.toString)).getOrElse(df) + } + + val filteredDF = Seq(tagValuesFilter, specificAnalyzersFilter, timeAfterFilter, timeBeforeFilter) + .foldLeft(initialDF) { + (df, filter) => filter(df) + } + + // Convert the final DataFrame to the desired output format + filteredDF.collect().flatMap(row => { + val serializedContext = row.getAs[String]("serialized_context") + AnalysisResultSerde.deserialize(serializedContext) + }).toSeq + } + + +} diff --git a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala index b54e05eee..cff2c7448 100644 --- a/src/test/scala/com/amazon/deequ/SparkContextSpec.scala +++ b/src/test/scala/com/amazon/deequ/SparkContextSpec.scala @@ -19,11 +19,16 @@ package com.amazon.deequ import org.apache.spark.SparkContext import org.apache.spark.sql.{SQLContext, SparkSession} +import java.nio.file.{Files, Path} +import scala.collection.convert.ImplicitConversions.`iterator asScala` + /** * To be mixed with Tests so they can use a default spark context suitable for testing */ trait SparkContextSpec { + val warehouseDir: Path = Files.createTempDirectory("my_temp_dir_") + /** * @param testFun thunk to run with SparkSession as an argument */ @@ -37,6 +42,20 @@ trait SparkContextSpec { } } + def withSparkSessionIcebergCatalog(testFun: SparkSession => Any): Unit = { + val session = setupSparkSession + session.conf.set("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") + session.conf.set("spark.sql.catalog.local.type", "hadoop") + session.conf.set("spark.sql.catalog.local.warehouse", warehouseDir.toAbsolutePath.toString) + + try { + testFun(session) + } finally { + /* empty cache of RDD size, as the referred ids are only valid within a session */ + tearDownSparkSession(session) + } + } + /** * @param testFun thunk to run with SparkSession and SparkMonitor as an argument for the tests * that would like to get details on spark jobs submitted @@ -80,11 +99,22 @@ trait SparkContextSpec { .config("spark.sql.shuffle.partitions", 2.toString) .config("spark.sql.adaptive.enabled", value = false) .config("spark.driver.bindAddress", "127.0.0.1") + .config("spark.sql.warehouse.dir", warehouseDir.toAbsolutePath.toString) .getOrCreate() session.sparkContext.setCheckpointDir(System.getProperty("java.io.tmpdir")) session } + /** + * to cleanup temp directory used in test + * @param path - path to cleanup + */ + private def deleteDirectory(path: Path): Unit = { + if (Files.exists(path)) { + Files.walk(path).iterator().toList.reverse.foreach(Files.delete) + } + } + /** * Tears down the sparkSession * @@ -94,6 +124,8 @@ trait SparkContextSpec { private def tearDownSparkSession(session: SparkSession) = { session.stop() System.clearProperty("spark.driver.port") + deleteDirectory(warehouseDir) + } } diff --git a/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala new file mode 100644 index 000000000..5b66ce305 --- /dev/null +++ b/src/test/scala/com/amazon/deequ/repository/sparktable/SparkTableMetricsRepositoryTest.scala @@ -0,0 +1,119 @@ +/** + * Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not + * use this file except in compliance with the License. A copy of the License + * is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + * + */ + +package com.amazon.deequ.repository.sparktable + +import com.amazon.deequ.SparkContextSpec +import com.amazon.deequ.analyzers.Size +import com.amazon.deequ.analyzers.runners.AnalyzerContext +import com.amazon.deequ.metrics.{DoubleMetric, Entity} +import com.amazon.deequ.repository.ResultKey +import com.amazon.deequ.utils.FixtureSupport +import org.scalatest.wordspec.AnyWordSpec + +import scala.util.Try + +class SparkTableMetricsRepositoryTest extends AnyWordSpec + with SparkContextSpec + with FixtureSupport { + + // private var spark: SparkSession = _ + // private var repository: SparkTableMetricsRepository = _ + private val analyzer = Size() + + "spark table metrics repository " should { + "save and load a single metric" in withSparkSession { spark => { + val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context = AnalyzerContext(Map(analyzer -> metric)) + + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + // Save the metric + repository.save(resultKey, context) + + // Load the metric + val loadedContext = repository.loadByKey(resultKey) + + assert(loadedContext.isDefined) + assert(loadedContext.get.metric(analyzer).contains(metric)) + } + + } + + "save multiple metrics and load them" in withSparkSession { spark => { + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + + val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue1")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context1 = AnalyzerContext(Map(analyzer -> metric)) + + val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "tagValue2")) + val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101)) + val context2 = AnalyzerContext(Map(analyzer -> metric2)) + + repository.save(resultKey1, context1) + repository.save(resultKey2, context2) + + val loadedMetrics = repository.load().get() + + assert(loadedMetrics.length == 2) + + loadedMetrics.flatMap(_.resultKey.tags) + } + } + + "save and load metrics with tag" in withSparkSession { spark => { + val repository = new SparkTableMetricsRepository(spark, "metrics_table") + + val resultKey1 = ResultKey(System.currentTimeMillis(), Map("tag" -> "A")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context1 = AnalyzerContext(Map(analyzer -> metric)) + + val resultKey2 = ResultKey(System.currentTimeMillis(), Map("tag" -> "B")) + val metric2 = DoubleMetric(Entity.Column, "m2", "", Try(101)) + val context2 = AnalyzerContext(Map(analyzer -> metric2)) + + repository.save(resultKey1, context1) + repository.save(resultKey2, context2) + val loadedMetricsForTagA = repository.load().withTagValues(Map("tag" -> "A")).get() + assert(loadedMetricsForTagA.length == 1) + // additional assertions to ensure the loaded metric is the one with tag "A" + + val loadedMetricsForMetricM1 = repository.load().forAnalyzers(Seq(analyzer)) + assert(loadedMetricsForTagA.length == 1) + + } + } + + "save and load to iceberg a single metric" in withSparkSessionIcebergCatalog { spark => { + val resultKey = ResultKey(System.currentTimeMillis(), Map("tag" -> "value")) + val metric = DoubleMetric(Entity.Column, "m1", "", Try(100)) + val context = AnalyzerContext(Map(analyzer -> metric)) + + val repository = new SparkTableMetricsRepository(spark, "local.metrics_table") + // Save the metric + repository.save(resultKey, context) + + // Load the metric + val loadedContext = repository.loadByKey(resultKey) + + assert(loadedContext.isDefined) + assert(loadedContext.get.metric(analyzer).contains(metric)) + } + + } + } +}